Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
e23ddd84b8 x/grammar: add experimental GPU accelerated constrained decoding package 2026-01-11 00:50:11 -08:00
62 changed files with 6075 additions and 1179 deletions

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 8 * format.MegaByte
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -100,8 +100,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
return imagegenclient.CreateModel(args[0], ".", p)
}
reader = strings.NewReader("FROM .\n")
} else {
@@ -465,6 +464,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -526,14 +533,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
@@ -672,11 +671,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -842,6 +837,11 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err

18
go.mod
View File

@@ -15,8 +15,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
golang.org/x/sync v0.19.0
golang.org/x/sys v0.39.0
)
require (
@@ -30,8 +30,8 @@ require (
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
golang.org/x/tools v0.38.0
golang.org/x/mod v0.31.0
golang.org/x/tools v0.40.0
gonum.org/v1/gonum v0.15.0
)
@@ -81,11 +81,11 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.46.0 // indirect
golang.org/x/term v0.36.0
golang.org/x/text v0.30.0
golang.org/x/crypto v0.46.0
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
golang.org/x/net v0.48.0 // indirect
golang.org/x/term v0.38.0
golang.org/x/text v0.32.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

36
go.sum
View File

@@ -233,16 +233,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -264,8 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -278,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -289,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -306,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -330,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -1124,15 +1124,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
QuantizationLevel: m.Config.FileType,
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
if req.System != "" {
m.System = req.System
}
@@ -1215,10 +1206,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err

185
x/grammar/README.md Normal file
View File

@@ -0,0 +1,185 @@
# grammar
Grammar-constrained decoding for LLM outputs using MLX.
## Performance
Performance depends on hardware, vocabulary size, grammar, and whether you
evaluate the MLX graph. See [Benchmarks](#benchmarks) for how to measure on your
setup.
### Design choices that keep masking fast
| Technique | Impact |
|-----------|--------|
| Precomputed token analysis | Terminal matches computed once at startup |
| Mask caching by grammar state signature | Reuse masks for repeated parser states |
| Partitioned tokens | Exact matches separated from DP candidates |
### Comparison Notes
- **llama.cpp**: Decodes each token to UTF-8, checks against PDA. No caching.
- **Outlines**: FSM-based. Compilation can take 40s-10min for complex schemas. Fast after compile.
- **XGrammar**: PDA with 99% context-independent tokens precomputed. State-of-the-art before this.
- **x/grammar**: Precomputed token analysis + mask caching by grammar state signature.
## Usage
```go
import (
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/grammar/schema"
)
// Use built-in JSON grammar
g, _ := grammar.JSONGrammar()
// Or from JSON Schema (OpenAI-compatible)
g, _ := schema.Grammar(`{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`)
// Or parse custom EBNF
g, _ := grammar.ParseEBNF(myGrammar, "root")
// Create engine with model vocabulary
engine, _ := grammar.NewEngine(g, vocab)
defer engine.Close()
// Generation loop
for !engine.IsComplete() {
logits := model.Forward(tokens)
masked := engine.ApplyMask(logits) // Invalid tokens → -inf
nextToken := sample(masked)
engine.Accept(nextToken)
}
// Output conforms to the grammar when you only sample from masked tokens and call Accept
```
## EBNF Syntax
```ebnf
rule = expression . # Rule definition (ends with .)
"literal" # Literal string
"a" "z" # Character range (inclusive)
( a | b ) # Grouping with alternation
[ optional ] # Optional (0 or 1)
{ repeated } # Repetition (0 or more)
```
### Example: JSON Grammar
```ebnf
json = value .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .
```
### Example: Custom Schema
```ebnf
root = "{" ws name_field "," ws age_field ws "}" .
name_field = "\"name\"" ws ":" ws string .
age_field = "\"age\"" ws ":" ws number .
string = "\"" { char } "\"" .
char = " " | "!" | "#" "~" .
number = [ "-" ] digit { digit } .
digit = "0" "9" .
ws = { " " | "\n" } .
```
## JSON Schema Support
OpenAI-compatible JSON Schema support with automatic EBNF generation:
```go
schema := `{
"type": "object",
"properties": {
"user": {"$ref": "#/$defs/User"}
},
"required": ["user"],
"$defs": {
"User": {
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string", "format": "email"},
"role": {"enum": ["admin", "user", "guest"]}
},
"required": ["name", "email", "role"]
}
}
}`
grammar, _ := schema.Grammar(schema)
```
### Supported Features
| Feature | Example |
|---------|---------|
| Basic types | `string`, `integer`, `number`, `boolean`, `null` |
| Objects | `properties`, `required` |
| Arrays | `items`, `minItems`, `maxItems` |
| Enums | `enum: ["a", "b", "c"]` |
| Constants | `const: "value"` |
| Union types | `anyOf`, `oneOf`, `type: ["string", "null"]` |
| References | `$ref: "#/$defs/Name"`, `$defs` |
| Formats | `date`, `time`, `date-time`, `email`, `uuid`, `ipv4` |
## Benchmarks
```bash
# Run all tests
go test -tags mlx ./x/grammar/...
# Run benchmarks
go test -tags mlx ./x/grammar/ -bench=.
# Compare with llama.cpp (outputs JSON)
go run -tags mlx ./x/grammar/cmd/compare -vocab-size 128000 -iterations 500
# Compare with a more complex schema
go run -tags mlx ./x/grammar/cmd/compare \
-gbnf x/grammar/cmd/compare/complex.gbnf \
-schema x/grammar/cmd/compare/complex.schema.json \
-vocab-size 128000 -iterations 500
```
## References
- [XGrammar Paper](https://arxiv.org/abs/2411.15100) - Flexible and Efficient Structured Generation
- [Outlines](https://github.com/dottxt-ai/outlines) - Structured Text Generation
- [JSONSchemaBench](https://arxiv.org/abs/2501.10868) - Benchmark for Structured Outputs

161
x/grammar/analyzer.go Normal file
View File

@@ -0,0 +1,161 @@
//go:build mlx
package grammar
// terminalTokenGroups contains pre-partitioned tokens for a terminal.
// This enables O(1) lookup of tokens that exactly match vs need DP validation.
type terminalTokenGroups struct {
// ExactMatches are tokens that exactly match this terminal (O(1) validation)
ExactMatches []int32
// DPCandidates are tokens that start with this terminal but need DP validation
DPCandidates []int
}
// tokenAnalysis contains precomputed terminal matches for a token
type tokenAnalysis struct {
// The token string
Token string
// TokenID in the vocabulary
TokenID int
// Matches at each byte position
// MatchesAtPos[i] = terminals matching at position i with their lengths
MatchesAtPos [][]terminalMatch
// Fast path: if token exactly matches one terminal
// -1 if no exact match
exactMatch int
// Whether this token can be consumed at all (has at least one match)
HasMatches bool
}
// analyzer precomputes terminal matches for a vocabulary
type analyzer struct {
matcher *terminalMatcher
analyses []tokenAnalysis // Indexed by token ID
vocab []string
// Pre-partitioned tokens by terminal (exact match vs DP candidates)
// This enables direct slice appends instead of per-token branching
tokensByTerminal []terminalTokenGroups
}
// newAnalyzer creates an analyzer for the given vocabulary and terminals
func newAnalyzer(vocab []string, matcher *terminalMatcher) *analyzer {
a := &analyzer{
matcher: matcher,
analyses: make([]tokenAnalysis, len(vocab)),
vocab: vocab,
}
// Precompute analysis for each token
for i, token := range vocab {
a.analyses[i] = a.analyze(token, i)
}
// Build pre-partitioned token groups for fast ApplyMask
a.buildTokenPartitions()
return a
}
// analyze computes terminal matches for a single token
func (a *analyzer) analyze(token string, tokenID int) tokenAnalysis {
analysis := tokenAnalysis{
Token: token,
TokenID: tokenID,
MatchesAtPos: make([][]terminalMatch, len(token)),
exactMatch: -1,
HasMatches: false,
}
if len(token) == 0 {
return analysis
}
// Compute matches at each position
data := []byte(token)
for pos := 0; pos < len(data); pos++ {
matches := a.matcher.matchesAt(data, pos)
analysis.MatchesAtPos[pos] = matches
if len(matches) > 0 {
analysis.HasMatches = true
}
}
// Exact match is only valid when a single terminal spans the entire token
if len(analysis.MatchesAtPos) > 0 {
var exactID int = -1
for _, match := range analysis.MatchesAtPos[0] {
if match.Length != len(token) {
continue
}
if exactID >= 0 && exactID != match.TerminalID {
exactID = -1
break
}
exactID = match.TerminalID
}
analysis.exactMatch = exactID
}
return analysis
}
// analysis returns the precomputed analysis for a token ID
func (a *analyzer) analysis(tokenID int) tokenAnalysis {
if tokenID < 0 || tokenID >= len(a.analyses) {
return tokenAnalysis{exactMatch: -1}
}
return a.analyses[tokenID]
}
// vocabSize returns the vocabulary size
func (a *analyzer) vocabSize() int {
return len(a.vocab)
}
// buildTokenPartitions pre-partitions tokens into exact-match vs needs-DP groups per terminal.
// This enables ApplyMask to use direct slice appends instead of per-token branching.
func (a *analyzer) buildTokenPartitions() {
numTerminals := a.matcher.terminalCount()
a.tokensByTerminal = make([]terminalTokenGroups, numTerminals)
for tokenID, analysis := range a.analyses {
if !analysis.HasMatches {
continue
}
if analysis.exactMatch >= 0 {
// Token exactly matches one terminal - fast path (O(1) validation)
tid := analysis.exactMatch
a.tokensByTerminal[tid].ExactMatches = append(
a.tokensByTerminal[tid].ExactMatches, int32(tokenID))
} else {
// Token needs DP validation - add to all terminals it can start with
// This way, when a terminal is valid, we know exactly which tokens need DP
if len(analysis.MatchesAtPos) > 0 {
seen := make(map[int]bool)
for _, match := range analysis.MatchesAtPos[0] {
tid := match.TerminalID
if !seen[tid] {
seen[tid] = true
a.tokensByTerminal[tid].DPCandidates = append(
a.tokensByTerminal[tid].DPCandidates, tokenID)
}
}
}
}
}
}
// terminalGroups returns the pre-partitioned token groups for a terminal ID
func (a *analyzer) terminalGroups(terminalID int) terminalTokenGroups {
if terminalID < 0 || terminalID >= len(a.tokensByTerminal) {
return terminalTokenGroups{}
}
return a.tokensByTerminal[terminalID]
}

648
x/grammar/bridge.go Normal file
View File

@@ -0,0 +1,648 @@
//go:build mlx
package grammar
import (
"encoding/binary"
"hash/fnv"
"sort"
"sync"
)
// visitedMapPool reduces allocations for visited maps in bridge operations
var visitedMapPool = sync.Pool{
New: func() interface{} {
return make(map[stateStackKey]bool, 16)
},
}
// getVisitedMap gets a map from the pool
func getVisitedMap() map[stateStackKey]bool {
return visitedMapPool.Get().(map[stateStackKey]bool)
}
// putVisitedMap returns a map to the pool after clearing it
func putVisitedMap(m map[stateStackKey]bool) {
for k := range m {
delete(m, k)
}
visitedMapPool.Put(m)
}
// parserConfig represents a pda state+stack combination
type parserConfig struct {
state state
Stack []stackSymbol
}
// clone creates a deep copy of the config
func (c *parserConfig) clone() *parserConfig {
newStack := make([]stackSymbol, len(c.Stack))
copy(newStack, c.Stack)
return &parserConfig{
state: c.state,
Stack: newStack,
}
}
// key returns a unique key for this config for deduplication
func (c *parserConfig) key() uint64 {
h := fnv.New64a()
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(c.state))
h.Write(buf[:])
for _, sym := range c.Stack {
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
h.Write(buf[:])
}
return h.Sum64()
}
// configSet represents a set of parser configurations (for nondeterminism)
type configSet struct {
configs []*parserConfig
normalized bool // true if already deduplicated and sorted
cachedSig uint64 // cached signature after normalization
}
// newConfigSet creates a new config set with a single configuration
func newConfigSet(state state, stack []stackSymbol) *configSet {
return &configSet{
configs: []*parserConfig{
{state: state, Stack: stack},
},
normalized: true, // single config is already normalized
}
}
// normalize deduplicates and sorts configs for stable signatures
func (c *configSet) normalize() {
if c.normalized || len(c.configs) <= 1 {
c.normalized = true
return
}
// Deduplicate using a map
seen := make(map[uint64]*parserConfig, len(c.configs))
for _, cfg := range c.configs {
key := cfg.key()
if _, exists := seen[key]; !exists {
seen[key] = cfg
}
}
// Extract unique configs
unique := make([]*parserConfig, 0, len(seen))
for _, cfg := range seen {
unique = append(unique, cfg)
}
// Sort by key for deterministic ordering
sort.Slice(unique, func(i, j int) bool {
return unique[i].key() < unique[j].key()
})
c.configs = unique
c.normalized = true
}
// signature returns a hash for cache lookup (normalizes first)
func (c *configSet) signature() uint64 {
c.normalize()
// Return cached signature if available
if c.cachedSig != 0 {
return c.cachedSig
}
h := fnv.New64a()
// Hash number of configs
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(len(c.configs)))
h.Write(buf[:])
// Hash each config (already sorted)
for _, cfg := range c.configs {
binary.LittleEndian.PutUint64(buf[:], uint64(cfg.state))
h.Write(buf[:])
binary.LittleEndian.PutUint64(buf[:], uint64(len(cfg.Stack)))
h.Write(buf[:])
for _, sym := range cfg.Stack {
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
h.Write(buf[:])
}
}
c.cachedSig = h.Sum64()
return c.cachedSig
}
// isEmpty returns true if there are no configurations
func (c *configSet) isEmpty() bool {
return len(c.configs) == 0
}
// clone creates a deep copy of the config set
func (c *configSet) clone() *configSet {
newConfigs := make([]*parserConfig, len(c.configs))
for i, cfg := range c.configs {
newConfigs[i] = cfg.clone()
}
return &configSet{configs: newConfigs}
}
// bridge connects token analysis to pda validation
type bridge struct {
pda *pda
analyzer *analyzer
}
// newBridge creates a new bridge
func newBridge(pda *pda, analyzer *analyzer) *bridge {
return &bridge{
pda: pda,
analyzer: analyzer,
}
}
// IsTokenValid checks if token T can be consumed from the current config
// This is the main entry point for token validation
func (b *bridge) IsTokenValid(tokenID int, config *configSet) bool {
analysis := b.analyzer.analysis(tokenID)
if !analysis.HasMatches {
return false
}
// Fast path: exact terminal match
if analysis.exactMatch >= 0 {
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
return b.canAcceptTerminal(config, terminal.Pattern)
}
// General path: DP over (pos, config)
return b.dpValidate(&analysis, config)
}
// canAcceptTerminal checks if any config can accept the terminal
func (b *bridge) canAcceptTerminal(config *configSet, pattern string) bool {
for _, cfg := range config.configs {
if b.canConfigAcceptTerminal(cfg, pattern) {
return true
}
}
return false
}
// canConfigAcceptTerminal checks if a single config can accept the terminal
func (b *bridge) canConfigAcceptTerminal(cfg *parserConfig, pattern string) bool {
// Use pooled visited map to reduce allocations
visited := getVisitedMap()
result := b.tryAcceptTerminal(cfg.state, cfg.Stack, pattern, visited)
putVisitedMap(visited)
return result
}
// tryAcceptTerminal recursively tries to accept a terminal from a state
func (b *bridge) tryAcceptTerminal(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool) bool {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Can't pop more than we have
if t.StackPop > len(stack) {
continue
}
if t.Pattern == pattern {
// Direct match
return true
}
if t.Pattern == "" {
// Epsilon transition - follow it
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
// Pop
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
// Push
newStack = append(newStack, t.StackPush...)
if b.tryAcceptTerminal(t.ToState, newStack, pattern, visited) {
return true
}
}
}
return false
}
// dpValidate runs DP for multi-terminal tokens
func (b *bridge) dpValidate(analysis *tokenAnalysis, startConfig *configSet) bool {
// state: (pos, configSet)
// Memoize by (pos, configSig)
type dpKey struct {
pos int
sig uint64
}
memo := make(map[dpKey]bool)
var dp func(pos int, config *configSet) bool
dp = func(pos int, config *configSet) bool {
if pos == len(analysis.Token) {
return true // Consumed entire token
}
if config.isEmpty() {
return false
}
key := dpKey{pos, config.signature()}
if result, ok := memo[key]; ok {
return result
}
// Try each terminal that matches at this position
for _, match := range analysis.MatchesAtPos[pos] {
terminal := b.analyzer.matcher.terminals[match.TerminalID]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() && dp(pos+match.Length, newConfig) {
memo[key] = true
return true
}
}
memo[key] = false
return false
}
return dp(0, startConfig)
}
// advanceConfig advances all configs that can accept the terminal
func (b *bridge) advanceConfig(config *configSet, pattern string) *configSet {
var newConfigs []*parserConfig
for _, cfg := range config.configs {
advanced := b.advanceSingleConfig(cfg, pattern)
newConfigs = append(newConfigs, advanced...)
}
if len(newConfigs) == 0 {
return nil
}
return &configSet{configs: newConfigs}
}
// advanceSingleConfig advances a single config by accepting a terminal
func (b *bridge) advanceSingleConfig(cfg *parserConfig, pattern string) []*parserConfig {
var results []*parserConfig
visited := getVisitedMap()
b.collectAdvanced(cfg.state, cfg.Stack, pattern, visited, &results)
putVisitedMap(visited)
return results
}
// collectAdvanced collects all configs reachable by accepting the pattern
func (b *bridge) collectAdvanced(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool, results *[]*parserConfig) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Can't pop more than we have
if t.StackPop > len(stack) {
continue
}
if t.Pattern == pattern {
// Match! Create new config after transition
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
*results = append(*results, &parserConfig{
state: t.ToState,
Stack: newStack,
})
}
if t.Pattern == "" {
// Epsilon transition - follow it
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectAdvanced(t.ToState, newStack, pattern, visited, results)
}
}
}
// validTokens returns all token IDs that are valid from the given config
func (b *bridge) validTokens(config *configSet) []int {
var valid []int
for tokenID := 0; tokenID < b.analyzer.vocabSize(); tokenID++ {
if b.IsTokenValid(tokenID, config) {
valid = append(valid, tokenID)
}
}
return valid
}
// acceptToken attempts to accept a token and returns the new config set
// Returns nil if the token is not valid from this config
func (b *bridge) acceptToken(tokenID int, config *configSet) *configSet {
analysis := b.analyzer.analysis(tokenID)
if !analysis.HasMatches {
return nil
}
// Fast path: exact terminal match
if analysis.exactMatch >= 0 {
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() {
newConfig.normalize()
return newConfig
}
return nil
}
// General path: DP to find final config after consuming token
return b.dpAccept(&analysis, config)
}
// dpAccept runs DP to accept a multi-terminal token and return final config
// Returns the union of all possible end configurations (preserves nondeterminism)
func (b *bridge) dpAccept(analysis *tokenAnalysis, startConfig *configSet) *configSet {
type dpKey struct {
pos int
sig uint64
}
// Memoize the configs reachable at each (pos, sig)
memo := make(map[dpKey]*configSet)
var dp func(pos int, config *configSet) *configSet
dp = func(pos int, config *configSet) *configSet {
if pos == len(analysis.Token) {
return config // Consumed entire token, return final config
}
if config.isEmpty() {
return nil
}
key := dpKey{pos, config.signature()}
if result, ok := memo[key]; ok {
return result
}
// Collect all valid result configs from all possible paths
var allConfigs []*parserConfig
// Try each terminal that matches at this position
for _, match := range analysis.MatchesAtPos[pos] {
terminal := b.analyzer.matcher.terminals[match.TerminalID]
newConfig := b.advanceConfig(config, terminal.Pattern)
if newConfig != nil && !newConfig.isEmpty() {
finalConfig := dp(pos+match.Length, newConfig)
if finalConfig != nil {
// Collect all configs, don't return early
allConfigs = append(allConfigs, finalConfig.configs...)
}
}
}
// Build result: nil if no valid paths, normalized configSet otherwise
var result *configSet
if len(allConfigs) > 0 {
result = &configSet{configs: allConfigs}
result.normalize() // Dedup using parserConfig.key(), sort for consistent signature
}
memo[key] = result // Cache normalized result
return result
}
return dp(0, startConfig)
}
// isAccepting returns true if any config can reach an accepting state
func (b *bridge) isAccepting(config *configSet) bool {
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config check
for k := range visited {
delete(visited, k)
}
if b.canReachAccept(cfg.state, cfg.Stack, visited) {
return true
}
}
return false
}
// canReachAccept checks if we can reach an accepting state via epsilon transitions
func (b *bridge) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
// Check if this state is accepting with empty stack
if b.pda.AcceptStates[state] && len(stack) == 0 {
return true
}
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
// Try epsilon transitions
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.Pattern != "" {
continue // Not epsilon
}
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
if b.canReachAccept(t.ToState, newStack, visited) {
return true
}
}
return false
}
// validTerminals returns the valid terminal patterns from the given config
func (b *bridge) validTerminals(config *configSet) []string {
seen := make(map[string]bool)
var terminals []string
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config
for k := range visited {
delete(visited, k)
}
b.collectValidTerminals(cfg.state, cfg.Stack, visited, seen, &terminals)
}
return terminals
}
// collectValidTerminals collects all reachable terminals
func (b *bridge) collectValidTerminals(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[string]bool, terminals *[]string) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
if t.Pattern != "" && !seen[t.Pattern] {
seen[t.Pattern] = true
*terminals = append(*terminals, t.Pattern)
}
if t.Pattern == "" {
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectValidTerminals(t.ToState, newStack, visited, seen, terminals)
}
}
}
// validTerminalIDs returns the IDs of valid terminals from the given config
func (b *bridge) validTerminalIDs(config *configSet) []int {
seen := make(map[int]bool)
var terminalIDs []int
visited := getVisitedMap()
defer putVisitedMap(visited)
for _, cfg := range config.configs {
// Clear visited for each config
for k := range visited {
delete(visited, k)
}
b.collectValidTerminalIDs(cfg.state, cfg.Stack, visited, seen, &terminalIDs)
}
return terminalIDs
}
// collectValidTerminalIDs collects IDs of all reachable terminals
func (b *bridge) collectValidTerminalIDs(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[int]bool, terminalIDs *[]int) {
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return
}
visited[key] = true
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
for _, t := range b.pda.Transitions[state] {
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.StackPop > len(stack) {
continue
}
if t.Pattern != "" {
// Look up terminal ID from pattern
if tid, ok := b.analyzer.matcher.patternToID[t.Pattern]; ok && !seen[tid] {
seen[tid] = true
*terminalIDs = append(*terminalIDs, tid)
}
}
if t.Pattern == "" {
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
b.collectValidTerminalIDs(t.ToState, newStack, visited, seen, terminalIDs)
}
}
}

View File

@@ -0,0 +1,45 @@
root ::= ws "{" ws id-field "," ws kind-field "," ws items-field "," ws alt-field "," ws flags-field "," ws meta-field "," ws priority-field ws "}" ws
id-field ::= "\"id\"" ws ":" ws uuid
kind-field ::= "\"kind\"" ws ":" ws kind
items-field ::= "\"items\"" ws ":" ws items
alt-field ::= "\"alt\"" ws ":" ws alt
flags-field ::= "\"flags\"" ws ":" ws flags
meta-field ::= "\"meta\"" ws ":" ws meta
priority-field ::= "\"priority\"" ws ":" ws int
kind ::= "\"order\"" | "\"invoice\"" | "\"shipment\""
status ::= "\"new\"" | "\"backorder\"" | "\"shipped\""
flag ::= "\"fragile\"" | "\"gift\"" | "\"priority\"" | "\"insured\""
source ::= "\"api\"" | "\"batch\"" | "\"import\""
items ::= "[" ws item ( "," ws item )? ( "," ws item )? ws "]"
flags ::= "[" ws "]" | "[" ws flag ( "," ws flag )? ( "," ws flag )? ( "," ws flag )? ws "]"
item ::= "{" ws item-sku "," ws item-qty "," ws item-status "," ws item-notes ws "}"
item-sku ::= "\"sku\"" ws ":" ws string
item-qty ::= "\"qty\"" ws ":" ws int
item-status ::= "\"status\"" ws ":" ws status
item-notes ::= "\"notes\"" ws ":" ws string
meta ::= "{" ws meta-created "," ws meta-source "," ws meta-ip ws "}"
meta-created ::= "\"created\"" ws ":" ws date-time
meta-source ::= "\"source\"" ws ":" ws source
meta-ip ::= "\"ip\"" ws ":" ws ipv4
alt ::= string | int | "null"
uuid ::= "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\""
date-time ::= "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\""
ipv4 ::= "\"" digit+ "." digit+ "." digit+ "." digit+ "\""
string ::= "\"" characters "\""
characters ::= character*
character ::= [^"\\] | "\\" escape
escape ::= ["\\bfnrt]
int ::= "-"? digit+
digit ::= [0-9]
hex ::= [0-9a-fA-F]
ws ::= [ \t\n\r]*

View File

@@ -0,0 +1,46 @@
{
"type": "object",
"properties": {
"id": { "type": "string", "format": "uuid" },
"kind": { "enum": ["order", "invoice", "shipment"] },
"items": {
"type": "array",
"minItems": 1,
"maxItems": 3,
"items": {
"type": "object",
"properties": {
"sku": { "type": "string" },
"qty": { "type": "integer" },
"status": { "enum": ["new", "backorder", "shipped"] },
"notes": { "type": "string" }
},
"required": ["sku", "qty", "status", "notes"]
}
},
"alt": {
"oneOf": [
{ "type": "string" },
{ "type": "null" },
{ "type": "integer" }
]
},
"flags": {
"type": "array",
"minItems": 0,
"maxItems": 4,
"items": { "enum": ["fragile", "gift", "priority", "insured"] }
},
"meta": {
"type": "object",
"properties": {
"created": { "type": "string", "format": "date-time" },
"source": { "enum": ["api", "batch", "import"] },
"ip": { "type": "string", "format": "ipv4" }
},
"required": ["created", "source", "ip"]
},
"priority": { "type": "integer" }
},
"required": ["id", "kind", "items", "alt", "flags", "meta", "priority"]
}

View File

@@ -0,0 +1,235 @@
//go:build mlx
package main
import (
"encoding/json"
"flag"
"fmt"
"os"
"time"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/grammar/schema"
"github.com/ollama/ollama/x/imagegen/mlx"
)
const jsonGBNF = `
root ::= value
value ::= object | array | string | number | "true" | "false" | "null"
object ::= "{" ws "}" | "{" members "}"
members ::= member ("," member)*
member ::= ws string ws ":" element
array ::= "[" ws "]" | "[" elements "]"
elements ::= element ("," element)*
element ::= ws value ws
string ::= "\"" characters "\""
characters ::= character*
character ::= [^"\\] | "\\" escape
escape ::= ["\\bfnrt]
number ::= "-"? integer fraction? exponent?
integer ::= "0" | [1-9] [0-9]*
fraction ::= "." [0-9]+
exponent ::= [eE] [+-]? [0-9]+
ws ::= [ \t\n\r]*
`
type result struct {
vocabSize int `json:"vocab_size"`
Iterations int `json:"iterations"`
Warmup int `json:"warmup"`
ConstrainedSource string `json:"constrained_source"`
LlamaSource string `json:"llama_source"`
LlamaApply string `json:"llama_apply"`
ConstrainedGraph string `json:"constrained_graph"`
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
EvalOnly string `json:"eval_only,omitempty"`
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
}
func main() {
var (
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
iterations = flag.Int("iterations", 500, "Benchmark iterations")
warmup = flag.Int("warmup", 50, "Warmup iterations")
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
startRule = flag.String("start", "root", "Start rule for EBNF")
)
flag.Parse()
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
fmt.Fprintln(os.Stderr, "invalid flags")
os.Exit(2)
}
vocab := createVocab(*vocabSize)
if *schemaPath != "" && *ebnfPath != "" {
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
os.Exit(2)
}
var constrainedSource string
var compiled *grammar.Grammar
var err error
switch {
case *schemaPath != "":
data, readErr := os.ReadFile(*schemaPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
os.Exit(1)
}
compiled, err = schema.Grammar(string(data))
constrainedSource = "schema:" + *schemaPath
case *ebnfPath != "":
data, readErr := os.ReadFile(*ebnfPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
os.Exit(1)
}
compiled, err = grammar.ParseEBNF(string(data), *startRule)
constrainedSource = "ebnf:" + *ebnfPath
default:
compiled, err = grammar.JSONGrammar()
constrainedSource = "json"
}
if err != nil {
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
os.Exit(1)
}
engine, err := grammar.NewEngine(compiled, vocab)
if err != nil {
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
os.Exit(1)
}
defer engine.Close()
logits := mlx.Ones(int32(*vocabSize))
mlx.Keep(logits)
for i := 0; i < *warmup; i++ {
masked := engine.ApplyMask(logits)
if *withEval {
mlx.Eval(masked)
}
}
graphAvg := measure(*iterations, func() {
_ = engine.ApplyMask(logits)
})
var evalAvg time.Duration
var evalOnlyAvg time.Duration
if *withEval {
evalOnlyAvg = measure(*iterations, func() {
baseline := mlx.MulScalar(logits, 1)
mlx.Eval(baseline)
baseline.Free()
})
evalAvg = measure(*iterations, func() {
masked := engine.ApplyMask(logits)
mlx.Eval(masked)
})
}
vocabIDs := make([]uint32, *vocabSize)
for i := range vocabIDs {
vocabIDs[i] = uint32(i)
}
eogTokens := []int32{0}
gbnf := jsonGBNF
llamaSource := "json"
if *gbnfPath != "" {
data, readErr := os.ReadFile(*gbnfPath)
if readErr != nil {
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
os.Exit(1)
}
gbnf = string(data)
llamaSource = *gbnfPath
}
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
if llamaGrammar == nil {
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
os.Exit(1)
}
defer llamaGrammar.Free()
llamaTokens := make([]llama.TokenData, *vocabSize)
for i := 0; i < *warmup; i++ {
for j := range llamaTokens {
llamaTokens[j].Logit = 1.0
}
llamaGrammar.Apply(llamaTokens)
}
llamaAvg := measure(*iterations, func() {
for j := range llamaTokens {
llamaTokens[j].Logit = 1.0
}
llamaGrammar.Apply(llamaTokens)
})
out := result{
vocabSize: *vocabSize,
Iterations: *iterations,
Warmup: *warmup,
LlamaApply: llamaAvg.String(),
ConstrainedGraph: graphAvg.String(),
ConstrainedSource: constrainedSource,
LlamaSource: llamaSource,
}
if *withEval {
out.ConstrainedWithEval = evalAvg.String()
out.EvalOnly = evalOnlyAvg.String()
if evalAvg > evalOnlyAvg {
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
} else {
out.ConstrainedEvalNet = "0s"
}
}
enc := json.NewEncoder(os.Stdout)
if err := enc.Encode(out); err != nil {
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
os.Exit(1)
}
}
func measure(iterations int, fn func()) time.Duration {
start := time.Now()
for i := 0; i < iterations; i++ {
fn()
}
return time.Since(start) / time.Duration(iterations)
}
func createVocab(size int) []string {
vocab := make([]string, size)
jsonTokens := []string{
"{", "}", "[", "]", ":", ",",
"true", "false", "null",
" ", "\n", "\t", "\r",
"\"",
}
for i, t := range jsonTokens {
if i < size {
vocab[i] = t
}
}
for i := len(jsonTokens); i < size; i++ {
vocab[i] = fmt.Sprintf("tok%d", i)
}
return vocab
}

320
x/grammar/compiled.go Normal file
View File

@@ -0,0 +1,320 @@
//go:build mlx
package grammar
import (
"fmt"
"strconv"
"strings"
"unicode/utf8"
)
// Grammar is the compiled form of an EBNF grammar.
// It contains terminals, parse tables, and the start state.
// Use ParseEBNF or JSONGrammar to create a Grammar.
type Grammar struct {
// The underlying pda
pda *pda
// Compiled terminal matcher
matcher *terminalMatcher
}
// ParseEBNF compiles an EBNF grammar string into a Grammar.
// startRule is the name of the start rule (e.g., "root", "json").
func ParseEBNF(ebnf string, startRule string) (*Grammar, error) {
pda, err := compileString(ebnf, startRule)
if err != nil {
return nil, fmt.Errorf("failed to compile EBNF: %w", err)
}
matcher, err := compileTerminalsStrict(pda)
if err != nil {
return nil, fmt.Errorf("failed to compile terminals: %w", err)
}
return &Grammar{
pda: pda,
matcher: matcher,
}, nil
}
// JSONGrammar returns the compiled JSON grammar.
// This is a convenience wrapper for ParseEBNF(JSONGrammarEBNF, "json").
func JSONGrammar() (*Grammar, error) {
return ParseEBNF(JSONGrammarEBNF, "json")
}
// JSONObjectGrammar returns a JSON grammar that only allows objects at the top level.
// Use this when you want to ensure the output is a JSON object (starts with {).
func JSONObjectGrammar() (*Grammar, error) {
return ParseEBNF(JSONObjectGrammarEBNF, "json")
}
// compileTerminalsStrict builds a matcher that properly handles:
// - Escaped literals ("\n", \"", \uXXXX)
// - Unicode ranges (rune-based, not byte-based)
// - Rejects unsupported patterns with an error (no silent fallback)
func compileTerminalsStrict(pda *pda) (*terminalMatcher, error) {
m := &terminalMatcher{
literalTrie: &trieNode{terminalID: -1},
ranges: make([]terminal, 0),
terminals: make([]terminal, 0, len(pda.Terminals)),
patternToID: make(map[string]int),
}
// Track which pattern produced each unescaped value for collision detection
unescapedSource := make(map[string]string) // unescaped -> original pattern
for i, pattern := range pda.Terminals {
terminal, err := parseTerminalPattern(pattern, i)
if err != nil {
return nil, fmt.Errorf("terminal %q: %w", pattern, err)
}
if terminal.Type == terminalLiteral {
// Use the unescaped pattern for trie matching
m.addLiteralToTrie(terminal.Unescaped, i)
// Detect collisions between literals that unescape to the same value
if existingPattern, exists := unescapedSource[terminal.Unescaped]; exists {
if existingPattern != pattern {
return nil, fmt.Errorf("collision: patterns %q and %q both unescape to %q",
existingPattern, pattern, terminal.Unescaped)
}
} else {
unescapedSource[terminal.Unescaped] = pattern
}
} else if terminal.Type == terminalRange {
m.ranges = append(m.ranges, terminal)
}
m.terminals = append(m.terminals, terminal)
m.patternToID[pattern] = i
}
return m, nil
}
// parseTerminalPattern parses a terminal pattern and returns a terminal.
// Supports:
// - Literal strings (with escape sequences)
// - Character ranges [X-Y] (unicode-aware)
func parseTerminalPattern(pattern string, id int) (terminal, error) {
if len(pattern) == 0 {
return terminal{}, fmt.Errorf("empty pattern")
}
// Check for range pattern: [X-Y]
if isUnicodeRangePattern(pattern) {
lowRune, highRune, err := parseUnicodeRange(pattern)
if err != nil {
return terminal{}, err
}
return terminal{
ID: id,
Type: terminalRange,
Pattern: pattern,
Unescaped: pattern,
LowRune: lowRune,
HighRune: highRune,
}, nil
}
// It's a literal - unescape it
unescaped, err := unescapeLiteral(pattern)
if err != nil {
return terminal{}, fmt.Errorf("invalid escape sequence: %w", err)
}
return terminal{
ID: id,
Type: terminalLiteral,
Pattern: pattern,
Unescaped: unescaped,
}, nil
}
// isUnicodeRangePattern checks if pattern is a character range like [a-z] or [\u0000-\uFFFF]
func isUnicodeRangePattern(pattern string) bool {
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
return false
}
// Find the dash that separates low-high
inner := pattern[1 : len(pattern)-1]
dashIdx := strings.Index(inner, "-")
// Handle escaped dash at start
if dashIdx <= 0 {
return false
}
return true
}
// parseUnicodeRange parses [X-Y] into low and high runes
func parseUnicodeRange(pattern string) (rune, rune, error) {
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
return 0, 0, fmt.Errorf("invalid range pattern")
}
inner := pattern[1 : len(pattern)-1]
// Simple case: [a-z] where a and z are single chars
if len(inner) == 3 && inner[1] == '-' {
return rune(inner[0]), rune(inner[2]), nil
}
// Handle escaped characters like [\u0000-\uFFFF]
dashIdx := findRangeDash(inner)
if dashIdx < 0 {
return 0, 0, fmt.Errorf("no dash in range")
}
lowStr := inner[:dashIdx]
highStr := inner[dashIdx+1:]
lowRune, err := parseRune(lowStr)
if err != nil {
return 0, 0, fmt.Errorf("invalid low bound: %w", err)
}
highRune, err := parseRune(highStr)
if err != nil {
return 0, 0, fmt.Errorf("invalid high bound: %w", err)
}
if lowRune > highRune {
return 0, 0, fmt.Errorf("low bound > high bound")
}
return lowRune, highRune, nil
}
// findRangeDash finds the dash separating low-high in a range pattern
func findRangeDash(inner string) int {
i := 0
for i < len(inner) {
if inner[i] == '\\' && i+1 < len(inner) {
// Skip escape sequence
if inner[i+1] == 'u' && i+6 <= len(inner) {
i += 6 // \uXXXX
} else {
i += 2 // \n, \t, etc.
}
continue
}
if inner[i] == '-' && i > 0 {
return i
}
i++
}
return -1
}
// parseRune parses a single rune from a string (handles escapes)
func parseRune(s string) (rune, error) {
if len(s) == 0 {
return 0, fmt.Errorf("empty rune")
}
// Handle escape sequences
if s[0] == '\\' {
if len(s) < 2 {
return 0, fmt.Errorf("incomplete escape")
}
switch s[1] {
case 'n':
return '\n', nil
case 't':
return '\t', nil
case 'r':
return '\r', nil
case '\\':
return '\\', nil
case '"':
return '"', nil
case '\'':
return '\'', nil
case 'u':
if len(s) < 6 {
return 0, fmt.Errorf("incomplete unicode escape")
}
val, err := strconv.ParseInt(s[2:6], 16, 32)
if err != nil {
return 0, fmt.Errorf("invalid unicode escape: %w", err)
}
return rune(val), nil
default:
return 0, fmt.Errorf("unknown escape: \\%c", s[1])
}
}
// Plain character
r, _ := utf8.DecodeRuneInString(s)
if r == utf8.RuneError {
return 0, fmt.Errorf("invalid utf8")
}
return r, nil
}
// unescapeLiteral unescapes a literal pattern string
func unescapeLiteral(pattern string) (string, error) {
// Try strconv.Unquote if it looks quoted
if len(pattern) >= 2 && pattern[0] == '"' && pattern[len(pattern)-1] == '"' {
unquoted, err := strconv.Unquote(pattern)
if err != nil {
return "", err
}
return unquoted, nil
}
// If no backslashes, return as-is
if !strings.Contains(pattern, "\\") {
return pattern, nil
}
// Manual unescape
var result strings.Builder
i := 0
for i < len(pattern) {
if pattern[i] == '\\' && i+1 < len(pattern) {
switch pattern[i+1] {
case 'n':
result.WriteByte('\n')
i += 2
case 't':
result.WriteByte('\t')
i += 2
case 'r':
result.WriteByte('\r')
i += 2
case '\\':
result.WriteByte('\\')
i += 2
case '"':
result.WriteByte('"')
i += 2
case '\'':
result.WriteByte('\'')
i += 2
case 'u':
if i+6 <= len(pattern) {
val, err := strconv.ParseInt(pattern[i+2:i+6], 16, 32)
if err != nil {
return "", fmt.Errorf("invalid unicode escape at %d", i)
}
result.WriteRune(rune(val))
i += 6
} else {
return "", fmt.Errorf("incomplete unicode escape at %d", i)
}
default:
// Reject unknown escape sequences
return "", fmt.Errorf("unknown escape sequence: \\%c at position %d", pattern[i+1], i)
}
} else {
result.WriteByte(pattern[i])
i++
}
}
return result.String(), nil
}

329
x/grammar/engine.go Normal file
View File

@@ -0,0 +1,329 @@
//go:build mlx
package grammar
import (
"container/list"
"fmt"
"math"
"sync"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// maskCache provides LRU caching for computed masks.
type maskCache struct {
cache map[uint64]*list.Element
order *list.List
maxSize int
mu sync.Mutex
}
type maskEntry struct {
sig uint64
mask *mlx.Array
}
// newMaskCache creates a new mask cache with the given max size
// If maxSize <= 0, the cache is disabled (Get/Put are no-ops)
func newMaskCache(maxSize int) *maskCache {
if maxSize <= 0 {
return &maskCache{
cache: make(map[uint64]*list.Element),
order: list.New(),
maxSize: 0, // Signals disabled
}
}
return &maskCache{
cache: make(map[uint64]*list.Element),
order: list.New(),
maxSize: maxSize,
}
}
// get retrieves a cached mask, returning nil if not found.
// Updates LRU order on cache hit.
func (c *maskCache) get(sig uint64) *mlx.Array {
if c.maxSize <= 0 {
return nil // Cache disabled
}
c.mu.Lock()
defer c.mu.Unlock()
if elem, ok := c.cache[sig]; ok {
c.order.MoveToFront(elem)
return elem.Value.(*maskEntry).mask
}
return nil
}
// put stores a mask in the cache with LRU eviction.
func (c *maskCache) put(sig uint64, mask *mlx.Array) {
if c.maxSize <= 0 {
return // Cache disabled
}
c.mu.Lock()
defer c.mu.Unlock()
if elem, exists := c.cache[sig]; exists {
c.order.MoveToFront(elem)
return
}
// Evict oldest if at capacity (safe since maxSize > 0)
if c.order.Len() >= c.maxSize {
oldest := c.order.Back()
if oldest != nil {
entry := oldest.Value.(*maskEntry)
entry.mask.Free()
delete(c.cache, entry.sig)
c.order.Remove(oldest)
}
}
elem := c.order.PushFront(&maskEntry{sig: sig, mask: mask})
c.cache[sig] = elem
}
// clear frees all cached masks.
func (c *maskCache) clear() {
c.mu.Lock()
defer c.mu.Unlock()
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
elem.Value.(*maskEntry).mask.Free()
}
c.cache = make(map[uint64]*list.Element)
c.order.Init()
}
// size returns the number of cached masks.
func (c *maskCache) size() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.cache)
}
// Engine applies grammar constraints to model outputs using MLX.
// It uses a token→pda bridge for strict correctness with arbitrary BPE tokens.
type Engine struct {
// The compiled grammar
grammar *Grammar
// bridge for token validation
bridge *bridge
analyzer *analyzer
// Current parser state (configSet for nondeterminism)
configSet *configSet
// Token vocabulary from the model
vocab []string
tokenToID map[string]int // O(1) lookup for AcceptString
// Mask cache: configSig → valid token mask (LRU)
maskCache *maskCache
// Cached negative infinity mask for invalid tokens
negInfMask *mlx.Array
// Threshold for comparison (0.5 since mask values are 0 or 1)
threshold *mlx.Array
// Vocabulary size
vocabSize int32
// Reusable buffers for candidate filtering (avoid allocations)
candidateMark []bool // indexed by tokenID, true if in candidate set
touched []int // tokenIDs that were marked (for reset)
dpCandidates []int // candidates requiring DP validation
// Reusable buffer for valid token indices (for GPU scatter)
validTokenIDs []int32
}
// EngineOption configures an Engine
type EngineOption func(*Engine)
// WithMaskCacheSize sets the mask cache size (default 1024)
func WithMaskCacheSize(size int) EngineOption {
return func(e *Engine) {
e.maskCache = newMaskCache(size)
}
}
// NewEngine creates a new constrained decoding engine.
// grammar is the compiled grammar (use JSONGrammar() or ParseEBNF()).
// vocab is the list of token strings from the model's tokenizer.
func NewEngine(grammar *Grammar, vocab []string, opts ...EngineOption) (*Engine, error) {
if grammar == nil {
return nil, fmt.Errorf("grammar cannot be nil")
}
// Build analyzer and bridge
analyzer := newAnalyzer(vocab, grammar.matcher)
bridge := newBridge(grammar.pda, analyzer)
// Initialize config set from pda initial state
initialConfig := newConfigSet(grammar.pda.StartState, nil)
// Build token lookup map for O(1) AcceptString
tokenToID := make(map[string]int, len(vocab))
for i, tok := range vocab {
tokenToID[tok] = i
}
e := &Engine{
grammar: grammar,
bridge: bridge,
analyzer: analyzer,
configSet: initialConfig,
vocab: vocab,
tokenToID: tokenToID,
maskCache: newMaskCache(1024),
vocabSize: int32(len(vocab)),
candidateMark: make([]bool, len(vocab)),
touched: make([]int, 0, 10000),
validTokenIDs: make([]int32, 0, 10000),
}
// Apply options
for _, opt := range opts {
opt(e)
}
// Create the negative infinity mask and threshold
if e.vocabSize > 0 {
e.negInfMask = mlx.FullDtype(float32(math.Inf(-1)), mlx.DtypeFloat32, e.vocabSize)
mlx.Keep(e.negInfMask)
e.threshold = mlx.NewScalarArray(0.5)
mlx.Keep(e.threshold)
}
return e, nil
}
// ApplyMask applies grammar constraints to logits.
// Returns logits with invalid tokens set to -inf.
func (e *Engine) ApplyMask(logits *mlx.Array) *mlx.Array {
sig := e.configSet.signature()
// Check state cache first (exact state match)
if cached := e.maskCache.get(sig); cached != nil {
condition := mlx.GreaterEqual(cached, e.threshold)
return mlx.Where(condition, logits, e.negInfMask)
}
// Compute valid tokens using candidate filtering:
// 1. Get valid terminal IDs from current grammar state
// 2. Get candidate tokens (those that START with valid terminals)
// 3. Run DP validation only on candidates
// This is O(candidates) instead of O(vocab_size)
validTerminalIDs := e.bridge.validTerminalIDs(e.configSet)
// Use pre-partitioned token groups for fast candidate building
// This eliminates per-token branching - just direct slice appends
e.validTokenIDs = e.validTokenIDs[:0]
e.dpCandidates = e.dpCandidates[:0]
e.touched = e.touched[:0]
for _, tid := range validTerminalIDs {
groups := e.analyzer.terminalGroups(tid)
// Direct append of exact matches (no per-token check needed)
e.validTokenIDs = append(e.validTokenIDs, groups.ExactMatches...)
// Collect DP candidates (may have duplicates across terminals)
for _, tokenID := range groups.DPCandidates {
if !e.candidateMark[tokenID] {
e.candidateMark[tokenID] = true
e.dpCandidates = append(e.dpCandidates, tokenID)
e.touched = append(e.touched, tokenID)
}
}
}
// Reset marks for next call
for _, id := range e.touched {
e.candidateMark[id] = false
}
for _, tokenID := range e.dpCandidates {
if e.bridge.IsTokenValid(tokenID, e.configSet) {
e.validTokenIDs = append(e.validTokenIDs, int32(tokenID))
}
}
// Create and cache the mask on GPU using index updates
mask := mlx.Zeros([]int32{e.vocabSize})
if len(e.validTokenIDs) > 0 {
indices := mlx.NewArrayInt32(e.validTokenIDs, []int32{int32(len(e.validTokenIDs))})
values := mlx.Ones(int32(len(e.validTokenIDs)))
mask = mlx.PutAlongAxis(mask, indices, values, 0)
}
mlx.Keep(mask)
// Cache by state signature
e.maskCache.put(sig, mask)
// Apply mask
condition := mlx.GreaterEqual(mask, e.threshold)
return mlx.Where(condition, logits, e.negInfMask)
}
// Accept processes a token and updates the parser state.
// Returns true if the token was valid and accepted.
func (e *Engine) Accept(tokenID int) bool {
if tokenID < 0 || tokenID >= len(e.vocab) {
return false
}
newConfig := e.bridge.acceptToken(tokenID, e.configSet)
if newConfig == nil {
return false
}
e.configSet = newConfig
return true
}
// AcceptString processes a token string directly.
// Returns true if the token was valid and accepted.
func (e *Engine) AcceptString(token string) bool {
if id, ok := e.tokenToID[token]; ok {
return e.Accept(id)
}
return false
}
// IsComplete returns true if the current state is accepting.
func (e *Engine) IsComplete() bool {
return e.bridge.isAccepting(e.configSet)
}
// Reset resets the engine to initial state.
func (e *Engine) Reset() {
e.configSet = newConfigSet(e.grammar.pda.StartState, nil)
}
// validTokens returns the indices of tokens that are currently valid.
func (e *Engine) validTokens() []int {
return e.bridge.validTokens(e.configSet)
}
// validTerminals returns the valid terminal patterns from the current state.
func (e *Engine) validTerminals() []string {
return e.bridge.validTerminals(e.configSet)
}
// Close releases MLX resources.
func (e *Engine) Close() {
if e.maskCache != nil {
e.maskCache.clear()
}
if e.negInfMask != nil {
e.negInfMask.Free()
}
if e.threshold != nil {
e.threshold.Free()
}
}

View File

@@ -0,0 +1,414 @@
//go:build mlx
package grammar
import (
"fmt"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// newBenchEngine creates a JSON engine for benchmarks
func newBenchEngine(b *testing.B, vocab []string) *Engine {
b.Helper()
grammar, err := JSONGrammar()
if err != nil {
b.Fatalf("failed to create JSON grammar: %v", err)
}
e, err := NewEngine(grammar, vocab)
if err != nil {
b.Fatalf("failed to create engine: %v", err)
}
return e
}
// Vocabulary sizes to test (matching real models)
var vocabSizes = []int{
32000, // Llama 2
128000, // Llama 3
256000, // Large models
}
// createBenchVocabN creates a vocabulary of size n with realistic token distribution
func createBenchVocabN(n int) []string {
vocab := make([]string, n)
// JSON structural tokens (first 20)
jsonTokens := []string{
"{", "}", "[", "]", ":", ",",
"true", "false", "null",
" ", "\n", "\t", "\r",
"\"", "'",
}
for i, t := range jsonTokens {
if i < n {
vocab[i] = t
}
}
// String tokens (indices 20-1000)
stringIdx := 20
for i := 0; i < 980 && stringIdx+i < n; i++ {
vocab[stringIdx+i] = fmt.Sprintf("\"token%d\"", i)
}
// Number tokens (indices 1000-2000)
numberIdx := 1000
for i := 0; i < 1000 && numberIdx+i < n; i++ {
vocab[numberIdx+i] = fmt.Sprintf("%d", i)
}
// Generic tokens (rest)
for i := 2000; i < n; i++ {
vocab[i] = fmt.Sprintf("tok%d", i)
}
return vocab
}
// ============ Core Performance Benchmarks ============
// BenchmarkApplyMask_32k measures mask application with 32k vocab
func BenchmarkApplyMask_32k(b *testing.B) {
benchmarkApplyMask(b, 32000)
}
// BenchmarkApplyMask_128k measures mask application with 128k vocab
func BenchmarkApplyMask_128k(b *testing.B) {
benchmarkApplyMask(b, 128000)
}
// BenchmarkApplyMask_256k measures mask application with 256k vocab
func BenchmarkApplyMask_256k(b *testing.B) {
benchmarkApplyMask(b, 256000)
}
func benchmarkApplyMask(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(vocabSize))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ReportMetric(float64(vocabSize), "vocab_size")
}
// ============ state-Dependent Benchmarks ============
// BenchmarkApplyMaskAfterBrace measures mask after { (STRING or } valid)
func BenchmarkApplyMaskAfterBrace(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
e.AcceptString("{")
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// BenchmarkApplyMaskMidObject measures mask in middle of object
func BenchmarkApplyMaskMidObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
// state: {"key": _value_
e.AcceptString("{")
e.AcceptString("\"key\"")
e.AcceptString(":")
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// ============ Token Sequence Benchmarks ============
// BenchmarkSequence_SimpleObject benchmarks {"key": "value"}
func BenchmarkSequence_SimpleObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_NestedObject benchmarks {"a": {"b": {"c": 1}}}
func BenchmarkSequence_NestedObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{
"{", "\"a\"", ":", "{", "\"b\"", ":", "{", "\"c\"", ":", "1", "}", "}", "}",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_LargeArray benchmarks [1, 2, 3, ..., 100]
func BenchmarkSequence_LargeArray(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Build sequence: [1, 2, 3, ..., 50]
sequence := []string{"["}
for i := 1; i <= 50; i++ {
sequence = append(sequence, fmt.Sprintf("%d", i))
if i < 50 {
sequence = append(sequence, ",")
}
}
sequence = append(sequence, "]")
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// BenchmarkSequence_MixedTypes benchmarks complex mixed-type object
func BenchmarkSequence_MixedTypes(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{
"{",
"\"name\"", ":", "\"test\"", ",",
"\"count\"", ":", "42", ",",
"\"enabled\"", ":", "true", ",",
"\"data\"", ":", "null", ",",
"\"items\"", ":", "[", "1", ",", "2", ",", "3", "]",
"}",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
e.AcceptString(token)
}
}
b.ReportMetric(float64(len(sequence)), "tokens")
}
// ============ Component Benchmarks ============
// BenchmarkValidInputs measures pda valid input computation
func BenchmarkValidInputs(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.validTerminals()
}
}
// BenchmarkStateTransition measures pda state transition
func BenchmarkStateTransition(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
for _, token := range sequence {
e.AcceptString(token)
}
}
}
// BenchmarkConstrainedGrammar_128k benchmarks x/grammar (graph only, no eval).
func BenchmarkConstrainedGrammar_128k(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.ApplyMask(logits) // Graph only, no eval
}
}
// BenchmarkNewEngine measures one-time engine initialization.
func BenchmarkNewEngine_32k(b *testing.B) {
benchmarkNewEngine(b, 32000)
}
func BenchmarkNewEngine_128k(b *testing.B) {
benchmarkNewEngine(b, 128000)
}
func benchmarkNewEngine(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
e := newBenchEngine(b, vocab)
e.Close()
}
}
// ============ Memory Benchmarks ============
func BenchmarkMemoryAllocs_32k(b *testing.B) {
benchmarkMemoryAllocs(b, 32000)
}
func BenchmarkMemoryAllocs_128k(b *testing.B) {
benchmarkMemoryAllocs(b, 128000)
}
func benchmarkMemoryAllocs(b *testing.B, vocabSize int) {
vocab := createBenchVocabN(vocabSize)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(vocabSize))
mlx.Keep(logits)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
}
// ============ No-Eval Benchmarks (simulating LLM graph integration) ============
// BenchmarkApplyMaskNoEval_128k measures mask generation WITHOUT GPU sync
// This simulates adding mask to LLM compute graph
func BenchmarkApplyMaskNoEval_128k(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
// Warm up
for i := 0; i < 10; i++ {
masked := e.ApplyMask(logits)
mlx.Eval(masked)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.ApplyMask(logits) // No Eval - just build graph
}
}
// BenchmarkSequenceNoEval simulates real LLM usage - build graph, eval once at end
func BenchmarkSequenceNoEval_SimpleObject(b *testing.B) {
vocab := createBenchVocabN(128000)
e := newBenchEngine(b, vocab)
defer e.Close()
logits := mlx.Ones(int32(128000))
mlx.Keep(logits)
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
e.Reset()
var lastMasked *mlx.Array
for _, token := range sequence {
lastMasked = e.ApplyMask(logits) // Build graph only
e.AcceptString(token)
}
mlx.Eval(lastMasked) // Single eval at end
}
b.ReportMetric(float64(len(sequence)), "tokens")
}

689
x/grammar/engine_test.go Normal file
View File

@@ -0,0 +1,689 @@
//go:build mlx
package grammar
import (
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// newTestEngine creates a JSON engine for testing
func newTestEngine(t testing.TB, vocab []string) *Engine {
t.Helper()
grammar, err := JSONGrammar()
if err != nil {
t.Fatalf("failed to create JSON grammar: %v", err)
}
e, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
return e
}
// Mock vocabulary for testing
func testVocab() []string {
return []string{
"{", // 0: object start
"}", // 1: object end
"[", // 2: array start
"]", // 3: array end
":", // 4: colon
",", // 5: comma
"\"key\"", // 6: string (quoted)
"\"val\"", // 7: string (quoted)
"123", // 8: number
"-42.5", // 9: number
"true", // 10: boolean
"false", // 11: boolean
"null", // 12: null
" ", // 13: whitespace (should be ignored)
"\n", // 14: whitespace (should be ignored)
"subword", // 15: bare word (NOT valid JSON - requires quotes)
"hello", // 16: bare word (NOT valid JSON - requires quotes)
}
}
func TestNewEngine(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
if e.vocabSize != int32(len(vocab)) {
t.Errorf("vocabSize = %d, want %d", e.vocabSize, len(vocab))
}
// Verify grammar is set
if e.grammar == nil {
t.Error("grammar should not be nil")
}
// Verify analyzer is set
if e.analyzer == nil {
t.Error("analyzer should not be nil")
}
}
func TestEngineValidTokens(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// At start, any value type should be valid
validTokens := e.validTokens()
// Should include object start, array start, strings, numbers, booleans, null
// Note: bare words like "subword" and "hello" are NOT valid JSON strings
// (JSON strings must be quoted)
expectedTokens := map[int]bool{
0: true, // {
2: true, // [
6: true, // "key"
7: true, // "val"
8: true, // 123
9: true, // -42.5
10: true, // true
11: true, // false
12: true, // null
}
// Check that expected tokens are present
validSet := make(map[int]bool)
for _, idx := range validTokens {
validSet[idx] = true
}
for idx := range expectedTokens {
if !validSet[idx] {
t.Errorf("expected token %d (%s) to be valid", idx, vocab[idx])
}
}
if validSet[15] || validSet[16] {
t.Error("bare words should not be valid JSON at the start state")
}
}
func TestEngineAccept(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept { should work
if !e.Accept(0) { // {
t.Error("should accept {")
}
// After {, valid tokens should be STRING or }
validTokens := e.validTokens()
validSet := make(map[int]bool)
for _, idx := range validTokens {
validSet[idx] = true
}
// STRING tokens (indices 6, 7) and } (index 1) should be valid
if !validSet[1] {
t.Error("} should be valid after {")
}
if !validSet[6] && !validSet[7] {
t.Error("STRING should be valid after { (for keys)")
}
}
func TestEngineAcceptSequence(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept {"key": "val"}
sequence := []int{0, 6, 4, 7, 1} // {, "key", :, "val", }
for i, tokenID := range sequence {
if !e.Accept(tokenID) {
t.Fatalf("failed to accept token %d (%s) at position %d",
tokenID, vocab[tokenID], i)
}
}
if !e.IsComplete() {
t.Error("should be in complete state after valid JSON")
}
}
func TestEngineReset(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept some tokens
e.Accept(0) // {
e.Accept(1) // }
if !e.IsComplete() {
t.Error("should be complete after {}")
}
// Reset
e.Reset()
// Should be back to initial state
if e.IsComplete() {
t.Error("should not be complete after reset")
}
// Should be able to accept new sequence
if !e.Accept(0) { // {
t.Error("should accept { after reset")
}
}
func TestEngineInvalidTokenRejection(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept { first
if !e.Accept(0) {
t.Fatal("should accept {")
}
// Now try to accept [ which is invalid after {
// (After {, only STRING or } are valid)
if e.Accept(2) { // [
t.Error("should not accept [ after { (expecting STRING or })")
}
}
func TestEngineAcceptString(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Accept using string directly
if !e.AcceptString("{") {
t.Error("should accept {")
}
if !e.AcceptString("\"key\"") {
t.Error("should accept string key")
}
if !e.AcceptString(":") {
t.Error("should accept :")
}
if !e.AcceptString("123") {
t.Error("should accept number")
}
if !e.AcceptString("}") {
t.Error("should accept }")
}
if !e.IsComplete() {
t.Error("should be complete after valid JSON")
}
}
func TestJSONBackslashEscape(t *testing.T) {
vocab := []string{`"`, `\`, "n", "a"}
e := newTestEngine(t, vocab)
defer e.Close()
// Valid escape: "\n"
if !e.AcceptString(`"`) {
t.Fatal("should accept string start")
}
if !e.AcceptString(`\`) {
t.Fatal("should accept escape prefix")
}
if !e.AcceptString("n") {
t.Fatal("should accept escape code")
}
if !e.AcceptString(`"`) {
t.Fatal("should accept string end")
}
if !e.IsComplete() {
t.Error("should be complete after escaped string")
}
// Invalid escape: "\a"
e.Reset()
if !e.AcceptString(`"`) {
t.Fatal("should accept string start")
}
if !e.AcceptString(`\`) {
t.Fatal("should accept escape prefix")
}
if e.AcceptString("a") {
t.Error("should reject invalid escape code")
}
}
func TestEngineNegInfMask(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Verify negInfMask exists and has correct shape
if e.negInfMask == nil {
t.Fatal("negInfMask should not be nil")
}
}
func TestEngineMaskCache(t *testing.T) {
vocab := testVocab()
e := newTestEngine(t, vocab)
defer e.Close()
// Create test logits
logits := mlx.Ones(int32(len(vocab)))
// Apply mask - should populate cache
_ = e.ApplyMask(logits)
// Check cache was populated
cacheSize := e.maskCache.size()
if cacheSize == 0 {
t.Error("mask cache should have at least one entry after ApplyMask")
}
}
func TestEngineEmptyVocab(t *testing.T) {
e := newTestEngine(t, []string{})
defer e.Close()
if e.vocabSize != 0 {
t.Errorf("vocabSize = %d, want 0", e.vocabSize)
}
}
func TestEngineLargeVocab(t *testing.T) {
// Create a large vocabulary (simulating real model vocab)
vocab := make([]string, 32000)
for i := range vocab {
vocab[i] = "token"
}
// Add some actual JSON tokens
vocab[0] = "{"
vocab[1] = "}"
vocab[2] = "["
vocab[3] = "]"
vocab[4] = ":"
vocab[5] = ","
vocab[6] = "\"test\""
vocab[7] = "123"
vocab[8] = "true"
vocab[9] = "false"
vocab[10] = "null"
e := newTestEngine(t, vocab)
defer e.Close()
if e.vocabSize != 32000 {
t.Errorf("vocabSize = %d, want 32000", e.vocabSize)
}
// Test that it still works correctly
if !e.Accept(0) { // {
t.Error("should accept {")
}
if !e.Accept(1) { // }
t.Error("should accept }")
}
if !e.IsComplete() {
t.Error("should be complete after {}")
}
}
// TestE2E_JSONDecoding tests end-to-end JSON constrained decoding.
func TestE2E_JSONDecoding(t *testing.T) {
// Create a realistic vocabulary with JSON tokens
vocab := []string{
// Structural tokens
"{", "}", "[", "]", ":", ",",
// Keywords
"true", "false", "null",
// Quoted strings
`"name"`, `"value"`, `"items"`, `"count"`, `"enabled"`,
`"hello"`, `"world"`, `"test"`,
// Numbers
"0", "1", "2", "3", "42", "123", "-1", "-42",
// Whitespace
" ", "\n", "\t",
// Multi-terminal tokens (span multiple JSON lexemes)
`"key":`, `},`, `],`, `{"`, `["`,
// Partial/invalid tokens (should be rejected)
"invalid", "foo", "bar",
}
grammar, err := JSONGrammar()
if err != nil {
t.Fatalf("failed to create JSON grammar: %v", err)
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
// Simple values
{"empty object", []string{"{", "}"}, true},
{"empty array", []string{"[", "]"}, true},
{"true literal", []string{"true"}, true},
{"null literal", []string{"null"}, true},
{"number", []string{"42"}, true},
{"negative number", []string{"-42"}, true},
{"quoted string", []string{`"hello"`}, true},
// Objects
{"simple object", []string{"{", `"name"`, ":", `"value"`, "}"}, true},
{"object with single-digit numbers", []string{"{", `"count"`, ":", "1", ",", `"value"`, ":", "2", "}"}, true},
{"multi-terminal key", []string{"{", `"key":`, `"value"`, "}"}, true},
// Arrays
{"array of numbers", []string{"[", "42", "]"}, true},
{"array of single digits", []string{"[", "1", ",", "2", "]"}, true},
{"array of strings", []string{"[", `"hello"`, ",", `"world"`, "]"}, true},
{"nested array", []string{"[", "[", "42", "]", "]"}, true},
// Nested structures
{"nested object", []string{"{", `"items"`, ":", "{", `"count"`, ":", "42", "}", "}"}, true},
{"object with array", []string{"{", `"items"`, ":", "[", "42", "]", "}"}, true},
// Invalid sequences
{"unclosed object", []string{"{", `"name"`, ":"}, false}, // incomplete
{"double comma", []string{"[", "42", ",", ",", "42", "]"}, false}, // invalid
{"missing value", []string{"{", `"name"`, ":", "}"}, false}, // missing value
{"bare word", []string{"invalid"}, false}, // not valid JSON
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
// Process each token
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass {
if !allAccepted {
return // Already reported error
}
if !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
} else {
// For invalid sequences, we expect either rejection or incomplete
if allAccepted && engine.IsComplete() {
t.Errorf("expected rejection or incomplete, but parse succeeded")
}
}
})
}
}
// TestE2E_SimpleExpressionGrammar tests a custom expression grammar.
func TestE2E_SimpleExpressionGrammar(t *testing.T) {
// Simple expression grammar: expr = term { ("+" | "-") term }
// term = number | "(" expr ")"
// number = digit { digit }
// digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
exprGrammar := `
expr = term { addop term } .
addop = "+" | "-" .
term = factor { mulop factor } .
mulop = "*" | "/" .
factor = number | "(" expr ")" .
number = digit { digit } .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
`
grammar, err := ParseEBNF(exprGrammar, "expr")
if err != nil {
t.Fatalf("failed to parse expression grammar: %v", err)
}
// Vocabulary for expression tokens
vocab := []string{
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "*", "/",
"(", ")",
// Multi-digit numbers as single tokens
"10", "42", "100", "123",
// Invalid tokens
"x", "y", "invalid",
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
{"single digit", []string{"5"}, true},
{"multi-digit", []string{"1", "2", "3"}, true},
{"addition", []string{"1", "+", "2"}, true},
{"subtraction", []string{"5", "-", "3"}, true},
{"multiplication", []string{"2", "*", "3"}, true},
{"division", []string{"8", "/", "2"}, true},
{"complex expr", []string{"1", "+", "2", "*", "3"}, true},
{"parentheses", []string{"(", "1", "+", "2", ")", "*", "3"}, true},
{"nested parens", []string{"(", "(", "1", ")", ")"}, true},
// Invalid
{"just operator", []string{"+"}, false},
{"double operator", []string{"1", "+", "+", "2"}, false},
{"unclosed paren", []string{"(", "1", "+", "2"}, false},
{"variable", []string{"x"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass {
if !allAccepted {
return
}
if !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
} else {
if allAccepted && engine.IsComplete() {
t.Errorf("expected rejection or incomplete, but parse succeeded")
}
}
})
}
}
// TestE2E_IdentifierGrammar tests a grammar with character ranges.
func TestE2E_IdentifierGrammar(t *testing.T) {
// Identifier grammar using character ranges
identGrammar := `
ident = letter { letter | digit } .
letter = "a" … "z" | "A" … "Z" | "_" .
digit = "0" … "9" .
`
grammar, err := ParseEBNF(identGrammar, "ident")
if err != nil {
t.Fatalf("failed to parse identifier grammar: %v", err)
}
// Vocabulary with letters and digits
vocab := []string{
"a", "b", "c", "x", "y", "z",
"A", "B", "C", "X", "Y", "Z",
"_",
"0", "1", "2", "9",
// Multi-char tokens
"foo", "bar", "myVar", "test123",
// Invalid starting chars
"1abc", "123",
}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
tests := []struct {
name string
tokens []string
wantPass bool
}{
{"single letter", []string{"a"}, true},
{"uppercase", []string{"A"}, true},
{"underscore", []string{"_"}, true},
{"multi-letter", []string{"a", "b", "c"}, true},
{"letter then digit", []string{"x", "1"}, true},
{"underscore prefix", []string{"_", "a", "1"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine.Reset()
allAccepted := true
for i, token := range tt.tokens {
if !engine.AcceptString(token) {
if tt.wantPass {
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
}
allAccepted = false
break
}
}
if tt.wantPass && allAccepted && !engine.IsComplete() {
t.Errorf("expected complete parse, but not in accepting state")
}
})
}
}
// TestE2E_UnicodeRange ensures unicode ranges compile and match tokens.
func TestE2E_UnicodeRange(t *testing.T) {
greekGrammar := `
greek = "α" … "ω" .
`
grammar, err := ParseEBNF(greekGrammar, "greek")
if err != nil {
t.Fatalf("failed to parse unicode grammar: %v", err)
}
vocab := []string{"α", "β", "ω", "a"}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
if !engine.AcceptString("β") {
t.Error("should accept beta")
}
if !engine.IsComplete() {
t.Error("should be complete after single rune")
}
engine.Reset()
if engine.AcceptString("a") {
t.Error("should reject ASCII outside unicode range")
}
}
// TestE2E_NondeterminismPreserved tests that nondeterministic paths are preserved.
func TestE2E_NondeterminismPreserved(t *testing.T) {
// This grammar has nondeterminism: "ab" could be parsed as
// a single token or as two tokens "a" "b"
ambiguousGrammar := `
start = item item .
item = "a" | "b" | "ab" .
`
grammar, err := ParseEBNF(ambiguousGrammar, "start")
if err != nil {
t.Fatalf("failed to parse grammar: %v", err)
}
// Vocabulary with both single and combined tokens
vocab := []string{"a", "b", "ab"}
engine, err := NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("failed to create engine: %v", err)
}
defer engine.Close()
// Test: "ab" "a" should be valid (ab as first item, a as second)
t.Run("ab then a", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("ab") {
t.Error("should accept ab")
}
if !engine.AcceptString("a") {
t.Error("should accept a after ab")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
t.Run("a then ab", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("a") {
t.Error("should accept a")
}
if !engine.AcceptString("ab") {
t.Error("should accept ab after a")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
t.Run("a then a", func(t *testing.T) {
engine.Reset()
if !engine.AcceptString("a") {
t.Error("should accept first a")
}
if !engine.AcceptString("a") {
t.Error("should accept second a")
}
if !engine.IsComplete() {
t.Error("should be complete")
}
})
}

614
x/grammar/grammar.go Normal file
View File

@@ -0,0 +1,614 @@
//go:build mlx
// Package grammar provides GPU-accelerated constrained decoding using MLX.
// It compiles EBNF grammars to pushdown automata (pda) with precomputed token masks.
// For JSON Schema conversion, see the grammar/schema subpackage.
package grammar
import (
"encoding/binary"
"fmt"
"io"
"strings"
"golang.org/x/exp/ebnf"
)
// stackSymbol represents a symbol that can be pushed onto the pda stack.
type stackSymbol int
const (
stackEmpty stackSymbol = iota
// Additional stack symbols will be generated per-grammar
)
// state represents a pda state.
type state int
const (
stateError state = -1
stateStart state = 0
stateAccept state = 1
// Additional states will be generated per-grammar
)
// transition represents a pda transition.
// On input matching Pattern, from FromState with stackTop:
// - Move to ToState
// - Pop StackPop symbols, push StackPush symbols
type transition struct {
FromState state
stackTop stackSymbol // What must be on stack top (stackEmpty = don't care)
Pattern string // Input pattern to match (token or character class)
ToState state
StackPop int // Number of symbols to pop
StackPush []stackSymbol // Symbols to push (in order, first pushed first)
}
// pda represents a compiled pushdown automaton.
type pda struct {
States int // Total number of states
StackSymbols int // Total number of stack symbols
StartState state // Initial state
AcceptStates map[state]bool // Set of accepting states
Transitions map[state][]transition // Transitions indexed by from-state
// For token-level matching
Terminals []string // All terminal symbols (patterns to match)
}
// newPDA creates an empty pda.
func newPDA() *pda {
return &pda{
States: 2, // Error and Start
StackSymbols: 1, // Empty
StartState: stateStart,
AcceptStates: make(map[state]bool),
Transitions: make(map[state][]transition),
Terminals: make([]string, 0),
}
}
// addState adds a new state and returns its ID.
func (p *pda) addState() state {
s := state(p.States)
p.States++
return s
}
// addStackSymbol adds a new stack symbol and returns its ID.
func (p *pda) addStackSymbol() stackSymbol {
s := stackSymbol(p.StackSymbols)
p.StackSymbols++
return s
}
// addTransition adds a transition to the pda.
func (p *pda) addTransition(t transition) {
p.Transitions[t.FromState] = append(p.Transitions[t.FromState], t)
}
// addTerminal registers a terminal pattern and returns its index.
func (p *pda) addTerminal(pattern string) int {
for i, t := range p.Terminals {
if t == pattern {
return i
}
}
p.Terminals = append(p.Terminals, pattern)
return len(p.Terminals) - 1
}
// compiler compiles EBNF grammars to PDAs.
type compiler struct {
grammar ebnf.Grammar
pda *pda
// Maps production names to their entry/exit states
prodEntry map[string]state
prodExit map[string]state
}
// compile parses an EBNF grammar and compiles it to a pda.
func compile(name string, src io.Reader, start string) (*pda, error) {
grammar, err := ebnf.Parse(name, src)
if err != nil {
return nil, fmt.Errorf("parse grammar: %w", err)
}
if err := ebnf.Verify(grammar, start); err != nil {
return nil, fmt.Errorf("verify grammar: %w", err)
}
c := &compiler{
grammar: grammar,
pda: newPDA(),
prodEntry: make(map[string]state),
prodExit: make(map[string]state),
}
// Create entry/exit states for each production
for name := range grammar {
c.prodEntry[name] = c.pda.addState()
c.prodExit[name] = c.pda.addState()
}
// compile each production
for name, prod := range grammar {
if err := c.compileProduction(name, prod); err != nil {
return nil, fmt.Errorf("compile production %q: %w", name, err)
}
}
// Set start state to entry of start production
if entry, ok := c.prodEntry[start]; ok {
// Add epsilon transition from pda start to grammar start
c.pda.addTransition(transition{
FromState: stateStart,
Pattern: "", // epsilon
ToState: entry,
})
} else {
return nil, fmt.Errorf("start production %q not found", start)
}
// Mark exit of start production as accepting
if exit, ok := c.prodExit[start]; ok {
c.pda.AcceptStates[exit] = true
}
return c.pda, nil
}
// compileString is a convenience function to compile from a string.
func compileString(grammar string, start string) (*pda, error) {
return compile("grammar", strings.NewReader(grammar), start)
}
func (c *compiler) compileProduction(name string, prod *ebnf.Production) error {
entry := c.prodEntry[name]
exit := c.prodExit[name]
return c.compileExpr(prod.Expr, entry, exit)
}
func (c *compiler) compileExpr(expr ebnf.Expression, entry, exit state) error {
switch e := expr.(type) {
case *ebnf.Name:
return c.compileName(e, entry, exit)
case *ebnf.Token:
return c.compileToken(e, entry, exit)
case ebnf.Sequence:
return c.compileSequence(e, entry, exit)
case ebnf.Alternative:
return c.compileAlternative(e, entry, exit)
case *ebnf.Option:
return c.compileOption(e, entry, exit)
case *ebnf.Repetition:
return c.compileRepetition(e, entry, exit)
case *ebnf.Group:
return c.compileExpr(e.Body, entry, exit)
case *ebnf.Range:
return c.compileRange(e, entry, exit)
case nil:
// Empty production - direct epsilon transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
return nil
default:
return fmt.Errorf("unsupported expression type: %T", expr)
}
}
func (c *compiler) compileName(n *ebnf.Name, entry, exit state) error {
// Reference to another production
prodName := n.String
prodEntry, ok := c.prodEntry[prodName]
if !ok {
return fmt.Errorf("undefined production: %s", prodName)
}
prodExit := c.prodExit[prodName]
// Use a unique stack symbol per call site so returns are unambiguous.
stackSym := c.pda.addStackSymbol()
// Push return address, go to production entry
c.pda.addTransition(transition{
FromState: entry,
Pattern: "", // epsilon
ToState: prodEntry,
StackPush: []stackSymbol{stackSym},
})
// On production exit, pop and return
c.pda.addTransition(transition{
FromState: prodExit,
stackTop: stackSym,
Pattern: "", // epsilon
ToState: exit,
StackPop: 1,
})
return nil
}
func (c *compiler) compileToken(t *ebnf.Token, entry, exit state) error {
// terminal symbol - add transition that consumes this token
pattern := t.String
c.pda.addTerminal(pattern)
c.pda.addTransition(transition{
FromState: entry,
Pattern: pattern,
ToState: exit,
})
return nil
}
func (c *compiler) compileSequence(seq ebnf.Sequence, entry, exit state) error {
if len(seq) == 0 {
// Empty sequence - epsilon transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
return nil
}
// Chain: entry -> s1 -> s2 -> ... -> exit
current := entry
for i, expr := range seq {
var next state
if i == len(seq)-1 {
next = exit
} else {
next = c.pda.addState()
}
if err := c.compileExpr(expr, current, next); err != nil {
return err
}
current = next
}
return nil
}
func (c *compiler) compileAlternative(alt ebnf.Alternative, entry, exit state) error {
// Each alternative goes from entry to exit
for _, expr := range alt {
if err := c.compileExpr(expr, entry, exit); err != nil {
return err
}
}
return nil
}
func (c *compiler) compileOption(opt *ebnf.Option, entry, exit state) error {
// Optional: can skip (epsilon) or take the body
// Epsilon transition (skip)
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
// Or take the body
return c.compileExpr(opt.Body, entry, exit)
}
func (c *compiler) compileRepetition(rep *ebnf.Repetition, entry, exit state) error {
// Repetition {body}: zero or more
// entry -> exit (skip)
// entry -> body -> entry (loop back)
// Skip transition
c.pda.addTransition(transition{
FromState: entry,
Pattern: "",
ToState: exit,
})
// Loop: entry -> (body) -> entry
return c.compileExpr(rep.Body, entry, entry)
}
func (c *compiler) compileRange(r *ebnf.Range, entry, exit state) error {
// Character range like "a" … "z" or "\u03b1" … "\u03c9"
begin := strings.Trim(r.Begin.String, "\"")
end := strings.Trim(r.End.String, "\"")
// Unescape bounds first (so "\u03b1" works)
beginUnesc, err := unescapeLiteral(begin)
if err != nil {
return fmt.Errorf("invalid range begin: %w", err)
}
endUnesc, err := unescapeLiteral(end)
if err != nil {
return fmt.Errorf("invalid range end: %w", err)
}
// Validate as single runes (not bytes) for Unicode support
beginRunes := []rune(beginUnesc)
endRunes := []rune(endUnesc)
if len(beginRunes) != 1 || len(endRunes) != 1 {
return fmt.Errorf("range bounds must be single characters: %q..%q", r.Begin.String, r.End.String)
}
// Use unescaped rune strings in pattern (consistent with matcher)
pattern := fmt.Sprintf("[%s-%s]", string(beginRunes[0]), string(endRunes[0]))
c.pda.addTerminal(pattern)
c.pda.addTransition(transition{
FromState: entry,
Pattern: pattern,
ToState: exit,
})
return nil
}
// runtime represents a pda execution instance.
type runtime struct {
pda *pda
state state
stack []stackSymbol
}
// newRuntime creates a new pda runtime.
func newRuntime(pda *pda) *runtime {
return &runtime{
pda: pda,
state: pda.StartState,
stack: make([]stackSymbol, 0, 32),
}
}
// stackTop returns the top of the stack, or stackEmpty if empty.
func (r *runtime) stackTop() stackSymbol {
if len(r.stack) == 0 {
return stackEmpty
}
return r.stack[len(r.stack)-1]
}
// isAccepting returns true if we can reach an accepting state via epsilon transitions
// with an empty stack.
func (r *runtime) isAccepting() bool {
return r.canReachAccept(r.state, r.stack, make(map[stateStackKey]bool))
}
func (r *runtime) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
// Check if this state is accepting with empty stack
if r.pda.AcceptStates[state] && len(stack) == 0 {
return true
}
// Avoid infinite loops
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
if visited[key] {
return false
}
visited[key] = true
// Try epsilon transitions
for _, t := range r.pda.Transitions[state] {
if t.Pattern != "" {
continue // Not epsilon
}
// Check stack constraint
stackTop := stackEmpty
if len(stack) > 0 {
stackTop = stack[len(stack)-1]
}
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
// Simulate stack operations
newStack := make([]stackSymbol, len(stack))
copy(newStack, stack)
if t.StackPop > 0 && len(newStack) >= t.StackPop {
newStack = newStack[:len(newStack)-t.StackPop]
}
newStack = append(newStack, t.StackPush...)
if r.canReachAccept(t.ToState, newStack, visited) {
return true
}
}
return false
}
// Reset resets the runtime to initial state.
func (r *runtime) Reset() {
r.state = r.pda.StartState
r.stack = r.stack[:0]
}
// validInputs returns all valid input patterns from current state.
func (r *runtime) validInputs() []string {
var valid []string
seen := make(map[string]bool)
visited := make(map[stateStackKey]bool)
// Make a copy of the stack for simulation
simStack := make([]stackSymbol, len(r.stack))
copy(simStack, r.stack)
r.collectValidInputs(r.state, simStack, seen, visited, &valid)
return valid
}
// stateStackKey is used to detect cycles in epsilon closure
type stateStackKey struct {
state state
stackSig string
}
func stackSignature(stack []stackSymbol) string {
if len(stack) == 0 {
return ""
}
buf := make([]byte, len(stack)*8)
for i, sym := range stack {
binary.LittleEndian.PutUint64(buf[i*8:], uint64(sym))
}
return string(buf)
}
func (r *runtime) collectValidInputs(state state, simStack []stackSymbol, seen map[string]bool, visited map[stateStackKey]bool, valid *[]string) {
// Get stack top for comparisons
stackTop := stackEmpty
if len(simStack) > 0 {
stackTop = simStack[len(simStack)-1]
}
// Check for cycles to avoid infinite loops
key := stateStackKey{state: state, stackSig: stackSignature(simStack)}
if visited[key] {
return
}
visited[key] = true
transitions := r.pda.Transitions[state]
for _, t := range transitions {
// Check stack constraint
if t.stackTop != stackEmpty && t.stackTop != stackTop {
continue
}
if t.Pattern == "" {
// Epsilon transition - simulate stack operations
newStack := make([]stackSymbol, len(simStack))
copy(newStack, simStack)
// Pop
if t.StackPop > 0 {
if len(newStack) < t.StackPop {
continue // Can't pop, skip this transition
}
newStack = newStack[:len(newStack)-t.StackPop]
}
// Push
newStack = append(newStack, t.StackPush...)
r.collectValidInputs(t.ToState, newStack, seen, visited, valid)
} else {
// terminal - add if not seen
if !seen[t.Pattern] {
seen[t.Pattern] = true
*valid = append(*valid, t.Pattern)
}
}
}
}
// matchesPattern checks if input matches a pattern.
// Patterns can be:
// - Exact strings: "a", "{", "true"
// - Character ranges: "[a-z]", "[0-9]", "[#-~]"
func matchesPattern(input, pattern string) bool {
// Exact match
if input == pattern {
return true
}
// Check for character range pattern [X-Y]
if len(pattern) == 5 && pattern[0] == '[' && pattern[2] == '-' && pattern[4] == ']' {
if len(input) != 1 {
return false
}
ch := input[0]
low := pattern[1]
high := pattern[3]
return ch >= low && ch <= high
}
return false
}
// Accept tries to accept an input, returning true if successful.
func (r *runtime) Accept(input string) bool {
return r.accept(input, make(map[stateStackKey]bool))
}
func (r *runtime) accept(input string, visited map[stateStackKey]bool) bool {
key := stateStackKey{state: r.state, stackSig: stackSignature(r.stack)}
if visited[key] {
return false
}
visited[key] = true
transitions := r.pda.Transitions[r.state]
// First, process any epsilon transitions to reach a state that can accept input
// This is a simplified version - full implementation would need epsilon closure
for _, t := range transitions {
if matchesPattern(input, t.Pattern) {
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
continue
}
if t.StackPop > len(r.stack) {
continue
}
// Apply transition
r.applyTransition(t)
return true
}
}
// Try epsilon transitions first
for _, t := range transitions {
if t.Pattern == "" {
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
continue
}
if t.StackPop > len(r.stack) {
continue
}
// Save state for backtracking
oldState := r.state
oldStack := make([]stackSymbol, len(r.stack))
copy(oldStack, r.stack)
r.applyTransition(t)
if r.accept(input, visited) {
return true
}
// Backtrack
r.state = oldState
r.stack = oldStack
}
}
return false
}
func (r *runtime) applyTransition(t transition) {
// Pop
if t.StackPop > 0 && len(r.stack) >= t.StackPop {
r.stack = r.stack[:len(r.stack)-t.StackPop]
}
// Push
r.stack = append(r.stack, t.StackPush...)
// Move to new state
r.state = t.ToState
}

540
x/grammar/grammar_test.go Normal file
View File

@@ -0,0 +1,540 @@
//go:build mlx
package grammar
import (
"testing"
)
func TestCompileSimpleGrammar(t *testing.T) {
// Simple grammar: S = "a" "b" .
grammar := `S = "a" "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
if pda == nil {
t.Fatal("pda is nil")
}
// Should have terminals "a" and "b"
if len(pda.Terminals) != 2 {
t.Errorf("expected 2 terminals, got %d: %v", len(pda.Terminals), pda.Terminals)
}
// Test runtime
rt := newRuntime(pda)
// Should accept "a" then "b"
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.Accept("b") {
t.Error("should accept 'b'")
}
if !rt.isAccepting() {
t.Error("should be in accepting state")
}
}
func TestCompileAlternative(t *testing.T) {
// Grammar: S = "a" | "b" .
grammar := `S = "a" | "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// Test accepting "a"
rt := newRuntime(pda)
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'a'")
}
// Test accepting "b"
rt.Reset()
if !rt.Accept("b") {
t.Error("should accept 'b'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'b'")
}
// Test rejecting "c"
rt.Reset()
if rt.Accept("c") {
t.Error("should not accept 'c'")
}
}
func TestCompileRepetition(t *testing.T) {
// Grammar: S = {"a"} .
grammar := `S = {"a"} .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// Empty should be accepted (zero repetitions)
rt := newRuntime(pda)
if !rt.isAccepting() {
t.Error("empty should be accepting")
}
// "a" should be accepted
rt.Reset()
if !rt.Accept("a") {
t.Error("should accept first 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after one 'a'")
}
// "aa" should be accepted
if !rt.Accept("a") {
t.Error("should accept second 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after two 'a's")
}
}
func TestCompileOption(t *testing.T) {
// Grammar: S = ["a"] "b" .
grammar := `S = ["a"] "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// "b" alone should be accepted
rt := newRuntime(pda)
if !rt.Accept("b") {
t.Error("should accept 'b' alone")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'b'")
}
// "ab" should be accepted
rt.Reset()
if !rt.Accept("a") {
t.Error("should accept 'a'")
}
if !rt.Accept("b") {
t.Error("should accept 'b' after 'a'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'ab'")
}
}
func TestCompileRecursive(t *testing.T) {
// Grammar with recursion: S = "(" S ")" | "x" .
grammar := `S = "(" S ")" | "x" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
// "x" should be accepted
rt := newRuntime(pda)
if !rt.Accept("x") {
t.Error("should accept 'x'")
}
if !rt.isAccepting() {
t.Error("should be accepting after 'x'")
}
// "(x)" should be accepted
rt.Reset()
if !rt.Accept("(") {
t.Error("should accept '('")
}
if !rt.Accept("x") {
t.Error("should accept 'x' inside parens")
}
if !rt.Accept(")") {
t.Error("should accept ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '(x)'")
}
// "((x))" should be accepted
rt.Reset()
if !rt.Accept("(") {
t.Error("should accept first '('")
}
if !rt.Accept("(") {
t.Error("should accept second '('")
}
if !rt.Accept("x") {
t.Error("should accept 'x'")
}
if !rt.Accept(")") {
t.Error("should accept first ')'")
}
if !rt.Accept(")") {
t.Error("should accept second ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '((x))'")
}
}
func TestValidInputs(t *testing.T) {
// Grammar: S = "a" | "b" .
grammar := `S = "a" | "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
valid := rt.validInputs()
// Should have both "a" and "b" as valid
hasA, hasB := false, false
for _, v := range valid {
if v == "a" {
hasA = true
}
if v == "b" {
hasB = true
}
}
if !hasA {
t.Error("'a' should be valid input")
}
if !hasB {
t.Error("'b' should be valid input")
}
}
// TestValidInputsAfterAccept tests that validInputs returns correct values
// after accepting tokens, ensuring proper stack simulation.
func TestValidInputsAfterAccept(t *testing.T) {
// Grammar: S = "a" "b" "c" .
grammar := `S = "a" "b" "c" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially only "a" should be valid
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "a" {
t.Errorf("initially expected only 'a', got %v", valid)
}
// After accepting "a", only "b" should be valid
if !rt.Accept("a") {
t.Fatal("failed to accept 'a'")
}
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "b" {
t.Errorf("after 'a', expected only 'b', got %v", valid)
}
// After accepting "b", only "c" should be valid
if !rt.Accept("b") {
t.Fatal("failed to accept 'b'")
}
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "c" {
t.Errorf("after 'ab', expected only 'c', got %v", valid)
}
}
// TestValidInputsWithRepetitionInProduction tests the critical case where
// a repetition exists inside a called production. This requires proper
// stack simulation to determine when closing symbols are valid.
func TestValidInputsWithRepetitionInProduction(t *testing.T) {
// Grammar similar to JSON:
// S = "(" items ")" .
// items = item { "," item } .
// item = "x" .
grammar := `
S = "(" items ")" .
items = item { "," item } .
item = "x" .
`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially only "(" should be valid
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "(" {
t.Errorf("initially expected only '(', got %v", valid)
}
// Accept "("
if !rt.Accept("(") {
t.Fatal("failed to accept '('")
}
// After "(", should be able to accept "x" (item)
valid = rt.validInputs()
hasX := false
for _, v := range valid {
if v == "x" {
hasX = true
}
}
if !hasX {
t.Errorf("after '(', expected 'x' to be valid, got %v", valid)
}
// Accept first item "x"
if !rt.Accept("x") {
t.Fatal("failed to accept 'x'")
}
// After "(x", should be able to accept "," (more items) OR ")" (end)
valid = rt.validInputs()
hasComma, hasClose := false, false
for _, v := range valid {
if v == "," {
hasComma = true
}
if v == ")" {
hasClose = true
}
}
if !hasComma {
t.Errorf("after '(x', expected ',' to be valid, got %v", valid)
}
if !hasClose {
t.Errorf("after '(x', expected ')' to be valid, got %v", valid)
}
// Accept comma for another item
if !rt.Accept(",") {
t.Fatal("failed to accept ','")
}
// After "(x,", should only be able to accept "x" (next item)
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "x" {
t.Errorf("after '(x,', expected only 'x', got %v", valid)
}
// Accept second item "x"
if !rt.Accept("x") {
t.Fatal("failed to accept second 'x'")
}
// CRITICAL: After "(x,x", should be able to accept "," OR ")"
// This tests the stack simulation fix - we need to properly
// follow epsilon transitions through the production call stack.
valid = rt.validInputs()
hasComma, hasClose = false, false
for _, v := range valid {
if v == "," {
hasComma = true
}
if v == ")" {
hasClose = true
}
}
if !hasComma {
t.Errorf("after '(x,x', expected ',' to be valid, got %v", valid)
}
if !hasClose {
t.Errorf("after '(x,x', expected ')' to be valid, got %v", valid)
}
// Close with ")"
if !rt.Accept(")") {
t.Fatal("failed to accept ')'")
}
if !rt.isAccepting() {
t.Error("should be accepting after '(x,x)'")
}
}
// TestValidInputsNestedCalls tests validInputs with deeply nested production calls.
func TestValidInputsNestedCalls(t *testing.T) {
// Grammar: A = "start" B "end" . B = "middle" .
grammar := `
A = "start" B "end" .
B = "middle" .
`
pda, err := compileString(grammar, "A")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// After "start", should accept "middle" (from B)
rt.Accept("start")
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "middle" {
t.Errorf("after 'start', expected 'middle', got %v", valid)
}
// After "start middle", should accept "end"
rt.Accept("middle")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != "end" {
t.Errorf("after 'start middle', expected 'end', got %v", valid)
}
}
func TestReturnAddressDisambiguation(t *testing.T) {
// Grammar where the same production is called from different contexts:
// S = A "x" | "c" A "y" .
// A = "a" .
grammar := `
S = A "x" | "c" A "y" .
A = "a" .
`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
if !rt.Accept("c") {
t.Fatal("failed to accept 'c'")
}
if !rt.Accept("a") {
t.Fatal("failed to accept 'a'")
}
valid := rt.validInputs()
if len(valid) != 1 || valid[0] != "y" {
t.Errorf("after 'ca', expected only 'y', got %v", valid)
}
rt.Reset()
rt.Accept("c")
rt.Accept("a")
if rt.Accept("x") {
t.Error("should not accept 'x' after 'ca'")
}
}
// TestValidInputsRecursiveWithStack tests validInputs with recursive grammars
// which heavily exercise the stack simulation.
func TestValidInputsRecursiveWithStack(t *testing.T) {
// Grammar: S = "(" S ")" | "x" .
grammar := `S = "(" S ")" | "x" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// Initially: "(" or "x" should be valid
valid := rt.validInputs()
hasParen, hasX := false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("initially expected '(' and 'x', got %v", valid)
}
// After "(": "(" or "x" should be valid (nested S)
rt.Accept("(")
valid = rt.validInputs()
hasParen, hasX = false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("after '(', expected '(' and 'x', got %v", valid)
}
// After "((": "(" or "x" should still be valid
rt.Accept("(")
valid = rt.validInputs()
hasParen, hasX = false, false
for _, v := range valid {
if v == "(" {
hasParen = true
}
if v == "x" {
hasX = true
}
}
if !hasParen || !hasX {
t.Errorf("after '((', expected '(' and 'x', got %v", valid)
}
// After "((x": only ")" should be valid
rt.Accept("x")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != ")" {
t.Errorf("after '((x', expected only ')', got %v", valid)
}
// After "((x)": only ")" should be valid (closing outer)
rt.Accept(")")
valid = rt.validInputs()
if len(valid) != 1 || valid[0] != ")" {
t.Errorf("after '((x)', expected only ')', got %v", valid)
}
}
// TestRejectionAfterValid tests that invalid inputs are rejected
// at various points in the grammar.
func TestRejectionAfterValid(t *testing.T) {
// Grammar: S = "a" "b" .
grammar := `S = "a" "b" .`
pda, err := compileString(grammar, "S")
if err != nil {
t.Fatalf("compile failed: %v", err)
}
rt := newRuntime(pda)
// "b" should be rejected initially
if rt.Accept("b") {
t.Error("'b' should be rejected initially")
}
// Accept "a"
rt.Accept("a")
// "a" should be rejected after "a"
if rt.Accept("a") {
t.Error("'a' should be rejected after 'a'")
}
// "c" should be rejected (not in grammar)
if rt.Accept("c") {
t.Error("'c' should be rejected (not in grammar)")
}
}

View File

@@ -0,0 +1,56 @@
# Example Grammars
This directory contains example EBNF grammars for constrained decoding.
## Usage
```bash
go run -tags mlx ./x/imagegen/cmd/engine/ \
-model /path/to/model \
-prompt "Your prompt" \
-grammar x/grammar/grammars/json.ebnf \
-grammar-start value
```
## Available Grammars
| File | Start Rule | Description |
|------|------------|-------------|
| `json.ebnf` | `value` | Standard JSON (RFC 8259) |
| `expression.ebnf` | `expr` | Arithmetic expressions (+, -, *, /, parens) |
| `identifier.ebnf` | `ident` | Programming language identifiers |
| `boolean.ebnf` | `expr` | Boolean expressions (AND, OR, NOT) |
| `list.ebnf` | `list` | Comma-separated word list |
| `yesno.ebnf` | `response` | Simple yes/no responses |
| `date.ebnf` | `date` | Dates in YYYY-MM-DD format |
| `email.ebnf` | `email` | Basic email addresses |
| `phone.ebnf` | `phone` | US phone numbers |
| `hexcolor.ebnf` | `color` | CSS hex colors (#RGB or #RRGGBB) |
| `url.ebnf` | `url` | HTTP/HTTPS URLs |
## Grammar Syntax
**Note:** Comments are not supported. Grammar files must contain only EBNF productions.
The grammars use EBNF notation:
- `=` defines a production rule
- `|` is alternation (or)
- `{ }` is repetition (zero or more)
- `[ ]` is optional (zero or one)
- `" "` is a literal string
- `…` is a character range (e.g., `"a" … "z"`)
- `.` ends a production
## Writing Custom Grammars
1. Define your grammar in a `.ebnf` file
2. Choose a start rule name
3. Pass `-grammar path/to/grammar.ebnf -grammar-start rulename`
Example custom grammar for RGB colors:
```ebnf
color = "#" hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
hexdigit = "0" "9" | "a" "f" | "A" "F" .
```

View File

@@ -0,0 +1,7 @@
expr = term { " OR " term } .
term = factor { " AND " factor } .
factor = "NOT " factor | atom | "(" expr ")" .
atom = "true" | "false" | ident .
ident = letter { letter | digit } .
letter = "a" "z" | "A" "Z" .
digit = "0" "9" .

View File

@@ -0,0 +1,6 @@
date = year "-" month "-" day .
year = digit digit digit digit .
month = ( "0" digit1to9 ) | ( "1" ( "0" | "1" | "2" ) ) .
day = ( "0" digit1to9 ) | ( ( "1" | "2" ) digit ) | ( "3" ( "0" | "1" ) ) .
digit1to9 = "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -0,0 +1,5 @@
email = localpart "@" domain .
localpart = word { "." word } .
domain = word { "." word } .
word = alphanum { alphanum | "-" | "_" } .
alphanum = "a" "z" | "A" "Z" | "0" "9" .

View File

@@ -0,0 +1,7 @@
expr = term { addop term } .
addop = "+" | "-" .
term = factor { mulop factor } .
mulop = "*" | "/" .
factor = number | "(" expr ")" .
number = [ "-" ] digit { digit } .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -0,0 +1,4 @@
color = "#" ( hex6 | hex3 ) .
hex6 = hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
hex3 = hexdigit hexdigit hexdigit .
hexdigit = "0" "9" | "a" "f" | "A" "F" .

View File

@@ -0,0 +1,3 @@
ident = letter { letter | digit | "_" } .
letter = "a" "z" | "A" "Z" | "_" .
digit = "0" "9" .

View File

@@ -0,0 +1,16 @@
value = object | array | string | number | "true" | "false" | "null" .
object = "{" [ members ] "}" .
members = pair { "," pair } .
pair = string ":" value .
array = "[" [ elements ] "]" .
elements = value { "," value } .
string = "\"" { char } "\"" .
char = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
onenine = "1" "9" .
digit = "0" "9" .

View File

@@ -0,0 +1,27 @@
root = array .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -0,0 +1,4 @@
list = item { ", " item } .
item = word .
word = letter { letter } .
letter = "a" "z" | "A" "Z" .

View File

@@ -0,0 +1,19 @@
root = "[" ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person { "," ws person } ws "]" .
person = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
name_field = "\"" "n" "a" "m" "e" "\"" ws ":" ws string .
age_field = "\"" "a" "g" "e" "\"" ws ":" ws number .
email_field = "\"" "e" "m" "a" "i" "l" "\"" ws ":" ws string .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" "[" | "]" "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
number = [ "-" ] integer .
integer = "0" | onenine { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -0,0 +1,15 @@
root = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
name_field = "\"name\"" ws ":" ws string .
age_field = "\"age\"" ws ":" ws number .
email_field = "\"email\"" ws ":" ws string .
string = "\"" { character } "\"" .
character = " " | "!" | "#" "~" .
number = [ "-" ] integer .
integer = "0" | onenine { digit } .
digit = "0" "9" .
onenine = "1" "9" .
ws = { " " | "\t" | "\n" | "\r" } .

View File

@@ -0,0 +1,7 @@
phone = parenformat | dashformat .
parenformat = "(" areacode ") " exchange "-" subscriber .
dashformat = areacode "-" exchange "-" subscriber .
areacode = digit digit digit .
exchange = digit digit digit .
subscriber = digit digit digit digit .
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .

View File

@@ -0,0 +1,11 @@
url = scheme "://" host [ ":" port ] [ path ] [ query ] .
scheme = "http" | "https" .
host = word { "." word } .
port = digit { digit } .
path = "/" { pathseg } .
pathseg = word [ "/" ] .
query = "?" param { "&" param } .
param = word "=" word .
word = alphanum { alphanum | "-" | "_" } .
alphanum = "a" "z" | "A" "Z" | "0" "9" .
digit = "0" "9" .

View File

@@ -0,0 +1,3 @@
response = affirmative | negative .
affirmative = "yes" | "Yes" | "YES" | "y" | "Y" | "true" | "True" .
negative = "no" | "No" | "NO" | "n" | "N" | "false" | "False" .

69
x/grammar/json.go Normal file
View File

@@ -0,0 +1,69 @@
//go:build mlx
package grammar
// JSONGrammarEBNF is the EBNF grammar for JSON (character-level).
// Based on https://www.json.org/json-en.html
//
// This grammar operates at the character level. The engine validates
// tokens by matching them as sequences of these character-level terminals.
const JSONGrammarEBNF = `
json = value .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .
`
// JSONObjectGrammarEBNF is like JSONGrammarEBNF but only allows objects at the top level.
const JSONObjectGrammarEBNF = `
json = object .
value = object | array | string | number | "true" | "false" | "null" .
object = "{" ws "}" | "{" members "}" .
members = member { "," member } .
member = ws string ws ":" element .
array = "[" ws "]" | "[" elements "]" .
elements = element { "," element } .
element = ws value ws .
string = "\"" { character } "\"" .
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
hex = "0" … "9" | "A" … "F" | "a" … "f" .
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
digit = "0" … "9" .
onenine = "1" … "9" .
ws = { " " | "\t" | "\n" | "\r" } .
`

726
x/grammar/schema/schema.go Normal file
View File

@@ -0,0 +1,726 @@
//go:build mlx
// Package schema converts OpenAI-compatible JSON Schema into constrained grammars.
package schema
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
"github.com/ollama/ollama/x/grammar"
)
// schemaNode represents OpenAI-compatible JSON Schema for structured outputs.
// See: https://platform.openai.com/docs/guides/structured-outputs
type schemaNode struct {
// Core types
Type interface{} `json:"type"` // string, []string, or nil
// Object properties
Properties map[string]*schemaNode `json:"properties"`
Required []string `json:"required"`
AdditionalProperties interface{} `json:"additionalProperties"`
// Array properties
Items *schemaNode `json:"items"`
MinItems *int `json:"minItems"`
MaxItems *int `json:"maxItems"`
// String properties
Pattern string `json:"pattern"` // Regex pattern
Format string `json:"format"` // date-time, email, uuid, etc.
// Number properties (noted but not enforced in grammar - validated post-generation)
Minimum *float64 `json:"minimum"`
Maximum *float64 `json:"maximum"`
ExclusiveMinimum *float64 `json:"exclusiveMinimum"`
ExclusiveMaximum *float64 `json:"exclusiveMaximum"`
MultipleOf *float64 `json:"multipleOf"`
// Enum and const
Enum []interface{} `json:"enum"`
Const interface{} `json:"const"`
// Composition
AnyOf []*schemaNode `json:"anyOf"`
OneOf []*schemaNode `json:"oneOf"` // Treated same as anyOf for grammar
// References and definitions
Ref string `json:"$ref"`
Defs map[string]*schemaNode `json:"$defs"`
// Description (ignored for grammar but useful for docs)
Description string `json:"description"`
}
// converter handles JSON Schema to EBNF conversion with state.
type converter struct {
schema *schemaNode
definitions map[string]*schemaNode // Resolved $defs
usedTypes map[string]bool
rules []string
ruleNum int
definedRefs map[string]bool // Track which refs we've already defined as rules
}
// EBNF converts a JSON Schema to EBNF grammar
func EBNF(schemaJSON string) (string, error) {
var schema schemaNode
if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
return "", fmt.Errorf("failed to parse JSON Schema: %w", err)
}
conv := &converter{
schema: &schema,
definitions: schema.Defs,
usedTypes: make(map[string]bool),
definedRefs: make(map[string]bool),
}
return conv.convert()
}
func (c *converter) convert() (string, error) {
var b strings.Builder
// Generate root rule
rootExpr := c.schemaToExpr(c.schema, "root")
b.WriteString("root = ")
b.WriteString(rootExpr)
b.WriteString(" .\n")
// Add generated rules (refs, items, etc.)
for _, rule := range c.rules {
b.WriteString(rule)
b.WriteString("\n")
}
// Add primitives based on usage
c.addPrimitives(&b)
return b.String(), nil
}
func (c *converter) addPrimitives(b *strings.Builder) {
if c.usedTypes["string"] {
b.WriteString(`
string = "\"" { character } "\"" .
`)
}
if c.usedTypes["string"] || c.usedTypes["character"] {
b.WriteString(`
character = unescaped | escaped .
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
unicode = "u" hex hex hex hex .
`)
}
if c.usedTypes["number"] {
b.WriteString(`
number = [ "-" ] integer [ fraction ] [ exponent ] .
integer = "0" | onenine { digit } .
fraction = "." digit { digit } .
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
`)
}
if c.usedTypes["integer"] {
b.WriteString(`
int = [ "-" ] ( "0" | onenine { digit } ) .
`)
}
if c.usedTypes["number"] || c.usedTypes["integer"] || c.usedTypes["digit"] {
b.WriteString(`
digit = "0" … "9" .
`)
}
// onenine only needed for number/integer, not for digit-only formats
if c.usedTypes["number"] || c.usedTypes["integer"] {
b.WriteString(`onenine = "1" … "9" .
`)
}
if c.usedTypes["string"] || c.usedTypes["character"] || c.usedTypes["hex"] {
b.WriteString(`
hex = "0" … "9" | "A" … "F" | "a" … "f" .
`)
}
if c.usedTypes["ws"] {
b.WriteString(`
ws = { " " | "\t" | "\n" | "\r" } .
`)
}
}
func (c *converter) schemaToExpr(schema *schemaNode, name string) string {
if schema == nil {
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "( string | number | object | array | \"true\" | \"false\" | \"null\" )"
}
// Handle $ref first
if schema.Ref != "" {
return c.resolveRef(schema.Ref)
}
// Handle const
if schema.Const != nil {
return c.constToExpr(schema.Const)
}
// Handle enum
if len(schema.Enum) > 0 {
return c.enumToExpr(schema.Enum)
}
// Handle anyOf / oneOf
if len(schema.AnyOf) > 0 {
return c.anyOfToExpr(schema.AnyOf, name)
}
if len(schema.OneOf) > 0 {
return c.anyOfToExpr(schema.OneOf, name)
}
// Handle type
types := c.getTypes(schema.Type)
if len(types) == 0 {
// No type specified, could be anything
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "( string | number | \"true\" | \"false\" | \"null\" )"
}
if len(types) == 1 {
return c.typeToExpr(types[0], schema, name)
}
// Multiple types (e.g., ["string", "null"])
var parts []string
for _, t := range types {
parts = append(parts, c.typeToExpr(t, schema, name))
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) typeToExpr(typeName string, schema *schemaNode, name string) string {
switch typeName {
case "object":
return c.objectToExpr(schema, name)
case "array":
return c.arrayToExpr(schema, name)
case "string":
return c.stringToExpr(schema, name)
case "number":
c.usedTypes["number"] = true
return "number"
case "integer":
c.usedTypes["integer"] = true
c.usedTypes["digit"] = true
return "int"
case "boolean":
return `( "true" | "false" )`
case "null":
return `"null"`
default:
c.usedTypes["string"] = true
c.usedTypes["number"] = true
return "string"
}
}
func (c *converter) objectToExpr(schema *schemaNode, name string) string {
c.usedTypes["ws"] = true
if len(schema.Properties) == 0 {
return `"{" ws "}"`
}
// Sort properties for deterministic output
// Required properties come first, in their required order
var propOrder []string
requiredSet := make(map[string]bool)
for _, r := range schema.Required {
requiredSet[r] = true
propOrder = append(propOrder, r)
}
// Add any non-required properties (though OpenAI requires all to be required)
var optionalProps []string
for propName := range schema.Properties {
if !requiredSet[propName] {
optionalProps = append(optionalProps, propName)
}
}
sort.Strings(optionalProps)
propOrder = append(propOrder, optionalProps...)
var propExprs []string
first := true
for _, propName := range propOrder {
propSchema, exists := schema.Properties[propName]
if !exists {
continue
}
propExpr := c.schemaToExpr(propSchema, propName)
prefix := ""
if !first {
prefix = `"," ws `
}
first = false
propExprs = append(propExprs, fmt.Sprintf(`%s"\"%s\"" ws ":" ws %s`, prefix, propName, propExpr))
}
if len(propExprs) == 0 {
return `"{" ws "}"`
}
return `"{" ws ` + strings.Join(propExprs, " ") + ` ws "}"`
}
func (c *converter) arrayToExpr(schema *schemaNode, name string) string {
c.usedTypes["ws"] = true
itemExpr := "value"
if schema.Items != nil {
itemExpr = c.schemaToExpr(schema.Items, name+"_item")
} else {
c.usedTypes["string"] = true
c.usedTypes["number"] = true
}
// Create item rule
c.ruleNum++
itemRule := fmt.Sprintf("item%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", itemRule, itemExpr))
// Handle minItems/maxItems
if schema.MinItems != nil || schema.MaxItems != nil {
return c.arrayWithBounds(itemRule, schema.MinItems, schema.MaxItems)
}
// Default: zero or more items
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
}
func (c *converter) arrayWithBounds(itemRule string, minItems, maxItems *int) string {
min := 0
max := -1 // unlimited
if minItems != nil {
min = *minItems
}
if maxItems != nil {
max = *maxItems
}
if min == 0 && max < 0 {
// No constraints
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
}
if min == 0 && max == 0 {
return `"[" ws "]"`
}
// Build pattern for bounded array
// For min=2, max=4: item "," item [ "," item ] [ "," item ]
var parts []string
// Required items
for i := 0; i < min; i++ {
if i > 0 {
parts = append(parts, `"," ws`)
}
parts = append(parts, itemRule)
}
// Optional items up to max
if max > min {
for i := min; i < max; i++ {
if i == 0 {
parts = append(parts, fmt.Sprintf(`[ %s`, itemRule))
} else {
parts = append(parts, fmt.Sprintf(`[ "," ws %s`, itemRule))
}
}
// Close all optional brackets
for i := min; i < max; i++ {
parts = append(parts, "]")
}
} else if max < 0 {
// Unlimited after min
if min > 0 {
parts = append(parts, fmt.Sprintf(`{ "," ws %s }`, itemRule))
} else {
parts = append(parts, fmt.Sprintf(`[ %s { "," ws %s } ]`, itemRule, itemRule))
}
}
if min == 0 {
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s ws "]" )`, strings.Join(parts, " "))
}
return fmt.Sprintf(`"[" ws %s ws "]"`, strings.Join(parts, " "))
}
func (c *converter) stringToExpr(schema *schemaNode, name string) string {
// Handle format
if schema.Format != "" {
return c.formatToExpr(schema.Format)
}
// Handle pattern (regex)
if schema.Pattern != "" {
return c.patternToExpr(schema.Pattern, name)
}
// Default string
c.usedTypes["string"] = true
if name == "root" {
c.usedTypes["character"] = true
return `"\"" { character } "\""`
}
return "string"
}
func (c *converter) formatToExpr(format string) string {
switch format {
case "date":
// YYYY-MM-DD
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("date%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "\"" .`, ruleName))
return ruleName
case "time":
// HH:MM:SS
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("time%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit ":" digit digit ":" digit digit "\"" .`, ruleName))
return ruleName
case "date-time":
// YYYY-MM-DDTHH:MM:SSZ or with offset
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("datetime%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\"" .`, ruleName))
return ruleName
case "email":
// Simplified email pattern
c.ruleNum++
ruleName := fmt.Sprintf("email%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" emailchar { emailchar } "@" emailchar { emailchar } "." emailchar { emailchar } "\"" .`, ruleName))
c.rules = append(c.rules, `emailchar = "a" … "z" | "A" … "Z" | "0" … "9" | "." | "-" | "_" .`)
return ruleName
case "uuid":
// 8-4-4-4-12 hex pattern
c.ruleNum++
ruleName := fmt.Sprintf("uuid%d", c.ruleNum)
c.usedTypes["hex"] = true
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\"" .`, ruleName))
return ruleName
case "ipv4":
c.ruleNum++
c.usedTypes["digit"] = true
ruleName := fmt.Sprintf("ipv4_%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit { digit } "." digit { digit } "." digit { digit } "." digit { digit } "\"" .`, ruleName))
return ruleName
case "uri", "hostname":
// Fallback to general string for complex formats
c.usedTypes["string"] = true
return "string"
default:
c.usedTypes["string"] = true
return "string"
}
}
func (c *converter) patternToExpr(pattern string, name string) string {
// Try to convert simple regex patterns to EBNF
// This handles common cases; complex regex falls back to string
// Remove anchors
pattern = strings.TrimPrefix(pattern, "^")
pattern = strings.TrimSuffix(pattern, "$")
// Try to parse and convert
expr, ok := c.regexToEBNF(pattern)
if !ok {
// Fallback to general string
c.usedTypes["string"] = true
return "string"
}
c.ruleNum++
ruleName := fmt.Sprintf("pattern%d", c.ruleNum)
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" %s "\"" .`, ruleName, expr))
return ruleName
}
func (c *converter) regexToEBNF(pattern string) (string, bool) {
// Simple regex to EBNF converter
// Handles: literals, [a-z], [A-Z], [0-9], +, *, ?, basic groups
var result strings.Builder
i := 0
for i < len(pattern) {
ch := pattern[i]
switch ch {
case '[':
// Character class
end := strings.Index(pattern[i:], "]")
if end == -1 {
return "", false
}
class := pattern[i+1 : i+end]
ebnfClass, ok := c.charClassToEBNF(class)
if !ok {
return "", false
}
result.WriteString(ebnfClass)
i += end + 1
case '(':
// Group - find matching )
depth := 1
start := i + 1
j := start
for j < len(pattern) && depth > 0 {
if pattern[j] == '(' {
depth++
} else if pattern[j] == ')' {
depth--
}
j++
}
if depth != 0 {
return "", false
}
groupContent := pattern[start : j-1]
groupExpr, ok := c.regexToEBNF(groupContent)
if !ok {
return "", false
}
result.WriteString("( ")
result.WriteString(groupExpr)
result.WriteString(" )")
i = j
case '|':
result.WriteString(" | ")
i++
case '+':
// One or more - wrap previous in { } and add one required
// This is a simplification
return "", false // TODO: handle properly
case '*':
// Zero or more - need to wrap previous
return "", false // TODO: handle properly
case '?':
// Optional - need to wrap previous in [ ]
return "", false // TODO: handle properly
case '\\':
// Escape sequence
if i+1 >= len(pattern) {
return "", false
}
next := pattern[i+1]
switch next {
case 'd':
result.WriteString("digit")
c.usedTypes["digit"] = true
case 'w':
result.WriteString(`( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`)
case 's':
result.WriteString(`( " " | "\t" )`)
default:
result.WriteString(fmt.Sprintf(`"%c"`, next))
}
i += 2
default:
// Literal character
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' || ch == '.' {
result.WriteString(fmt.Sprintf(`"%c" `, ch))
} else {
// Special char, try to escape
result.WriteString(fmt.Sprintf(`"%c" `, ch))
}
i++
}
}
return strings.TrimSpace(result.String()), true
}
func (c *converter) charClassToEBNF(class string) (string, bool) {
// Handle character classes like a-z, A-Z, 0-9
if class == "a-zA-Z0-9_" || class == "a-zA-Z_" {
return `( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`, true
}
if class == "a-zA-Z0-9" {
return `( "a" … "z" | "A" … "Z" | "0" … "9" )`, true
}
if class == "a-z" {
return `"a" … "z"`, true
}
if class == "A-Z" {
return `"A" … "Z"`, true
}
if class == "0-9" {
c.usedTypes["digit"] = true
return "digit", true
}
// Try to parse range patterns
if matched, _ := regexp.MatchString(`^[a-zA-Z]-[a-zA-Z]$`, class); matched {
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
}
if matched, _ := regexp.MatchString(`^[0-9]-[0-9]$`, class); matched {
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
}
return "", false
}
func (c *converter) anyOfToExpr(schemas []*schemaNode, name string) string {
var parts []string
for i, s := range schemas {
expr := c.schemaToExpr(s, fmt.Sprintf("%s_opt%d", name, i))
parts = append(parts, expr)
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) enumToExpr(values []interface{}) string {
var parts []string
for _, v := range values {
parts = append(parts, c.constToExpr(v))
}
return "( " + strings.Join(parts, " | ") + " )"
}
func (c *converter) constToExpr(v interface{}) string {
switch val := v.(type) {
case string:
return fmt.Sprintf(`"\"%s\""`, c.escapeString(val))
case float64:
if val == float64(int(val)) {
return fmt.Sprintf(`"%d"`, int(val))
}
return fmt.Sprintf(`"%v"`, val)
case bool:
if val {
return `"true"`
}
return `"false"`
case nil:
return `"null"`
default:
c.usedTypes["string"] = true
return "string"
}
}
func (c *converter) resolveRef(ref string) string {
// Handle #/$defs/name references
if strings.HasPrefix(ref, "#/$defs/") {
defName := strings.TrimPrefix(ref, "#/$defs/")
return c.resolveDefRef(defName)
}
// Handle root recursion #
if ref == "#" {
return "root"
}
// Unknown ref format
c.usedTypes["string"] = true
return "string"
}
func (c *converter) resolveDefRef(defName string) string {
// Check if we've already defined this as a rule
ruleName := "def_" + defName
if c.definedRefs[defName] {
return ruleName
}
// Mark as defined to prevent infinite recursion
c.definedRefs[defName] = true
// Look up the definition
if c.definitions == nil {
c.usedTypes["string"] = true
return "string"
}
defSchema, ok := c.definitions[defName]
if !ok {
c.usedTypes["string"] = true
return "string"
}
// Generate the rule
expr := c.schemaToExpr(defSchema, ruleName)
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", ruleName, expr))
return ruleName
}
func (c *converter) getTypes(t interface{}) []string {
switch v := t.(type) {
case string:
return []string{v}
case []interface{}:
var types []string
for _, item := range v {
if s, ok := item.(string); ok {
types = append(types, s)
}
}
return types
}
return nil
}
func (c *converter) escapeString(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `"`, `\"`)
return s
}
// Grammar converts a JSON Schema string into a compiled grammar.
func Grammar(schemaJSON string) (*grammar.Grammar, error) {
ebnf, err := EBNF(schemaJSON)
if err != nil {
return nil, err
}
return grammar.ParseEBNF(ebnf, "root")
}

View File

@@ -0,0 +1,336 @@
//go:build mlx
package schema
import (
"testing"
gram "github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/imagegen/mlx"
)
func TestJSONEBNF(t *testing.T) {
tests := []struct {
name string
schema string
}{
{
name: "simple object",
schema: `{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`,
},
{
name: "with enum",
schema: `{
"type": "object",
"properties": {
"status": {"enum": ["active", "inactive", "pending"]}
},
"required": ["status"]
}`,
},
{
name: "array of objects",
schema: `{
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "integer"}
},
"required": ["id"]
}
}`,
},
{
name: "nested object",
schema: `{
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"email": {"type": "string"}
},
"required": ["email"]
}
},
"required": ["user"]
}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ebnf, err := EBNF(tc.schema)
if err != nil {
t.Fatalf("EBNF failed: %v", err)
}
// Try to compile it
grammar, err := gram.ParseEBNF(ebnf, "root")
if err != nil {
t.Fatalf("ParseEBNF failed: %v", err)
}
if grammar == nil {
t.Fatal("grammar is nil")
}
})
}
}
func TestGrammarEngine(t *testing.T) {
schema := `{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}`
grammar, err := Grammar(schema)
if err != nil {
t.Fatalf("Grammar failed: %v", err)
}
vocab := []string{
"{", "}", "[", "]", ":", ",",
"\"name\"", "\"age\"", "\"test\"",
"\"", "a", "b", "c",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
" ", "\n",
"true", "false", "null",
}
engine, err := gram.NewEngine(grammar, vocab)
if err != nil {
t.Fatalf("grammar.NewEngine failed: %v", err)
}
defer engine.Close()
logits := mlx.Ones(int32(len(vocab)))
mlx.Keep(logits)
// Test that we can apply mask
masked := engine.ApplyMask(logits)
mlx.Eval(masked)
}
// TestOpenAIStructuredOutputs tests features required for OpenAI compatibility
func TestOpenAIStructuredOutputs(t *testing.T) {
tests := []struct {
name string
schema string
}{
{
name: "anyOf union",
schema: `{
"type": "object",
"properties": {
"value": {
"anyOf": [
{"type": "string"},
{"type": "integer"}
]
}
},
"required": ["value"]
}`,
},
{
name: "nullable string via type array",
schema: `{
"type": "object",
"properties": {
"name": {"type": ["string", "null"]}
},
"required": ["name"]
}`,
},
{
name: "$ref with $defs",
schema: `{
"type": "object",
"properties": {
"person": {"$ref": "#/$defs/Person"}
},
"required": ["person"],
"$defs": {
"Person": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}
}
}`,
},
{
name: "const value",
schema: `{
"type": "object",
"properties": {
"type": {"const": "user"}
},
"required": ["type"]
}`,
},
{
name: "format date-time",
schema: `{
"type": "object",
"properties": {
"created": {"type": "string", "format": "date-time"}
},
"required": ["created"]
}`,
},
{
name: "format date",
schema: `{
"type": "object",
"properties": {
"birthday": {"type": "string", "format": "date"}
},
"required": ["birthday"]
}`,
},
{
name: "format email",
schema: `{
"type": "object",
"properties": {
"email": {"type": "string", "format": "email"}
},
"required": ["email"]
}`,
},
{
name: "format uuid",
schema: `{
"type": "object",
"properties": {
"id": {"type": "string", "format": "uuid"}
},
"required": ["id"]
}`,
},
{
name: "array with minItems maxItems",
schema: `{
"type": "object",
"properties": {
"tags": {
"type": "array",
"items": {"type": "string"},
"minItems": 1,
"maxItems": 3
}
},
"required": ["tags"]
}`,
},
{
name: "deeply nested with refs",
schema: `{
"type": "object",
"properties": {
"company": {
"type": "object",
"properties": {
"name": {"type": "string"},
"employees": {
"type": "array",
"items": {"$ref": "#/$defs/Employee"}
}
},
"required": ["name", "employees"]
}
},
"required": ["company"],
"$defs": {
"Employee": {
"type": "object",
"properties": {
"name": {"type": "string"},
"role": {"enum": ["engineer", "manager", "intern"]}
},
"required": ["name", "role"]
}
}
}`,
},
{
name: "multiple refs same def",
schema: `{
"type": "object",
"properties": {
"from": {"$ref": "#/$defs/Address"},
"to": {"$ref": "#/$defs/Address"}
},
"required": ["from", "to"],
"$defs": {
"Address": {
"type": "object",
"properties": {
"city": {"type": "string"},
"zip": {"type": "string"}
},
"required": ["city", "zip"]
}
}
}`,
},
{
name: "oneOf variant",
schema: `{
"type": "object",
"properties": {
"result": {
"oneOf": [
{
"type": "object",
"properties": {"success": {"type": "boolean"}},
"required": ["success"]
},
{
"type": "object",
"properties": {"error": {"type": "string"}},
"required": ["error"]
}
]
}
},
"required": ["result"]
}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ebnf, err := EBNF(tc.schema)
if err != nil {
t.Fatalf("EBNF failed: %v", err)
}
grammar, err := gram.ParseEBNF(ebnf, "root")
if err != nil {
t.Fatalf("ParseEBNF failed: %v", err)
}
if grammar == nil {
t.Fatal("grammar is nil")
}
})
}
}

105
x/grammar/terminal.go Normal file
View File

@@ -0,0 +1,105 @@
//go:build mlx
package grammar
import "unicode/utf8"
// terminalType distinguishes different kinds of grammar terminals
type terminalType int
const (
terminalLiteral terminalType = iota // Exact string: "true", "{"
terminalRange // Character range: [a-z], [0-9]
)
// terminal represents a compiled grammar terminal
type terminal struct {
ID int
Type terminalType
Pattern string // Original pattern from grammar
Unescaped string // Unescaped literal (for terminalLiteral)
LowRune rune // For unicode ranges: low bound
HighRune rune // For unicode ranges: high bound
}
// terminalMatch represents a terminal that matched at a position
type terminalMatch struct {
TerminalID int
Length int // Number of bytes consumed
}
// trieNode is a node in the literal matching trie
type trieNode struct {
children [256]*trieNode // Byte-indexed children
terminalID int // -1 if not accepting, else terminal ID
}
// terminalMatcher tests which terminals match at a position in a byte slice
type terminalMatcher struct {
// Trie for literal matching (fast path)
literalTrie *trieNode
// Range terminals (single-byte matches)
ranges []terminal
// All terminals for enumeration
terminals []terminal
// Pattern to terminal ID map for fast lookup (keyed by raw pattern)
patternToID map[string]int
}
// addLiteralToTrie adds a literal pattern to the trie
func (m *terminalMatcher) addLiteralToTrie(pattern string, terminalID int) {
node := m.literalTrie
for i := 0; i < len(pattern); i++ {
c := pattern[i]
if node.children[c] == nil {
node.children[c] = &trieNode{terminalID: -1}
}
node = node.children[c]
}
node.terminalID = terminalID
}
// matchesAt returns all terminals that match at pos in data
func (m *terminalMatcher) matchesAt(data []byte, pos int) []terminalMatch {
if pos >= len(data) {
return nil
}
var matches []terminalMatch
// Check literal matches via trie
node := m.literalTrie
for i := pos; i < len(data) && node != nil; i++ {
c := data[i]
node = node.children[c]
if node != nil && node.terminalID >= 0 {
matches = append(matches, terminalMatch{
TerminalID: node.terminalID,
Length: i - pos + 1,
})
}
}
// Check range matches (unicode-aware)
r, runeLen := utf8.DecodeRune(data[pos:])
if r != utf8.RuneError {
for _, rng := range m.ranges {
if r >= rng.LowRune && r <= rng.HighRune {
matches = append(matches, terminalMatch{
TerminalID: rng.ID,
Length: runeLen,
})
}
}
}
return matches
}
// terminalCount returns the number of terminals
func (m *terminalMatcher) terminalCount() int {
return len(m.terminals)
}

View File

@@ -234,17 +234,3 @@ ollama create z-image
3. Copy config files (*.json) as config layers
4. Write manifest
```
## FP8 Quantization
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
### Usage
```bash
cd ./weights/Z-Image-Turbo
ollama create z-image-fp8 --quantize fp8
```
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.

View File

@@ -1,8 +1,10 @@
package api
import (
"encoding/base64"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"time"
@@ -99,10 +101,10 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imageBase64 string
var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
imagePath = extractPath(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
@@ -116,14 +118,14 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
return
}
c.SSEvent("done", buildResponse(imageBase64, format))
c.SSEvent("done", buildResponse(imagePath, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string
var imagePath string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
imagePath = extractPath(resp.Content)
}
})
if err != nil {
@@ -131,7 +133,7 @@ func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.
return
}
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
c.JSON(http.StatusOK, buildResponse(imagePath, format))
}
func parseSize(size string) (int32, int32) {
@@ -150,9 +152,9 @@ func parseSize(size string) (int32, int32) {
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
func extractPath(content string) string {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
return strings.TrimSpace(content[idx+16:])
}
return ""
}
@@ -163,21 +165,23 @@ func parseProgress(content string) ImageProgressEvent {
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imageBase64, format string) ImageGenerationResponse {
func buildResponse(imagePath, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imageBase64 == "" {
if imagePath == "" {
return resp
}
if format == "url" {
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
resp.Data[0].URL = "file://" + imagePath
} else {
resp.Data[0].B64JSON = imageBase64
data, err := os.ReadFile(imagePath)
if err == nil {
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
}
}
return resp

View File

@@ -1,197 +0,0 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
// It caches the transformer output and reuses it when timestep values
// are similar between consecutive steps.
//
// For CFG (classifier-free guidance), it caches pos and neg predictions
// separately and always computes CFG fresh to avoid error amplification.
//
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
// https://github.com/ali-vilab/TeaCache
type TeaCache struct {
// Cached transformer output from last computed step (non-CFG mode)
cachedOutput *mlx.Array
// Cached CFG outputs (pos and neg separately)
cachedPosOutput *mlx.Array
cachedNegOutput *mlx.Array
// Previous timestep value for difference calculation
prevTimestep float32
// Accumulated difference for rescaling
accumulatedDiff float32
// Configuration
threshold float32 // Threshold for recomputation decision
rescaleFactor float32 // Model-specific rescaling factor
skipEarlySteps int // Number of early steps to never cache
// Statistics
cacheHits int
cacheMisses int
}
// TeaCacheConfig holds configuration for TeaCache.
type TeaCacheConfig struct {
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
// Recommended: 0.05-0.15 for image models
Threshold float32
// Rescale factor to adjust timestep embedding differences.
// Model-specific, typically 1.0-2.0
RescaleFactor float32
// SkipEarlySteps: number of early steps to always compute (never cache).
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
SkipEarlySteps int
}
// DefaultTeaCacheConfig returns default configuration for TeaCache.
func DefaultTeaCacheConfig() *TeaCacheConfig {
return &TeaCacheConfig{
Threshold: 0.1,
RescaleFactor: 1.0,
}
}
// NewTeaCache creates a new TeaCache instance.
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
if cfg == nil {
cfg = DefaultTeaCacheConfig()
}
return &TeaCache{
threshold: cfg.Threshold,
rescaleFactor: cfg.RescaleFactor,
skipEarlySteps: cfg.SkipEarlySteps,
}
}
// ShouldCompute determines if we should compute the full forward pass
// or reuse the cached output based on timestep similarity.
//
// Algorithm:
// 1. First step always computes
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
// 3. If accumulated difference > threshold, compute new output
// 4. Otherwise, reuse cached output
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
// Always compute early steps (critical for structure)
// Check both regular cache and CFG cache
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
return true
}
// Compute absolute difference between current and previous timestep
diff := timestep - tc.prevTimestep
if diff < 0 {
diff = -diff
}
// Apply rescaling factor
scaledDiff := diff * tc.rescaleFactor
// Accumulate difference (helps track drift over multiple cached steps)
tc.accumulatedDiff += scaledDiff
// Decision based on accumulated difference
if tc.accumulatedDiff > tc.threshold {
tc.accumulatedDiff = 0 // Reset accumulator
return true
}
return false
}
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
// Free previous cached output
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
}
// Store new cached values
tc.cachedOutput = output
tc.prevTimestep = timestep
tc.cacheMisses++
}
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
// This allows CFG to be computed fresh each step, avoiding error amplification.
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
// Free previous cached outputs
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
}
// Store new cached values
tc.cachedPosOutput = posOutput
tc.cachedNegOutput = negOutput
tc.prevTimestep = timestep
tc.cacheMisses++
}
// GetCached returns the cached output (non-CFG mode).
func (tc *TeaCache) GetCached() *mlx.Array {
tc.cacheHits++
return tc.cachedOutput
}
// GetCFGCached returns cached pos and neg outputs for CFG mode.
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
tc.cacheHits++
return tc.cachedPosOutput, tc.cachedNegOutput
}
// HasCFGCache returns true if CFG cache is available.
func (tc *TeaCache) HasCFGCache() bool {
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
}
// Arrays returns all arrays that should be kept alive.
func (tc *TeaCache) Arrays() []*mlx.Array {
var arrays []*mlx.Array
if tc.cachedOutput != nil {
arrays = append(arrays, tc.cachedOutput)
}
if tc.cachedPosOutput != nil {
arrays = append(arrays, tc.cachedPosOutput)
}
if tc.cachedNegOutput != nil {
arrays = append(arrays, tc.cachedNegOutput)
}
return arrays
}
// Stats returns cache hit/miss statistics.
func (tc *TeaCache) Stats() (hits, misses int) {
return tc.cacheHits, tc.cacheMisses
}
// Free releases all cached arrays.
func (tc *TeaCache) Free() {
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
tc.cachedOutput = nil
}
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
tc.cachedPosOutput = nil
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
tc.cachedNegOutput = nil
}
}

View File

@@ -44,66 +44,63 @@ func DefaultOptions() ImageGenOptions {
}
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
// Show displays information about an image generation model.
func Show(modelName string, w io.Writer) error {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
return fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Count total size
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalSize += layer.Size
}
}
// Read model_index.json for architecture, parameter count, and quantization
// Read model_index.json for architecture
var architecture string
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
Architecture string `json:"architecture"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
architecture = index.Architecture
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
paramCount := totalSize / 2
paramStr := formatParamCount(paramCount)
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
// Print Model info
fmt.Fprintln(w, " Model")
if architecture != "" {
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
}
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
fmt.Fprintln(w)
return info, nil
// Print Capabilities
fmt.Fprintln(w, " Capabilities")
fmt.Fprintf(w, " %s\n", "image")
fmt.Fprintln(w)
return nil
}
// formatParamCount formats parameter count as human-readable string.
func formatParamCount(count int64) string {
if count >= 1_000_000_000 {
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
}
if count >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
}
return fmt.Sprintf("%d", count)
}
// RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models.
@@ -186,7 +183,8 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
@@ -205,9 +203,11 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return nil
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
}
return nil
@@ -218,27 +218,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
if imageBase64 != "" {
// Decode base64 and save to CWD
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
return fmt.Errorf("failed to decode image: %w", err)
}
// Create filename from prompt
safeName := sanitizeFilename(prompt)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0644); err != nil {
return fmt.Errorf("failed to save image: %w", err)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
if imagePath != "" {
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
}
return nil
@@ -324,7 +306,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
p.Add("", spinner)
var stepBar *progress.StepBar
var imageBase64 string
var imagePath string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
@@ -344,9 +326,11 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
return nil
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
}
return nil
@@ -358,30 +342,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
continue
}
// Save image to current directory with descriptive name
if imageBase64 != "" {
// Decode base64 image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
continue
}
// Copy image to current directory with descriptive name
if imagePath != "" {
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0644); err != nil {
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
continue
// Copy file to CWD
if err := copyFile(imagePath, newName); err != nil {
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
} else {
displayImageInTerminal(newName)
fmt.Printf("Image saved to: %s\n", newName)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
fmt.Println()
@@ -402,6 +381,24 @@ func sanitizeFilename(s string) string {
return result.String()
}
// copyFile copies a file from src to dst.
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")

View File

@@ -29,10 +29,9 @@ const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
@@ -59,77 +58,18 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
return imagegen.LayerInfo{}, err
}
layer.Name = name
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
@@ -179,7 +119,7 @@ func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) err
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err

View File

@@ -1,120 +0,0 @@
//go:build mlx
package client
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
// and returns safetensors data for the quantized weights, scales, and biases.
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
}
// Convert to BFloat16 if needed (quantize expects float type)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize with affine mode: group_size=32, bits=8
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
}
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
func QuantizeSupported() bool {
return true
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}

View File

@@ -1,18 +0,0 @@
//go:build !mlx
package client
import (
"fmt"
"io"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

View File

Binary file not shown.

View File

@@ -8,6 +8,7 @@ import (
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/grammar"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -109,7 +110,11 @@ type input struct {
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
WiredLimitGB int // Metal wired memory limit in GB (default 32)
JSONMode bool // Enable JSON grammar constraint
GrammarEBNF string // Raw EBNF grammar string
GrammarStart string // Start rule name for grammar
Vocab []string // Vocabulary for constrained decoding
}
type output struct {
@@ -127,9 +132,11 @@ type Decoder struct {
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
grammar *grammar.Engine // Optional grammar constraint engine
grammarVocab []string // Vocab for grammar debug
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
@@ -145,6 +152,12 @@ func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
}
}
// SetGrammar enables constrained decoding with the given grammar engine.
func (d *Decoder) SetGrammar(g *grammar.Engine, vocab []string) {
d.grammar = g
d.grammarVocab = vocab
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
@@ -222,6 +235,16 @@ func (d *Decoder) prefill(inputIDs []int32) int {
} else {
logits = d.model.Forward(x, d.caches)
}
// Apply grammar constraints if enabled
if d.grammar != nil {
shape := logits.Shape()
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
maskedLogits := d.grammar.ApplyMask(lastLogits)
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
@@ -245,6 +268,15 @@ func (d *Decoder) prefill(inputIDs []int32) int {
func (d *Decoder) step() int32 {
prevToken := d.token
// Sync on previous token FIRST to get its value and update grammar state
// This must happen before computing the next mask
val := prevToken.ItemInt32()
// Update grammar state with the token we just synced
if d.grammar != nil {
d.grammar.Accept(int(val))
}
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
@@ -253,6 +285,18 @@ func (d *Decoder) step() int32 {
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
// Apply grammar constraints if enabled
if d.grammar != nil {
// Get last position logits: [1, 1, vocab] -> [vocab]
shape := logits.Shape()
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
maskedLogits := d.grammar.ApplyMask(lastLogits)
// Reshape back to [1, 1, vocab] for sample()
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
@@ -262,9 +306,6 @@ func (d *Decoder) step() int32 {
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
@@ -289,6 +330,48 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Set up grammar constraint if enabled
var grammarEngine *grammar.Engine
var grammarVocab []string
if (in.JSONMode || in.GrammarEBNF != "") && len(in.Vocab) > 0 {
var compiled *grammar.Grammar
var err error
if in.GrammarEBNF != "" {
// Custom EBNF grammar
startRule := in.GrammarStart
if startRule == "" {
startRule = "root"
}
compiled, err = grammar.ParseEBNF(in.GrammarEBNF, startRule)
if err != nil {
return fmt.Errorf("failed to parse grammar: %w", err)
}
fmt.Printf("[Grammar mode: start=%s]\n", startRule)
} else {
// JSON object grammar (only allows objects at top level)
compiled, err = grammar.JSONObjectGrammar()
if err != nil {
return fmt.Errorf("failed to create JSON grammar: %w", err)
}
fmt.Println("[JSON object mode enabled]")
}
// Pad vocab to match model's vocab size if needed
grammarVocab = in.Vocab
modelVocabSize := int(m.VocabSize())
if len(grammarVocab) < modelVocabSize {
padded := make([]string, modelVocabSize)
copy(padded, grammarVocab)
grammarVocab = padded
}
grammarEngine, err = grammar.NewEngine(compiled, grammarVocab)
if err != nil {
return fmt.Errorf("failed to create grammar engine: %w", err)
}
defer grammarEngine.Close()
}
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
@@ -304,6 +387,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
tokens = tok.Encode(prompt, true)
}
if grammarEngine != nil {
dec.SetGrammar(grammarEngine, grammarVocab)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
@@ -327,6 +414,11 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
// Check if grammar is complete after first token
if dec.grammar != nil && dec.grammar.IsComplete() {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
@@ -341,6 +433,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
// Check if grammar is complete (valid JSON document finished)
if dec.grammar != nil && dec.grammar.IsComplete() {
break
}
if n%256 == 0 {
mlx.ClearCache()

View File

@@ -44,6 +44,9 @@ func main() {
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
jsonMode := flag.Bool("json", false, "Enable JSON grammar constraint (output will be valid JSON)")
grammarFile := flag.String("grammar", "", "Path to EBNF grammar file for constrained decoding")
grammarStart := flag.String("grammar-start", "root", "Start rule name for grammar (default: root)")
// Image generation params
width := flag.Int("width", 1024, "Image width")
@@ -67,9 +70,6 @@ func main() {
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
flag.Parse()
@@ -102,17 +102,13 @@ func main() {
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
TeaCache: *teaCache,
TeaCacheThreshold: float32(*teaCacheThreshold),
FusedQKV: *fusedQKV,
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
@@ -193,6 +189,20 @@ func main() {
}
}
// Get vocab for constrained decoding if needed
var vocab []string
var grammarEBNF string
if *jsonMode || *grammarFile != "" {
vocab = m.Tokenizer().Vocab()
}
if *grammarFile != "" {
data, err := os.ReadFile(*grammarFile)
if err != nil {
log.Fatalf("failed to read grammar file: %v", err)
}
grammarEBNF = string(data)
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
@@ -201,6 +211,10 @@ func main() {
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
JSONMode: *jsonMode,
GrammarEBNF: grammarEBNF,
GrammarStart: *grammarStart,
Vocab: vocab,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)

View File

@@ -40,12 +40,10 @@ type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo)
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
@@ -76,11 +74,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize == "fp8" && component != "vae" {
quantizeMsg = ", quantizing to fp8"
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
@@ -89,30 +83,16 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Count parameters from original tensor shape
if len(td.Shape) > 0 {
numElements := int64(1)
for _, dim := range td.Shape {
numElements *= int64(dim)
}
totalParams += numElements
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
// Determine if this tensor should be quantized
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, newLayers...)
layers = append(layers, layer)
}
extractor.Close()
@@ -142,7 +122,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
var r io.Reader
// For model_index.json, normalize to Ollama format and add metadata
// For model_index.json, normalize to Ollama format
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
@@ -161,16 +141,6 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
}
delete(cfg, "_diffusers_version")
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info
if quantize == "fp8" {
cfg["quantization"] = "FP8"
} else {
cfg["quantization"] = "BF16"
}
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)

View File

@@ -60,12 +60,9 @@ func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
}
// Transform to [H, W, C] for image conversion
// Free intermediate arrays to avoid memory leak
squeezed := mlx.Squeeze(arr, 0)
transposed := mlx.Transpose(squeezed, 1, 2, 0)
squeezed.Free()
img := mlx.Contiguous(transposed)
transposed.Free()
img := mlx.Squeeze(arr, 0)
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()

View File

@@ -607,11 +607,6 @@ func (a *Array) Valid() bool {
return a != nil && a.c.ctx != nil
}
// Kept returns true if the array is marked to survive Eval() cleanup.
func (a *Array) Kept() bool {
return a != nil && a.kept
}
func int32ToCInt(s []int32) *C.int {
if len(s) == 0 {
return nil
@@ -1485,44 +1480,6 @@ func (a *Array) ItemInt32() int32 {
return int32(val)
}
// Bytes copies the raw bytes out of the array without type conversion.
// Works with common dtypes (float32, int32, uint32, uint8).
// For non-contiguous arrays, call Contiguous() first.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) Bytes() []byte {
cleanup()
nbytes := a.Nbytes()
if nbytes == 0 {
return nil
}
// Get raw pointer based on dtype
var ptr unsafe.Pointer
switch a.Dtype() {
case DtypeFloat32:
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
case DtypeInt32:
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
case DtypeUint32:
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
case DtypeUint8:
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
default:
// For other types (bf16, f16, etc), convert to float32
arr := AsType(a, DtypeFloat32)
arr.Eval()
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
nbytes = arr.Nbytes()
}
if ptr == nil {
return nil
}
data := make([]byte, nbytes)
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
return data
}
// ============ Utility ============
// String returns a string representation
@@ -1701,34 +1658,6 @@ func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_string_free(s.metadata)
}
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
// This correctly handles all dtypes including uint32 for quantized weights.
func SaveSafetensors(path string, arrays map[string]*Array) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
// Create the map
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
// Add each array to the map
for name, arr := range arrays {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
C.free(unsafe.Pointer(cName))
}
// Create empty metadata (optional)
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
// Save
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file
@@ -1800,6 +1729,14 @@ func init() {
// Lock main goroutine to OS thread for CUDA context stability.
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
runtime.LockOSThread()
// Avoid Metal device init crashes on systems without Metal.
if runtime.GOOS == "darwin" {
if MetalIsAvailable() {
SetDefaultDeviceGPU()
} else {
SetDefaultDeviceCPU()
}
}
RandomState[0] = RandomKey(uint64(time.Now().UnixMilli()))
Keep(RandomState[0]) // Global state should persist
}
@@ -2057,8 +1994,7 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
// Returns (quantized_weights, scales, biases).
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default), "mxfp4", or "mxfp8"
// Note: mxfp8 mode returns nil biases (only weights and scales)
// mode: "affine" (default) or "mxfp4"
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
@@ -2067,21 +2003,14 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
// Result is a vector of arrays: [weights, scales, biases?]
// mxfp8 mode returns only 2 elements (no biases)
vecSize := int(C.mlx_vector_array_size(res))
// Result is a vector of 3 arrays: [weights, scales, biases]
var w0, w1, w2 C.mlx_array
C.mlx_vector_array_get(&w0, res, 0)
C.mlx_vector_array_get(&w1, res, 1)
if vecSize >= 3 {
C.mlx_vector_array_get(&w2, res, 2)
}
C.mlx_vector_array_get(&w2, res, 2)
C.mlx_vector_array_free(res)
if vecSize >= 3 {
return newArray(w0), newArray(w1), newArray(w2)
}
return newArray(w0), newArray(w1), nil
return newArray(w0), newArray(w1), newArray(w2)
}
// Dequantize reconstructs weights from quantized form.

View File

@@ -311,8 +311,8 @@ type Model struct {
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NewCache(int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))

View File

@@ -222,14 +222,6 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
@@ -272,19 +264,10 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// True CFG: run twice and combine with norm rescaling
// Note: layer caching with CFG is not supported yet (would need 2 caches)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
@@ -322,9 +305,6 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {

View File

@@ -241,14 +241,6 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
@@ -299,18 +291,11 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Tile inputs: [1, L, D] -> [2, L, D]
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
@@ -332,9 +317,6 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()

View File

@@ -28,12 +28,12 @@ type Qwen3Config struct {
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OProj *nn.Linear `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
@@ -136,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
GateProj *nn.Linear `weight:"gate_proj"`
UpProj *nn.Linear `weight:"up_proj"`
DownProj *nn.Linear `weight:"down_proj"`
}
// Forward applies the MLP

View File

@@ -36,8 +36,8 @@ type TransformerConfig struct {
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 *nn.Linear `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
@@ -74,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear nn.LinearLayer `weight:"2-1"`
Linear *nn.Linear `weight:"2-1"`
}
// Forward embeds patchified image latents
@@ -86,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
Linear *nn.Linear `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
@@ -100,13 +100,12 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
W1 *nn.Linear `weight:"w1"` // gate projection
W2 *nn.Linear `weight:"w2"` // down projection
W3 *nn.Linear `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -116,7 +115,6 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
@@ -128,69 +126,17 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
ToQ *nn.Linear `weight:"to_q"`
ToK *nn.Linear `weight:"to_k"`
ToV *nn.Linear `weight:"to_v"`
ToOut *nn.Linear `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
Dim int32 `weight:"-"`
Scale float32 `weight:"-"`
}
// FuseQKV creates a fused QKV projection by concatenating weights.
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
// Note: Fusion is skipped for quantized weights as it would require complex
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
// the ~5% fusion benefit.
func (attn *Attention) FuseQKV() {
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
return
}
// Skip fusion for quantized weights - type assert to check
toQ, qOk := attn.ToQ.(*nn.Linear)
toK, kOk := attn.ToK.(*nn.Linear)
toV, vOk := attn.ToV.(*nn.Linear)
if !qOk || !kOk || !vOk {
// One or more are QuantizedLinear, skip fusion
return
}
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
return
}
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
qWeight := toQ.Weight
kWeight := toK.Weight
vWeight := toV.Weight
// Concatenate along output dimension (axis 0)
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
// Evaluate fused weight to ensure it's materialized
mlx.Eval(fusedWeight)
// Create fused linear layer
fusedLinear := &nn.Linear{Weight: fusedWeight}
// Handle bias if present
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
mlx.Eval(fusedBias)
fusedLinear.Bias = fusedBias
}
attn.ToQKV = fusedLinear
attn.Fused = true
// Computed fields
NHeads int32
HeadDim int32
Dim int32
Scale float32
}
// Forward computes attention
@@ -200,24 +146,11 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
L := shape[1]
D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D)
var q, k, v *mlx.Array
if attn.Fused && attn.ToQKV != nil {
// Fused QKV path: single matmul then split
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
// Split into Q, K, V along last dimension
// Each has shape [B*L, dim]
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
} else {
// Separate Q, K, V projections
q = attn.ToQ.Forward(xFlat)
k = attn.ToK.Forward(xFlat)
v = attn.ToV.Forward(xFlat)
}
q := attn.ToQ.Forward(xFlat)
k := attn.ToK.Forward(xFlat)
v := attn.ToV.Forward(xFlat)
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
@@ -294,7 +227,7 @@ type TransformerBlock struct {
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -348,8 +281,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
@@ -417,11 +350,12 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights from tensor blobs with BF16 conversion
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
@@ -443,7 +377,7 @@ func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
@@ -457,20 +391,6 @@ func (m *Transformer) initComputedFields() {
}
}
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
func (m *Transformer) FuseAllQKV() {
for _, block := range m.NoiseRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.ContextRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.Layers {
block.Attention.FuseQKV()
}
}
// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim
@@ -484,7 +404,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
@@ -503,8 +423,6 @@ type RoPECache struct {
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
GridH int32 // Image token grid height
GridW int32 // Image token grid width
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
@@ -538,8 +456,6 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
GridH: hTok,
GridW: wTok,
}
}

View File

@@ -104,8 +104,6 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
groupSize := C / gn.NumGroups
// Keep the input - we need it for slicing tiles later
// Track if we were the ones who kept it, so we can restore state after
wasKept := x.Kept()
mlx.Keep(x)
// Compute per-group mean and variance using flattened spatial dimensions
@@ -207,10 +205,6 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
}
// Clean up kept arrays
// Restore x's kept state - only free if we were the ones who kept it
if !wasKept {
x.Free()
}
mean.Free()
invStd.Free()
if weightGN != nil {
@@ -740,26 +734,18 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
h := vae.ConvIn.Forward(z)
mlx.Eval(h)
prev := h
h = vae.MidBlock.Forward(h)
prev.Free()
for _, upBlock := range vae.UpBlocks {
prev = h
h = upBlock.Forward(h)
prev.Free()
}
prev = h
prev := h
h = vae.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
prev.Free()
prev = h
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
mlx.Eval(h)
prev.Free()
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
@@ -768,6 +754,7 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
prev.Free()
mlx.Eval(h)
return h

View File

@@ -26,12 +26,10 @@ type GenerateConfig struct {
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
// Fused QKV (fuse Q/K/V projections into single matmul)
FusedQKV bool // Enable fused QKV projection (default: false)
// Layer caching options (speedup via shallow layer reuse)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
}
// ProgressFunc is called during generation with step progress.
@@ -44,7 +42,6 @@ type Model struct {
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
qkvFused bool // Track if QKV has been fused (do only once)
}
// Load loads the Z-Image model from ollama blob storage.
@@ -199,17 +196,13 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
// TeaCache enabled by default
cfg.TeaCache = true
if cfg.TeaCacheThreshold <= 0 {
cfg.TeaCacheThreshold = 0.15
}
// Enable fused QKV if requested (only fuse once)
if cfg.FusedQKV && !m.qkvFused {
m.Transformer.FuseAllQKV()
m.qkvFused = true
fmt.Println(" Fused QKV enabled")
if cfg.LayerCache {
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
}
}
useCFG := cfg.NegativePrompt != ""
@@ -267,54 +260,12 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(ropeCache.UnifiedCos)
}
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
var batchedEmb *mlx.Array
if useCFG {
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// TeaCache for timestep-aware caching
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
var teaCache *cache.TeaCache
if cfg.TeaCache {
skipEarly := 0
if useCFG {
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
}
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
Threshold: cfg.TeaCacheThreshold,
RescaleFactor: 1.0,
SkipEarlySteps: skipEarly,
})
if useCFG {
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
} else {
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
}
}
// cleanup frees all kept arrays when we need to abort early
cleanup := func() {
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
teaCache.Free()
}
latents.Free()
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
@@ -326,7 +277,6 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if ctx != nil {
select {
case <-ctx.Done():
cleanup()
return nil, ctx.Err()
default:
}
@@ -339,77 +289,50 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
}
tCurr := scheduler.Timesteps[i]
var noisePred *mlx.Array
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
// TeaCache: check if we should compute or reuse cached output
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
patches := PatchifyLatents(latents, tcfg.PatchSize)
if shouldCompute {
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
var output *mlx.Array
if stepCache != nil {
// Use layer caching for faster inference
if useCFG {
// CFG Batching: single forward pass with batch=2
// Tile patches: [1, L, D] -> [2, L, D]
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
// Tile timestep: [1] -> [2]
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
outputShape := batchedOutput.Shape()
L := outputShape[1]
D := outputShape[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
// Convert to noise predictions (unpatchify and negate)
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
posPred = mlx.Neg(posPred)
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
negPred = mlx.Neg(negPred)
// Cache pos/neg separately for TeaCache
if teaCache != nil {
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
// Apply CFG: noisePred = neg + scale * (pos - neg)
diff := mlx.Sub(posPred, negPred)
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Note: CFG with layer cache shares the cache between pos/neg
// This is approximate but fast - neg prompt uses same cached shallow layers
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
output = mlx.Add(negOutput, scaledDiff)
} else {
// Non-CFG forward pass
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
// Update TeaCache
if teaCache != nil {
teaCache.UpdateCache(noisePred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
}
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
// CFG mode: get cached pos/neg and compute CFG fresh
posPred, negPred := teaCache.GetCFGCached()
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
} else {
// Non-CFG mode: reuse cached noise prediction
noisePred = teaCache.GetCached()
fmt.Printf(" [TeaCache: reusing cached output]\n")
// Standard forward without caching
if useCFG {
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
}
}
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
@@ -438,14 +361,8 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
hits, misses := teaCache.Stats()
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
hits, misses, float64(hits)/float64(hits+misses)*100)
teaCache.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode

View File

@@ -10,13 +10,6 @@ type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// LinearLayer is an interface for linear layers (both regular and quantized).
// This allows swapping between Linear and QuantizedLinear at runtime.
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32 // Returns the output dimension of the layer
}
// Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct {
@@ -56,11 +49,6 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
return mlx.Linear(x, w)
}
// OutputDim returns the output dimension of the linear layer.
func (l *Linear) OutputDim() int32 {
return l.Weight.Shape()[0]
}
// ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
@@ -96,13 +84,6 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
return out
}
// OutputDim returns the output dimension of the quantized linear layer.
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
// The output dimension is the first dimension of the weight.
func (ql *QuantizedLinear) OutputDim() int32 {
return ql.Weight.Shape()[0]
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array `weight:"weight"`

View File

@@ -1,22 +0,0 @@
package imagegen
import (
"io"
"strings"
)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ShouldQuantize returns true if a tensor should be quantized.
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
func ShouldQuantize(name, component string) bool {
if component == "vae" {
return false
}
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
return false
}
return strings.HasSuffix(name, ".weight")
}

View File

@@ -13,6 +13,7 @@ import (
"net/http"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
@@ -33,8 +34,7 @@ type Request struct {
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Content string `json:"content"`
Done bool `json:"done"`
}
@@ -191,10 +191,10 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
// Save image
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
if err := imagegen.SaveImage(img, outPath); err != nil {
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
@@ -204,12 +204,11 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response with image data
// Send final response
resp := Response{
Image: imageData,
Done: true,
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)

View File

@@ -8,7 +8,6 @@ import (
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// WeightSource is an interface for loading weights.
@@ -103,22 +102,6 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
}
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
if !hasTag {
continue // no tag = skip
}
layer, err := LoadLinearLayer(weights, fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath+": "+err.Error())
}
continue
}
fieldVal.Set(reflect.ValueOf(layer))
continue
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -193,64 +176,3 @@ func joinPath(prefix, suffix string) string {
}
return prefix + "." + suffix
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
}
scales, err := weights.GetTensor(scalePath)
if err != nil {
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
if mlx.MetalIsAvailable() {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: 32,
Bits: 8,
Mode: "affine",
}, nil
}
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
return nn.NewLinear(dequantized, bias), nil
}
// Load as regular Linear
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
return nn.NewLinear(weight, bias), nil
}

View File

@@ -46,8 +46,7 @@ type completionRequest struct {
// completionResponse is received from the subprocess
type completionResponse struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"`
Content string `json:"content"`
Done bool `json:"done"`
}
@@ -251,23 +250,15 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
}
// Stream responses - use large buffer for base64 image data
// Stream responses
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
var cresp completionResponse
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
continue
}
content := cresp.Content
// If this is the final response with an image, encode it in the content
if cresp.Done && cresp.Image != "" {
content = "IMAGE_BASE64:" + cresp.Image
}
fn(llm.CompletionResponse{
Content: content,
Content: cresp.Content,
Done: cresp.Done,
})
if cresp.Done {

View File

@@ -1082,6 +1082,12 @@ func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
return id, ok
}
// Vocab returns the vocabulary as a slice of token strings indexed by token ID.
// This is useful for constrained decoding where we need to map tokens to grammar symbols.
func (t *Tokenizer) Vocab() []string {
return t.vocab.Values
}
// LoadVocabMerges loads a tokenizer from vocab.json + merges.txt format (GPT-style)
func LoadVocabMerges(dir string) (*Tokenizer, error) {
vocabPath := dir + "/vocab.json"