mirror of
https://github.com/ollama/ollama.git
synced 2026-01-11 09:00:53 -05:00
Compare commits
1 Commits
main
...
mlx-gpu-cd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e23ddd84b8 |
18
go.mod
18
go.mod
@@ -15,8 +15,8 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/sys v0.39.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -30,8 +30,8 @@ require (
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
golang.org/x/mod v0.31.0
|
||||
golang.org/x/tools v0.40.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
@@ -81,11 +81,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/text v0.30.0
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/term v0.38.0
|
||||
golang.org/x/text v0.32.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
36
go.sum
36
go.sum
@@ -233,16 +233,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -264,8 +264,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -278,8 +278,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -289,8 +289,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -306,17 +306,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -330,8 +330,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
185
x/grammar/README.md
Normal file
185
x/grammar/README.md
Normal file
@@ -0,0 +1,185 @@
|
||||
# grammar
|
||||
|
||||
Grammar-constrained decoding for LLM outputs using MLX.
|
||||
|
||||
## Performance
|
||||
|
||||
Performance depends on hardware, vocabulary size, grammar, and whether you
|
||||
evaluate the MLX graph. See [Benchmarks](#benchmarks) for how to measure on your
|
||||
setup.
|
||||
|
||||
### Design choices that keep masking fast
|
||||
|
||||
| Technique | Impact |
|
||||
|-----------|--------|
|
||||
| Precomputed token analysis | Terminal matches computed once at startup |
|
||||
| Mask caching by grammar state signature | Reuse masks for repeated parser states |
|
||||
| Partitioned tokens | Exact matches separated from DP candidates |
|
||||
|
||||
### Comparison Notes
|
||||
|
||||
- **llama.cpp**: Decodes each token to UTF-8, checks against PDA. No caching.
|
||||
- **Outlines**: FSM-based. Compilation can take 40s-10min for complex schemas. Fast after compile.
|
||||
- **XGrammar**: PDA with 99% context-independent tokens precomputed. State-of-the-art before this.
|
||||
- **x/grammar**: Precomputed token analysis + mask caching by grammar state signature.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
)
|
||||
|
||||
// Use built-in JSON grammar
|
||||
g, _ := grammar.JSONGrammar()
|
||||
|
||||
// Or from JSON Schema (OpenAI-compatible)
|
||||
g, _ := schema.Grammar(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`)
|
||||
|
||||
// Or parse custom EBNF
|
||||
g, _ := grammar.ParseEBNF(myGrammar, "root")
|
||||
|
||||
// Create engine with model vocabulary
|
||||
engine, _ := grammar.NewEngine(g, vocab)
|
||||
defer engine.Close()
|
||||
|
||||
// Generation loop
|
||||
for !engine.IsComplete() {
|
||||
logits := model.Forward(tokens)
|
||||
masked := engine.ApplyMask(logits) // Invalid tokens → -inf
|
||||
nextToken := sample(masked)
|
||||
engine.Accept(nextToken)
|
||||
}
|
||||
// Output conforms to the grammar when you only sample from masked tokens and call Accept
|
||||
```
|
||||
|
||||
## EBNF Syntax
|
||||
|
||||
```ebnf
|
||||
rule = expression . # Rule definition (ends with .)
|
||||
"literal" # Literal string
|
||||
"a" … "z" # Character range (inclusive)
|
||||
( a | b ) # Grouping with alternation
|
||||
[ optional ] # Optional (0 or 1)
|
||||
{ repeated } # Repetition (0 or more)
|
||||
```
|
||||
|
||||
### Example: JSON Grammar
|
||||
|
||||
```ebnf
|
||||
json = value .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
```
|
||||
|
||||
### Example: Custom Schema
|
||||
|
||||
```ebnf
|
||||
root = "{" ws name_field "," ws age_field ws "}" .
|
||||
|
||||
name_field = "\"name\"" ws ":" ws string .
|
||||
age_field = "\"age\"" ws ":" ws number .
|
||||
|
||||
string = "\"" { char } "\"" .
|
||||
char = " " | "!" | "#" … "~" .
|
||||
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
|
||||
ws = { " " | "\n" } .
|
||||
```
|
||||
|
||||
## JSON Schema Support
|
||||
|
||||
OpenAI-compatible JSON Schema support with automatic EBNF generation:
|
||||
|
||||
```go
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {"$ref": "#/$defs/User"}
|
||||
},
|
||||
"required": ["user"],
|
||||
"$defs": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"role": {"enum": ["admin", "user", "guest"]}
|
||||
},
|
||||
"required": ["name", "email", "role"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
grammar, _ := schema.Grammar(schema)
|
||||
```
|
||||
|
||||
### Supported Features
|
||||
|
||||
| Feature | Example |
|
||||
|---------|---------|
|
||||
| Basic types | `string`, `integer`, `number`, `boolean`, `null` |
|
||||
| Objects | `properties`, `required` |
|
||||
| Arrays | `items`, `minItems`, `maxItems` |
|
||||
| Enums | `enum: ["a", "b", "c"]` |
|
||||
| Constants | `const: "value"` |
|
||||
| Union types | `anyOf`, `oneOf`, `type: ["string", "null"]` |
|
||||
| References | `$ref: "#/$defs/Name"`, `$defs` |
|
||||
| Formats | `date`, `time`, `date-time`, `email`, `uuid`, `ipv4` |
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test -tags mlx ./x/grammar/...
|
||||
|
||||
# Run benchmarks
|
||||
go test -tags mlx ./x/grammar/ -bench=.
|
||||
|
||||
# Compare with llama.cpp (outputs JSON)
|
||||
go run -tags mlx ./x/grammar/cmd/compare -vocab-size 128000 -iterations 500
|
||||
|
||||
# Compare with a more complex schema
|
||||
go run -tags mlx ./x/grammar/cmd/compare \
|
||||
-gbnf x/grammar/cmd/compare/complex.gbnf \
|
||||
-schema x/grammar/cmd/compare/complex.schema.json \
|
||||
-vocab-size 128000 -iterations 500
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [XGrammar Paper](https://arxiv.org/abs/2411.15100) - Flexible and Efficient Structured Generation
|
||||
- [Outlines](https://github.com/dottxt-ai/outlines) - Structured Text Generation
|
||||
- [JSONSchemaBench](https://arxiv.org/abs/2501.10868) - Benchmark for Structured Outputs
|
||||
161
x/grammar/analyzer.go
Normal file
161
x/grammar/analyzer.go
Normal file
@@ -0,0 +1,161 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
// terminalTokenGroups contains pre-partitioned tokens for a terminal.
|
||||
// This enables O(1) lookup of tokens that exactly match vs need DP validation.
|
||||
type terminalTokenGroups struct {
|
||||
// ExactMatches are tokens that exactly match this terminal (O(1) validation)
|
||||
ExactMatches []int32
|
||||
|
||||
// DPCandidates are tokens that start with this terminal but need DP validation
|
||||
DPCandidates []int
|
||||
}
|
||||
|
||||
// tokenAnalysis contains precomputed terminal matches for a token
|
||||
type tokenAnalysis struct {
|
||||
// The token string
|
||||
Token string
|
||||
|
||||
// TokenID in the vocabulary
|
||||
TokenID int
|
||||
|
||||
// Matches at each byte position
|
||||
// MatchesAtPos[i] = terminals matching at position i with their lengths
|
||||
MatchesAtPos [][]terminalMatch
|
||||
|
||||
// Fast path: if token exactly matches one terminal
|
||||
// -1 if no exact match
|
||||
exactMatch int
|
||||
|
||||
// Whether this token can be consumed at all (has at least one match)
|
||||
HasMatches bool
|
||||
}
|
||||
|
||||
// analyzer precomputes terminal matches for a vocabulary
|
||||
type analyzer struct {
|
||||
matcher *terminalMatcher
|
||||
analyses []tokenAnalysis // Indexed by token ID
|
||||
vocab []string
|
||||
|
||||
// Pre-partitioned tokens by terminal (exact match vs DP candidates)
|
||||
// This enables direct slice appends instead of per-token branching
|
||||
tokensByTerminal []terminalTokenGroups
|
||||
}
|
||||
|
||||
// newAnalyzer creates an analyzer for the given vocabulary and terminals
|
||||
func newAnalyzer(vocab []string, matcher *terminalMatcher) *analyzer {
|
||||
a := &analyzer{
|
||||
matcher: matcher,
|
||||
analyses: make([]tokenAnalysis, len(vocab)),
|
||||
vocab: vocab,
|
||||
}
|
||||
|
||||
// Precompute analysis for each token
|
||||
for i, token := range vocab {
|
||||
a.analyses[i] = a.analyze(token, i)
|
||||
}
|
||||
|
||||
// Build pre-partitioned token groups for fast ApplyMask
|
||||
a.buildTokenPartitions()
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// analyze computes terminal matches for a single token
|
||||
func (a *analyzer) analyze(token string, tokenID int) tokenAnalysis {
|
||||
analysis := tokenAnalysis{
|
||||
Token: token,
|
||||
TokenID: tokenID,
|
||||
MatchesAtPos: make([][]terminalMatch, len(token)),
|
||||
exactMatch: -1,
|
||||
HasMatches: false,
|
||||
}
|
||||
|
||||
if len(token) == 0 {
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Compute matches at each position
|
||||
data := []byte(token)
|
||||
for pos := 0; pos < len(data); pos++ {
|
||||
matches := a.matcher.matchesAt(data, pos)
|
||||
analysis.MatchesAtPos[pos] = matches
|
||||
if len(matches) > 0 {
|
||||
analysis.HasMatches = true
|
||||
}
|
||||
}
|
||||
|
||||
// Exact match is only valid when a single terminal spans the entire token
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
var exactID int = -1
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
if match.Length != len(token) {
|
||||
continue
|
||||
}
|
||||
if exactID >= 0 && exactID != match.TerminalID {
|
||||
exactID = -1
|
||||
break
|
||||
}
|
||||
exactID = match.TerminalID
|
||||
}
|
||||
analysis.exactMatch = exactID
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// analysis returns the precomputed analysis for a token ID
|
||||
func (a *analyzer) analysis(tokenID int) tokenAnalysis {
|
||||
if tokenID < 0 || tokenID >= len(a.analyses) {
|
||||
return tokenAnalysis{exactMatch: -1}
|
||||
}
|
||||
return a.analyses[tokenID]
|
||||
}
|
||||
|
||||
// vocabSize returns the vocabulary size
|
||||
func (a *analyzer) vocabSize() int {
|
||||
return len(a.vocab)
|
||||
}
|
||||
|
||||
// buildTokenPartitions pre-partitions tokens into exact-match vs needs-DP groups per terminal.
|
||||
// This enables ApplyMask to use direct slice appends instead of per-token branching.
|
||||
func (a *analyzer) buildTokenPartitions() {
|
||||
numTerminals := a.matcher.terminalCount()
|
||||
a.tokensByTerminal = make([]terminalTokenGroups, numTerminals)
|
||||
|
||||
for tokenID, analysis := range a.analyses {
|
||||
if !analysis.HasMatches {
|
||||
continue
|
||||
}
|
||||
|
||||
if analysis.exactMatch >= 0 {
|
||||
// Token exactly matches one terminal - fast path (O(1) validation)
|
||||
tid := analysis.exactMatch
|
||||
a.tokensByTerminal[tid].ExactMatches = append(
|
||||
a.tokensByTerminal[tid].ExactMatches, int32(tokenID))
|
||||
} else {
|
||||
// Token needs DP validation - add to all terminals it can start with
|
||||
// This way, when a terminal is valid, we know exactly which tokens need DP
|
||||
if len(analysis.MatchesAtPos) > 0 {
|
||||
seen := make(map[int]bool)
|
||||
for _, match := range analysis.MatchesAtPos[0] {
|
||||
tid := match.TerminalID
|
||||
if !seen[tid] {
|
||||
seen[tid] = true
|
||||
a.tokensByTerminal[tid].DPCandidates = append(
|
||||
a.tokensByTerminal[tid].DPCandidates, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminalGroups returns the pre-partitioned token groups for a terminal ID
|
||||
func (a *analyzer) terminalGroups(terminalID int) terminalTokenGroups {
|
||||
if terminalID < 0 || terminalID >= len(a.tokensByTerminal) {
|
||||
return terminalTokenGroups{}
|
||||
}
|
||||
return a.tokensByTerminal[terminalID]
|
||||
}
|
||||
648
x/grammar/bridge.go
Normal file
648
x/grammar/bridge.go
Normal file
@@ -0,0 +1,648 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// visitedMapPool reduces allocations for visited maps in bridge operations
|
||||
var visitedMapPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make(map[stateStackKey]bool, 16)
|
||||
},
|
||||
}
|
||||
|
||||
// getVisitedMap gets a map from the pool
|
||||
func getVisitedMap() map[stateStackKey]bool {
|
||||
return visitedMapPool.Get().(map[stateStackKey]bool)
|
||||
}
|
||||
|
||||
// putVisitedMap returns a map to the pool after clearing it
|
||||
func putVisitedMap(m map[stateStackKey]bool) {
|
||||
for k := range m {
|
||||
delete(m, k)
|
||||
}
|
||||
visitedMapPool.Put(m)
|
||||
}
|
||||
|
||||
// parserConfig represents a pda state+stack combination
|
||||
type parserConfig struct {
|
||||
state state
|
||||
Stack []stackSymbol
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config
|
||||
func (c *parserConfig) clone() *parserConfig {
|
||||
newStack := make([]stackSymbol, len(c.Stack))
|
||||
copy(newStack, c.Stack)
|
||||
return &parserConfig{
|
||||
state: c.state,
|
||||
Stack: newStack,
|
||||
}
|
||||
}
|
||||
|
||||
// key returns a unique key for this config for deduplication
|
||||
func (c *parserConfig) key() uint64 {
|
||||
h := fnv.New64a()
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(c.state))
|
||||
h.Write(buf[:])
|
||||
for _, sym := range c.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// configSet represents a set of parser configurations (for nondeterminism)
|
||||
type configSet struct {
|
||||
configs []*parserConfig
|
||||
normalized bool // true if already deduplicated and sorted
|
||||
cachedSig uint64 // cached signature after normalization
|
||||
}
|
||||
|
||||
// newConfigSet creates a new config set with a single configuration
|
||||
func newConfigSet(state state, stack []stackSymbol) *configSet {
|
||||
return &configSet{
|
||||
configs: []*parserConfig{
|
||||
{state: state, Stack: stack},
|
||||
},
|
||||
normalized: true, // single config is already normalized
|
||||
}
|
||||
}
|
||||
|
||||
// normalize deduplicates and sorts configs for stable signatures
|
||||
func (c *configSet) normalize() {
|
||||
if c.normalized || len(c.configs) <= 1 {
|
||||
c.normalized = true
|
||||
return
|
||||
}
|
||||
|
||||
// Deduplicate using a map
|
||||
seen := make(map[uint64]*parserConfig, len(c.configs))
|
||||
for _, cfg := range c.configs {
|
||||
key := cfg.key()
|
||||
if _, exists := seen[key]; !exists {
|
||||
seen[key] = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// Extract unique configs
|
||||
unique := make([]*parserConfig, 0, len(seen))
|
||||
for _, cfg := range seen {
|
||||
unique = append(unique, cfg)
|
||||
}
|
||||
|
||||
// Sort by key for deterministic ordering
|
||||
sort.Slice(unique, func(i, j int) bool {
|
||||
return unique[i].key() < unique[j].key()
|
||||
})
|
||||
|
||||
c.configs = unique
|
||||
c.normalized = true
|
||||
}
|
||||
|
||||
// signature returns a hash for cache lookup (normalizes first)
|
||||
func (c *configSet) signature() uint64 {
|
||||
c.normalize()
|
||||
|
||||
// Return cached signature if available
|
||||
if c.cachedSig != 0 {
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
|
||||
// Hash number of configs
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(c.configs)))
|
||||
h.Write(buf[:])
|
||||
|
||||
// Hash each config (already sorted)
|
||||
for _, cfg := range c.configs {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(cfg.state))
|
||||
h.Write(buf[:])
|
||||
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(len(cfg.Stack)))
|
||||
h.Write(buf[:])
|
||||
|
||||
for _, sym := range cfg.Stack {
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(sym))
|
||||
h.Write(buf[:])
|
||||
}
|
||||
}
|
||||
|
||||
c.cachedSig = h.Sum64()
|
||||
return c.cachedSig
|
||||
}
|
||||
|
||||
// isEmpty returns true if there are no configurations
|
||||
func (c *configSet) isEmpty() bool {
|
||||
return len(c.configs) == 0
|
||||
}
|
||||
|
||||
// clone creates a deep copy of the config set
|
||||
func (c *configSet) clone() *configSet {
|
||||
newConfigs := make([]*parserConfig, len(c.configs))
|
||||
for i, cfg := range c.configs {
|
||||
newConfigs[i] = cfg.clone()
|
||||
}
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// bridge connects token analysis to pda validation
|
||||
type bridge struct {
|
||||
pda *pda
|
||||
analyzer *analyzer
|
||||
}
|
||||
|
||||
// newBridge creates a new bridge
|
||||
func newBridge(pda *pda, analyzer *analyzer) *bridge {
|
||||
return &bridge{
|
||||
pda: pda,
|
||||
analyzer: analyzer,
|
||||
}
|
||||
}
|
||||
|
||||
// IsTokenValid checks if token T can be consumed from the current config
|
||||
// This is the main entry point for token validation
|
||||
func (b *bridge) IsTokenValid(tokenID int, config *configSet) bool {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
return b.canAcceptTerminal(config, terminal.Pattern)
|
||||
}
|
||||
|
||||
// General path: DP over (pos, config)
|
||||
return b.dpValidate(&analysis, config)
|
||||
}
|
||||
|
||||
// canAcceptTerminal checks if any config can accept the terminal
|
||||
func (b *bridge) canAcceptTerminal(config *configSet, pattern string) bool {
|
||||
for _, cfg := range config.configs {
|
||||
if b.canConfigAcceptTerminal(cfg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canConfigAcceptTerminal checks if a single config can accept the terminal
|
||||
func (b *bridge) canConfigAcceptTerminal(cfg *parserConfig, pattern string) bool {
|
||||
// Use pooled visited map to reduce allocations
|
||||
visited := getVisitedMap()
|
||||
result := b.tryAcceptTerminal(cfg.state, cfg.Stack, pattern, visited)
|
||||
putVisitedMap(visited)
|
||||
return result
|
||||
}
|
||||
|
||||
// tryAcceptTerminal recursively tries to accept a terminal from a state
|
||||
func (b *bridge) tryAcceptTerminal(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Direct match
|
||||
return true
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.tryAcceptTerminal(t.ToState, newStack, pattern, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// dpValidate runs DP for multi-terminal tokens
|
||||
func (b *bridge) dpValidate(analysis *tokenAnalysis, startConfig *configSet) bool {
|
||||
// state: (pos, configSet)
|
||||
// Memoize by (pos, configSig)
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
memo := make(map[dpKey]bool)
|
||||
|
||||
var dp func(pos int, config *configSet) bool
|
||||
dp = func(pos int, config *configSet) bool {
|
||||
if pos == len(analysis.Token) {
|
||||
return true // Consumed entire token
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() && dp(pos+match.Length, newConfig) {
|
||||
memo[key] = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
memo[key] = false
|
||||
return false
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// advanceConfig advances all configs that can accept the terminal
|
||||
func (b *bridge) advanceConfig(config *configSet, pattern string) *configSet {
|
||||
var newConfigs []*parserConfig
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
advanced := b.advanceSingleConfig(cfg, pattern)
|
||||
newConfigs = append(newConfigs, advanced...)
|
||||
}
|
||||
|
||||
if len(newConfigs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &configSet{configs: newConfigs}
|
||||
}
|
||||
|
||||
// advanceSingleConfig advances a single config by accepting a terminal
|
||||
func (b *bridge) advanceSingleConfig(cfg *parserConfig, pattern string) []*parserConfig {
|
||||
var results []*parserConfig
|
||||
visited := getVisitedMap()
|
||||
b.collectAdvanced(cfg.state, cfg.Stack, pattern, visited, &results)
|
||||
putVisitedMap(visited)
|
||||
return results
|
||||
}
|
||||
|
||||
// collectAdvanced collects all configs reachable by accepting the pattern
|
||||
func (b *bridge) collectAdvanced(state state, stack []stackSymbol, pattern string, visited map[stateStackKey]bool, results *[]*parserConfig) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Can't pop more than we have
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == pattern {
|
||||
// Match! Create new config after transition
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
*results = append(*results, &parserConfig{
|
||||
state: t.ToState,
|
||||
Stack: newStack,
|
||||
})
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - follow it
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
b.collectAdvanced(t.ToState, newStack, pattern, visited, results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTokens returns all token IDs that are valid from the given config
|
||||
func (b *bridge) validTokens(config *configSet) []int {
|
||||
var valid []int
|
||||
for tokenID := 0; tokenID < b.analyzer.vocabSize(); tokenID++ {
|
||||
if b.IsTokenValid(tokenID, config) {
|
||||
valid = append(valid, tokenID)
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
// acceptToken attempts to accept a token and returns the new config set
|
||||
// Returns nil if the token is not valid from this config
|
||||
func (b *bridge) acceptToken(tokenID int, config *configSet) *configSet {
|
||||
analysis := b.analyzer.analysis(tokenID)
|
||||
|
||||
if !analysis.HasMatches {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path: exact terminal match
|
||||
if analysis.exactMatch >= 0 {
|
||||
terminal := b.analyzer.matcher.terminals[analysis.exactMatch]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
newConfig.normalize()
|
||||
return newConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// General path: DP to find final config after consuming token
|
||||
return b.dpAccept(&analysis, config)
|
||||
}
|
||||
|
||||
// dpAccept runs DP to accept a multi-terminal token and return final config
|
||||
// Returns the union of all possible end configurations (preserves nondeterminism)
|
||||
func (b *bridge) dpAccept(analysis *tokenAnalysis, startConfig *configSet) *configSet {
|
||||
type dpKey struct {
|
||||
pos int
|
||||
sig uint64
|
||||
}
|
||||
// Memoize the configs reachable at each (pos, sig)
|
||||
memo := make(map[dpKey]*configSet)
|
||||
|
||||
var dp func(pos int, config *configSet) *configSet
|
||||
dp = func(pos int, config *configSet) *configSet {
|
||||
if pos == len(analysis.Token) {
|
||||
return config // Consumed entire token, return final config
|
||||
}
|
||||
|
||||
if config.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := dpKey{pos, config.signature()}
|
||||
if result, ok := memo[key]; ok {
|
||||
return result
|
||||
}
|
||||
|
||||
// Collect all valid result configs from all possible paths
|
||||
var allConfigs []*parserConfig
|
||||
|
||||
// Try each terminal that matches at this position
|
||||
for _, match := range analysis.MatchesAtPos[pos] {
|
||||
terminal := b.analyzer.matcher.terminals[match.TerminalID]
|
||||
newConfig := b.advanceConfig(config, terminal.Pattern)
|
||||
if newConfig != nil && !newConfig.isEmpty() {
|
||||
finalConfig := dp(pos+match.Length, newConfig)
|
||||
if finalConfig != nil {
|
||||
// Collect all configs, don't return early
|
||||
allConfigs = append(allConfigs, finalConfig.configs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build result: nil if no valid paths, normalized configSet otherwise
|
||||
var result *configSet
|
||||
if len(allConfigs) > 0 {
|
||||
result = &configSet{configs: allConfigs}
|
||||
result.normalize() // Dedup using parserConfig.key(), sort for consistent signature
|
||||
}
|
||||
memo[key] = result // Cache normalized result
|
||||
return result
|
||||
}
|
||||
|
||||
return dp(0, startConfig)
|
||||
}
|
||||
|
||||
// isAccepting returns true if any config can reach an accepting state
|
||||
func (b *bridge) isAccepting(config *configSet) bool {
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config check
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
if b.canReachAccept(cfg.state, cfg.Stack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// canReachAccept checks if we can reach an accepting state via epsilon transitions
|
||||
func (b *bridge) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if b.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if b.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the given config
|
||||
func (b *bridge) validTerminals(config *configSet) []string {
|
||||
seen := make(map[string]bool)
|
||||
var terminals []string
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminals(cfg.state, cfg.Stack, visited, seen, &terminals)
|
||||
}
|
||||
|
||||
return terminals
|
||||
}
|
||||
|
||||
// collectValidTerminals collects all reachable terminals
|
||||
func (b *bridge) collectValidTerminals(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[string]bool, terminals *[]string) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" && !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*terminals = append(*terminals, t.Pattern)
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminals(t.ToState, newStack, visited, seen, terminals)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validTerminalIDs returns the IDs of valid terminals from the given config
|
||||
func (b *bridge) validTerminalIDs(config *configSet) []int {
|
||||
seen := make(map[int]bool)
|
||||
var terminalIDs []int
|
||||
|
||||
visited := getVisitedMap()
|
||||
defer putVisitedMap(visited)
|
||||
|
||||
for _, cfg := range config.configs {
|
||||
// Clear visited for each config
|
||||
for k := range visited {
|
||||
delete(visited, k)
|
||||
}
|
||||
b.collectValidTerminalIDs(cfg.state, cfg.Stack, visited, seen, &terminalIDs)
|
||||
}
|
||||
|
||||
return terminalIDs
|
||||
}
|
||||
|
||||
// collectValidTerminalIDs collects IDs of all reachable terminals
|
||||
func (b *bridge) collectValidTerminalIDs(state state, stack []stackSymbol, visited map[stateStackKey]bool, seen map[int]bool, terminalIDs *[]int) {
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
|
||||
for _, t := range b.pda.Transitions[state] {
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern != "" {
|
||||
// Look up terminal ID from pattern
|
||||
if tid, ok := b.analyzer.matcher.patternToID[t.Pattern]; ok && !seen[tid] {
|
||||
seen[tid] = true
|
||||
*terminalIDs = append(*terminalIDs, tid)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
if t.StackPop > 0 {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
b.collectValidTerminalIDs(t.ToState, newStack, visited, seen, terminalIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
45
x/grammar/cmd/compare/complex.gbnf
Normal file
45
x/grammar/cmd/compare/complex.gbnf
Normal file
@@ -0,0 +1,45 @@
|
||||
root ::= ws "{" ws id-field "," ws kind-field "," ws items-field "," ws alt-field "," ws flags-field "," ws meta-field "," ws priority-field ws "}" ws
|
||||
|
||||
id-field ::= "\"id\"" ws ":" ws uuid
|
||||
kind-field ::= "\"kind\"" ws ":" ws kind
|
||||
items-field ::= "\"items\"" ws ":" ws items
|
||||
alt-field ::= "\"alt\"" ws ":" ws alt
|
||||
flags-field ::= "\"flags\"" ws ":" ws flags
|
||||
meta-field ::= "\"meta\"" ws ":" ws meta
|
||||
priority-field ::= "\"priority\"" ws ":" ws int
|
||||
|
||||
kind ::= "\"order\"" | "\"invoice\"" | "\"shipment\""
|
||||
status ::= "\"new\"" | "\"backorder\"" | "\"shipped\""
|
||||
flag ::= "\"fragile\"" | "\"gift\"" | "\"priority\"" | "\"insured\""
|
||||
source ::= "\"api\"" | "\"batch\"" | "\"import\""
|
||||
|
||||
items ::= "[" ws item ( "," ws item )? ( "," ws item )? ws "]"
|
||||
flags ::= "[" ws "]" | "[" ws flag ( "," ws flag )? ( "," ws flag )? ( "," ws flag )? ws "]"
|
||||
|
||||
item ::= "{" ws item-sku "," ws item-qty "," ws item-status "," ws item-notes ws "}"
|
||||
item-sku ::= "\"sku\"" ws ":" ws string
|
||||
item-qty ::= "\"qty\"" ws ":" ws int
|
||||
item-status ::= "\"status\"" ws ":" ws status
|
||||
item-notes ::= "\"notes\"" ws ":" ws string
|
||||
|
||||
meta ::= "{" ws meta-created "," ws meta-source "," ws meta-ip ws "}"
|
||||
meta-created ::= "\"created\"" ws ":" ws date-time
|
||||
meta-source ::= "\"source\"" ws ":" ws source
|
||||
meta-ip ::= "\"ip\"" ws ":" ws ipv4
|
||||
|
||||
alt ::= string | int | "null"
|
||||
|
||||
uuid ::= "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\""
|
||||
date-time ::= "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\""
|
||||
ipv4 ::= "\"" digit+ "." digit+ "." digit+ "." digit+ "\""
|
||||
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
|
||||
int ::= "-"? digit+
|
||||
digit ::= [0-9]
|
||||
hex ::= [0-9a-fA-F]
|
||||
|
||||
ws ::= [ \t\n\r]*
|
||||
46
x/grammar/cmd/compare/complex.schema.json
Normal file
46
x/grammar/cmd/compare/complex.schema.json
Normal file
@@ -0,0 +1,46 @@
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string", "format": "uuid" },
|
||||
"kind": { "enum": ["order", "invoice", "shipment"] },
|
||||
"items": {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"maxItems": 3,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sku": { "type": "string" },
|
||||
"qty": { "type": "integer" },
|
||||
"status": { "enum": ["new", "backorder", "shipped"] },
|
||||
"notes": { "type": "string" }
|
||||
},
|
||||
"required": ["sku", "qty", "status", "notes"]
|
||||
}
|
||||
},
|
||||
"alt": {
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"minItems": 0,
|
||||
"maxItems": 4,
|
||||
"items": { "enum": ["fragile", "gift", "priority", "insured"] }
|
||||
},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": { "type": "string", "format": "date-time" },
|
||||
"source": { "enum": ["api", "batch", "import"] },
|
||||
"ip": { "type": "string", "format": "ipv4" }
|
||||
},
|
||||
"required": ["created", "source", "ip"]
|
||||
},
|
||||
"priority": { "type": "integer" }
|
||||
},
|
||||
"required": ["id", "kind", "items", "alt", "flags", "meta", "priority"]
|
||||
}
|
||||
235
x/grammar/cmd/compare/main.go
Normal file
235
x/grammar/cmd/compare/main.go
Normal file
@@ -0,0 +1,235 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/grammar/schema"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
const jsonGBNF = `
|
||||
root ::= value
|
||||
value ::= object | array | string | number | "true" | "false" | "null"
|
||||
object ::= "{" ws "}" | "{" members "}"
|
||||
members ::= member ("," member)*
|
||||
member ::= ws string ws ":" element
|
||||
array ::= "[" ws "]" | "[" elements "]"
|
||||
elements ::= element ("," element)*
|
||||
element ::= ws value ws
|
||||
string ::= "\"" characters "\""
|
||||
characters ::= character*
|
||||
character ::= [^"\\] | "\\" escape
|
||||
escape ::= ["\\bfnrt]
|
||||
number ::= "-"? integer fraction? exponent?
|
||||
integer ::= "0" | [1-9] [0-9]*
|
||||
fraction ::= "." [0-9]+
|
||||
exponent ::= [eE] [+-]? [0-9]+
|
||||
ws ::= [ \t\n\r]*
|
||||
`
|
||||
|
||||
type result struct {
|
||||
vocabSize int `json:"vocab_size"`
|
||||
Iterations int `json:"iterations"`
|
||||
Warmup int `json:"warmup"`
|
||||
ConstrainedSource string `json:"constrained_source"`
|
||||
LlamaSource string `json:"llama_source"`
|
||||
LlamaApply string `json:"llama_apply"`
|
||||
ConstrainedGraph string `json:"constrained_graph"`
|
||||
ConstrainedWithEval string `json:"constrained_with_eval,omitempty"`
|
||||
EvalOnly string `json:"eval_only,omitempty"`
|
||||
ConstrainedEvalNet string `json:"constrained_eval_net,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
vocabSize = flag.Int("vocab-size", 128000, "Vocabulary size")
|
||||
iterations = flag.Int("iterations", 500, "Benchmark iterations")
|
||||
warmup = flag.Int("warmup", 50, "Warmup iterations")
|
||||
withEval = flag.Bool("eval", true, "Measure ApplyMask with mlx.Eval")
|
||||
gbnfPath = flag.String("gbnf", "", "GBNF grammar file for llama.cpp")
|
||||
schemaPath = flag.String("schema", "", "JSON Schema file for grammar constraints")
|
||||
ebnfPath = flag.String("ebnf", "", "EBNF grammar file for grammar constraints")
|
||||
startRule = flag.String("start", "root", "Start rule for EBNF")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *vocabSize <= 0 || *iterations <= 0 || *warmup < 0 {
|
||||
fmt.Fprintln(os.Stderr, "invalid flags")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
vocab := createVocab(*vocabSize)
|
||||
|
||||
if *schemaPath != "" && *ebnfPath != "" {
|
||||
fmt.Fprintln(os.Stderr, "only one of -schema or -ebnf may be set")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
var constrainedSource string
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
switch {
|
||||
case *schemaPath != "":
|
||||
data, readErr := os.ReadFile(*schemaPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read schema: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = schema.Grammar(string(data))
|
||||
constrainedSource = "schema:" + *schemaPath
|
||||
case *ebnfPath != "":
|
||||
data, readErr := os.ReadFile(*ebnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read ebnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(string(data), *startRule)
|
||||
constrainedSource = "ebnf:" + *ebnfPath
|
||||
default:
|
||||
compiled, err = grammar.JSONGrammar()
|
||||
constrainedSource = "json"
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "grammar: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
engine, err := grammar.NewEngine(compiled, vocab)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "engine: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(*vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
masked := engine.ApplyMask(logits)
|
||||
if *withEval {
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
graphAvg := measure(*iterations, func() {
|
||||
_ = engine.ApplyMask(logits)
|
||||
})
|
||||
|
||||
var evalAvg time.Duration
|
||||
var evalOnlyAvg time.Duration
|
||||
if *withEval {
|
||||
evalOnlyAvg = measure(*iterations, func() {
|
||||
baseline := mlx.MulScalar(logits, 1)
|
||||
mlx.Eval(baseline)
|
||||
baseline.Free()
|
||||
})
|
||||
|
||||
evalAvg = measure(*iterations, func() {
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
})
|
||||
}
|
||||
|
||||
vocabIDs := make([]uint32, *vocabSize)
|
||||
for i := range vocabIDs {
|
||||
vocabIDs[i] = uint32(i)
|
||||
}
|
||||
eogTokens := []int32{0}
|
||||
|
||||
gbnf := jsonGBNF
|
||||
llamaSource := "json"
|
||||
if *gbnfPath != "" {
|
||||
data, readErr := os.ReadFile(*gbnfPath)
|
||||
if readErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "read gbnf: %v\n", readErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
gbnf = string(data)
|
||||
llamaSource = *gbnfPath
|
||||
}
|
||||
|
||||
llamaGrammar := llama.NewGrammar(gbnf, vocabIDs, vocab, eogTokens)
|
||||
if llamaGrammar == nil {
|
||||
fmt.Fprintln(os.Stderr, "llama grammar initialization failed")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer llamaGrammar.Free()
|
||||
|
||||
llamaTokens := make([]llama.TokenData, *vocabSize)
|
||||
|
||||
for i := 0; i < *warmup; i++ {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
}
|
||||
|
||||
llamaAvg := measure(*iterations, func() {
|
||||
for j := range llamaTokens {
|
||||
llamaTokens[j].Logit = 1.0
|
||||
}
|
||||
llamaGrammar.Apply(llamaTokens)
|
||||
})
|
||||
|
||||
out := result{
|
||||
vocabSize: *vocabSize,
|
||||
Iterations: *iterations,
|
||||
Warmup: *warmup,
|
||||
LlamaApply: llamaAvg.String(),
|
||||
ConstrainedGraph: graphAvg.String(),
|
||||
ConstrainedSource: constrainedSource,
|
||||
LlamaSource: llamaSource,
|
||||
}
|
||||
if *withEval {
|
||||
out.ConstrainedWithEval = evalAvg.String()
|
||||
out.EvalOnly = evalOnlyAvg.String()
|
||||
if evalAvg > evalOnlyAvg {
|
||||
out.ConstrainedEvalNet = (evalAvg - evalOnlyAvg).String()
|
||||
} else {
|
||||
out.ConstrainedEvalNet = "0s"
|
||||
}
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
if err := enc.Encode(out); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "encode: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func measure(iterations int, fn func()) time.Duration {
|
||||
start := time.Now()
|
||||
for i := 0; i < iterations; i++ {
|
||||
fn()
|
||||
}
|
||||
return time.Since(start) / time.Duration(iterations)
|
||||
}
|
||||
|
||||
func createVocab(size int) []string {
|
||||
vocab := make([]string, size)
|
||||
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < size {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
for i := len(jsonTokens); i < size; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
320
x/grammar/compiled.go
Normal file
320
x/grammar/compiled.go
Normal file
@@ -0,0 +1,320 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Grammar is the compiled form of an EBNF grammar.
|
||||
// It contains terminals, parse tables, and the start state.
|
||||
// Use ParseEBNF or JSONGrammar to create a Grammar.
|
||||
type Grammar struct {
|
||||
// The underlying pda
|
||||
pda *pda
|
||||
|
||||
// Compiled terminal matcher
|
||||
matcher *terminalMatcher
|
||||
}
|
||||
|
||||
// ParseEBNF compiles an EBNF grammar string into a Grammar.
|
||||
// startRule is the name of the start rule (e.g., "root", "json").
|
||||
func ParseEBNF(ebnf string, startRule string) (*Grammar, error) {
|
||||
pda, err := compileString(ebnf, startRule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile EBNF: %w", err)
|
||||
}
|
||||
|
||||
matcher, err := compileTerminalsStrict(pda)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile terminals: %w", err)
|
||||
}
|
||||
|
||||
return &Grammar{
|
||||
pda: pda,
|
||||
matcher: matcher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JSONGrammar returns the compiled JSON grammar.
|
||||
// This is a convenience wrapper for ParseEBNF(JSONGrammarEBNF, "json").
|
||||
func JSONGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// JSONObjectGrammar returns a JSON grammar that only allows objects at the top level.
|
||||
// Use this when you want to ensure the output is a JSON object (starts with {).
|
||||
func JSONObjectGrammar() (*Grammar, error) {
|
||||
return ParseEBNF(JSONObjectGrammarEBNF, "json")
|
||||
}
|
||||
|
||||
// compileTerminalsStrict builds a matcher that properly handles:
|
||||
// - Escaped literals ("\n", \"", \uXXXX)
|
||||
// - Unicode ranges (rune-based, not byte-based)
|
||||
// - Rejects unsupported patterns with an error (no silent fallback)
|
||||
func compileTerminalsStrict(pda *pda) (*terminalMatcher, error) {
|
||||
m := &terminalMatcher{
|
||||
literalTrie: &trieNode{terminalID: -1},
|
||||
ranges: make([]terminal, 0),
|
||||
terminals: make([]terminal, 0, len(pda.Terminals)),
|
||||
patternToID: make(map[string]int),
|
||||
}
|
||||
|
||||
// Track which pattern produced each unescaped value for collision detection
|
||||
unescapedSource := make(map[string]string) // unescaped -> original pattern
|
||||
|
||||
for i, pattern := range pda.Terminals {
|
||||
terminal, err := parseTerminalPattern(pattern, i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("terminal %q: %w", pattern, err)
|
||||
}
|
||||
|
||||
if terminal.Type == terminalLiteral {
|
||||
// Use the unescaped pattern for trie matching
|
||||
m.addLiteralToTrie(terminal.Unescaped, i)
|
||||
|
||||
// Detect collisions between literals that unescape to the same value
|
||||
if existingPattern, exists := unescapedSource[terminal.Unescaped]; exists {
|
||||
if existingPattern != pattern {
|
||||
return nil, fmt.Errorf("collision: patterns %q and %q both unescape to %q",
|
||||
existingPattern, pattern, terminal.Unescaped)
|
||||
}
|
||||
} else {
|
||||
unescapedSource[terminal.Unescaped] = pattern
|
||||
}
|
||||
} else if terminal.Type == terminalRange {
|
||||
m.ranges = append(m.ranges, terminal)
|
||||
}
|
||||
|
||||
m.terminals = append(m.terminals, terminal)
|
||||
m.patternToID[pattern] = i
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// parseTerminalPattern parses a terminal pattern and returns a terminal.
|
||||
// Supports:
|
||||
// - Literal strings (with escape sequences)
|
||||
// - Character ranges [X-Y] (unicode-aware)
|
||||
func parseTerminalPattern(pattern string, id int) (terminal, error) {
|
||||
if len(pattern) == 0 {
|
||||
return terminal{}, fmt.Errorf("empty pattern")
|
||||
}
|
||||
|
||||
// Check for range pattern: [X-Y]
|
||||
if isUnicodeRangePattern(pattern) {
|
||||
lowRune, highRune, err := parseUnicodeRange(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, err
|
||||
}
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalRange,
|
||||
Pattern: pattern,
|
||||
Unescaped: pattern,
|
||||
LowRune: lowRune,
|
||||
HighRune: highRune,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// It's a literal - unescape it
|
||||
unescaped, err := unescapeLiteral(pattern)
|
||||
if err != nil {
|
||||
return terminal{}, fmt.Errorf("invalid escape sequence: %w", err)
|
||||
}
|
||||
|
||||
return terminal{
|
||||
ID: id,
|
||||
Type: terminalLiteral,
|
||||
Pattern: pattern,
|
||||
Unescaped: unescaped,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isUnicodeRangePattern checks if pattern is a character range like [a-z] or [\u0000-\uFFFF]
|
||||
func isUnicodeRangePattern(pattern string) bool {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return false
|
||||
}
|
||||
// Find the dash that separates low-high
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
dashIdx := strings.Index(inner, "-")
|
||||
// Handle escaped dash at start
|
||||
if dashIdx <= 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// parseUnicodeRange parses [X-Y] into low and high runes
|
||||
func parseUnicodeRange(pattern string) (rune, rune, error) {
|
||||
if len(pattern) < 5 || pattern[0] != '[' || pattern[len(pattern)-1] != ']' {
|
||||
return 0, 0, fmt.Errorf("invalid range pattern")
|
||||
}
|
||||
|
||||
inner := pattern[1 : len(pattern)-1]
|
||||
|
||||
// Simple case: [a-z] where a and z are single chars
|
||||
if len(inner) == 3 && inner[1] == '-' {
|
||||
return rune(inner[0]), rune(inner[2]), nil
|
||||
}
|
||||
|
||||
// Handle escaped characters like [\u0000-\uFFFF]
|
||||
dashIdx := findRangeDash(inner)
|
||||
if dashIdx < 0 {
|
||||
return 0, 0, fmt.Errorf("no dash in range")
|
||||
}
|
||||
|
||||
lowStr := inner[:dashIdx]
|
||||
highStr := inner[dashIdx+1:]
|
||||
|
||||
lowRune, err := parseRune(lowStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid low bound: %w", err)
|
||||
}
|
||||
|
||||
highRune, err := parseRune(highStr)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid high bound: %w", err)
|
||||
}
|
||||
|
||||
if lowRune > highRune {
|
||||
return 0, 0, fmt.Errorf("low bound > high bound")
|
||||
}
|
||||
|
||||
return lowRune, highRune, nil
|
||||
}
|
||||
|
||||
// findRangeDash finds the dash separating low-high in a range pattern
|
||||
func findRangeDash(inner string) int {
|
||||
i := 0
|
||||
for i < len(inner) {
|
||||
if inner[i] == '\\' && i+1 < len(inner) {
|
||||
// Skip escape sequence
|
||||
if inner[i+1] == 'u' && i+6 <= len(inner) {
|
||||
i += 6 // \uXXXX
|
||||
} else {
|
||||
i += 2 // \n, \t, etc.
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inner[i] == '-' && i > 0 {
|
||||
return i
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// parseRune parses a single rune from a string (handles escapes)
|
||||
func parseRune(s string) (rune, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, fmt.Errorf("empty rune")
|
||||
}
|
||||
|
||||
// Handle escape sequences
|
||||
if s[0] == '\\' {
|
||||
if len(s) < 2 {
|
||||
return 0, fmt.Errorf("incomplete escape")
|
||||
}
|
||||
switch s[1] {
|
||||
case 'n':
|
||||
return '\n', nil
|
||||
case 't':
|
||||
return '\t', nil
|
||||
case 'r':
|
||||
return '\r', nil
|
||||
case '\\':
|
||||
return '\\', nil
|
||||
case '"':
|
||||
return '"', nil
|
||||
case '\'':
|
||||
return '\'', nil
|
||||
case 'u':
|
||||
if len(s) < 6 {
|
||||
return 0, fmt.Errorf("incomplete unicode escape")
|
||||
}
|
||||
val, err := strconv.ParseInt(s[2:6], 16, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid unicode escape: %w", err)
|
||||
}
|
||||
return rune(val), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown escape: \\%c", s[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Plain character
|
||||
r, _ := utf8.DecodeRuneInString(s)
|
||||
if r == utf8.RuneError {
|
||||
return 0, fmt.Errorf("invalid utf8")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// unescapeLiteral unescapes a literal pattern string
|
||||
func unescapeLiteral(pattern string) (string, error) {
|
||||
// Try strconv.Unquote if it looks quoted
|
||||
if len(pattern) >= 2 && pattern[0] == '"' && pattern[len(pattern)-1] == '"' {
|
||||
unquoted, err := strconv.Unquote(pattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return unquoted, nil
|
||||
}
|
||||
|
||||
// If no backslashes, return as-is
|
||||
if !strings.Contains(pattern, "\\") {
|
||||
return pattern, nil
|
||||
}
|
||||
|
||||
// Manual unescape
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
for i < len(pattern) {
|
||||
if pattern[i] == '\\' && i+1 < len(pattern) {
|
||||
switch pattern[i+1] {
|
||||
case 'n':
|
||||
result.WriteByte('\n')
|
||||
i += 2
|
||||
case 't':
|
||||
result.WriteByte('\t')
|
||||
i += 2
|
||||
case 'r':
|
||||
result.WriteByte('\r')
|
||||
i += 2
|
||||
case '\\':
|
||||
result.WriteByte('\\')
|
||||
i += 2
|
||||
case '"':
|
||||
result.WriteByte('"')
|
||||
i += 2
|
||||
case '\'':
|
||||
result.WriteByte('\'')
|
||||
i += 2
|
||||
case 'u':
|
||||
if i+6 <= len(pattern) {
|
||||
val, err := strconv.ParseInt(pattern[i+2:i+6], 16, 32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid unicode escape at %d", i)
|
||||
}
|
||||
result.WriteRune(rune(val))
|
||||
i += 6
|
||||
} else {
|
||||
return "", fmt.Errorf("incomplete unicode escape at %d", i)
|
||||
}
|
||||
default:
|
||||
// Reject unknown escape sequences
|
||||
return "", fmt.Errorf("unknown escape sequence: \\%c at position %d", pattern[i+1], i)
|
||||
}
|
||||
} else {
|
||||
result.WriteByte(pattern[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
329
x/grammar/engine.go
Normal file
329
x/grammar/engine.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// maskCache provides LRU caching for computed masks.
|
||||
type maskCache struct {
|
||||
cache map[uint64]*list.Element
|
||||
order *list.List
|
||||
maxSize int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type maskEntry struct {
|
||||
sig uint64
|
||||
mask *mlx.Array
|
||||
}
|
||||
|
||||
// newMaskCache creates a new mask cache with the given max size
|
||||
// If maxSize <= 0, the cache is disabled (Get/Put are no-ops)
|
||||
func newMaskCache(maxSize int) *maskCache {
|
||||
if maxSize <= 0 {
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: 0, // Signals disabled
|
||||
}
|
||||
}
|
||||
return &maskCache{
|
||||
cache: make(map[uint64]*list.Element),
|
||||
order: list.New(),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a cached mask, returning nil if not found.
|
||||
// Updates LRU order on cache hit.
|
||||
func (c *maskCache) get(sig uint64) *mlx.Array {
|
||||
if c.maxSize <= 0 {
|
||||
return nil // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if elem, ok := c.cache[sig]; ok {
|
||||
c.order.MoveToFront(elem)
|
||||
return elem.Value.(*maskEntry).mask
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// put stores a mask in the cache with LRU eviction.
|
||||
func (c *maskCache) put(sig uint64, mask *mlx.Array) {
|
||||
if c.maxSize <= 0 {
|
||||
return // Cache disabled
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if elem, exists := c.cache[sig]; exists {
|
||||
c.order.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest if at capacity (safe since maxSize > 0)
|
||||
if c.order.Len() >= c.maxSize {
|
||||
oldest := c.order.Back()
|
||||
if oldest != nil {
|
||||
entry := oldest.Value.(*maskEntry)
|
||||
entry.mask.Free()
|
||||
delete(c.cache, entry.sig)
|
||||
c.order.Remove(oldest)
|
||||
}
|
||||
}
|
||||
|
||||
elem := c.order.PushFront(&maskEntry{sig: sig, mask: mask})
|
||||
c.cache[sig] = elem
|
||||
}
|
||||
|
||||
// clear frees all cached masks.
|
||||
func (c *maskCache) clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for elem := c.order.Front(); elem != nil; elem = elem.Next() {
|
||||
elem.Value.(*maskEntry).mask.Free()
|
||||
}
|
||||
c.cache = make(map[uint64]*list.Element)
|
||||
c.order.Init()
|
||||
}
|
||||
|
||||
// size returns the number of cached masks.
|
||||
func (c *maskCache) size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.cache)
|
||||
}
|
||||
|
||||
// Engine applies grammar constraints to model outputs using MLX.
|
||||
// It uses a token→pda bridge for strict correctness with arbitrary BPE tokens.
|
||||
type Engine struct {
|
||||
// The compiled grammar
|
||||
grammar *Grammar
|
||||
|
||||
// bridge for token validation
|
||||
bridge *bridge
|
||||
analyzer *analyzer
|
||||
|
||||
// Current parser state (configSet for nondeterminism)
|
||||
configSet *configSet
|
||||
|
||||
// Token vocabulary from the model
|
||||
vocab []string
|
||||
tokenToID map[string]int // O(1) lookup for AcceptString
|
||||
|
||||
// Mask cache: configSig → valid token mask (LRU)
|
||||
maskCache *maskCache
|
||||
|
||||
// Cached negative infinity mask for invalid tokens
|
||||
negInfMask *mlx.Array
|
||||
|
||||
// Threshold for comparison (0.5 since mask values are 0 or 1)
|
||||
threshold *mlx.Array
|
||||
|
||||
// Vocabulary size
|
||||
vocabSize int32
|
||||
|
||||
// Reusable buffers for candidate filtering (avoid allocations)
|
||||
candidateMark []bool // indexed by tokenID, true if in candidate set
|
||||
touched []int // tokenIDs that were marked (for reset)
|
||||
dpCandidates []int // candidates requiring DP validation
|
||||
|
||||
// Reusable buffer for valid token indices (for GPU scatter)
|
||||
validTokenIDs []int32
|
||||
}
|
||||
|
||||
// EngineOption configures an Engine
|
||||
type EngineOption func(*Engine)
|
||||
|
||||
// WithMaskCacheSize sets the mask cache size (default 1024)
|
||||
func WithMaskCacheSize(size int) EngineOption {
|
||||
return func(e *Engine) {
|
||||
e.maskCache = newMaskCache(size)
|
||||
}
|
||||
}
|
||||
|
||||
// NewEngine creates a new constrained decoding engine.
|
||||
// grammar is the compiled grammar (use JSONGrammar() or ParseEBNF()).
|
||||
// vocab is the list of token strings from the model's tokenizer.
|
||||
func NewEngine(grammar *Grammar, vocab []string, opts ...EngineOption) (*Engine, error) {
|
||||
if grammar == nil {
|
||||
return nil, fmt.Errorf("grammar cannot be nil")
|
||||
}
|
||||
|
||||
// Build analyzer and bridge
|
||||
analyzer := newAnalyzer(vocab, grammar.matcher)
|
||||
bridge := newBridge(grammar.pda, analyzer)
|
||||
|
||||
// Initialize config set from pda initial state
|
||||
initialConfig := newConfigSet(grammar.pda.StartState, nil)
|
||||
|
||||
// Build token lookup map for O(1) AcceptString
|
||||
tokenToID := make(map[string]int, len(vocab))
|
||||
for i, tok := range vocab {
|
||||
tokenToID[tok] = i
|
||||
}
|
||||
|
||||
e := &Engine{
|
||||
grammar: grammar,
|
||||
bridge: bridge,
|
||||
analyzer: analyzer,
|
||||
configSet: initialConfig,
|
||||
vocab: vocab,
|
||||
tokenToID: tokenToID,
|
||||
maskCache: newMaskCache(1024),
|
||||
vocabSize: int32(len(vocab)),
|
||||
candidateMark: make([]bool, len(vocab)),
|
||||
touched: make([]int, 0, 10000),
|
||||
validTokenIDs: make([]int32, 0, 10000),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(e)
|
||||
}
|
||||
|
||||
// Create the negative infinity mask and threshold
|
||||
if e.vocabSize > 0 {
|
||||
e.negInfMask = mlx.FullDtype(float32(math.Inf(-1)), mlx.DtypeFloat32, e.vocabSize)
|
||||
mlx.Keep(e.negInfMask)
|
||||
|
||||
e.threshold = mlx.NewScalarArray(0.5)
|
||||
mlx.Keep(e.threshold)
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// ApplyMask applies grammar constraints to logits.
|
||||
// Returns logits with invalid tokens set to -inf.
|
||||
func (e *Engine) ApplyMask(logits *mlx.Array) *mlx.Array {
|
||||
sig := e.configSet.signature()
|
||||
|
||||
// Check state cache first (exact state match)
|
||||
if cached := e.maskCache.get(sig); cached != nil {
|
||||
condition := mlx.GreaterEqual(cached, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Compute valid tokens using candidate filtering:
|
||||
// 1. Get valid terminal IDs from current grammar state
|
||||
// 2. Get candidate tokens (those that START with valid terminals)
|
||||
// 3. Run DP validation only on candidates
|
||||
// This is O(candidates) instead of O(vocab_size)
|
||||
|
||||
validTerminalIDs := e.bridge.validTerminalIDs(e.configSet)
|
||||
|
||||
// Use pre-partitioned token groups for fast candidate building
|
||||
// This eliminates per-token branching - just direct slice appends
|
||||
e.validTokenIDs = e.validTokenIDs[:0]
|
||||
e.dpCandidates = e.dpCandidates[:0]
|
||||
e.touched = e.touched[:0]
|
||||
|
||||
for _, tid := range validTerminalIDs {
|
||||
groups := e.analyzer.terminalGroups(tid)
|
||||
|
||||
// Direct append of exact matches (no per-token check needed)
|
||||
e.validTokenIDs = append(e.validTokenIDs, groups.ExactMatches...)
|
||||
|
||||
// Collect DP candidates (may have duplicates across terminals)
|
||||
for _, tokenID := range groups.DPCandidates {
|
||||
if !e.candidateMark[tokenID] {
|
||||
e.candidateMark[tokenID] = true
|
||||
e.dpCandidates = append(e.dpCandidates, tokenID)
|
||||
e.touched = append(e.touched, tokenID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset marks for next call
|
||||
for _, id := range e.touched {
|
||||
e.candidateMark[id] = false
|
||||
}
|
||||
|
||||
for _, tokenID := range e.dpCandidates {
|
||||
if e.bridge.IsTokenValid(tokenID, e.configSet) {
|
||||
e.validTokenIDs = append(e.validTokenIDs, int32(tokenID))
|
||||
}
|
||||
}
|
||||
|
||||
// Create and cache the mask on GPU using index updates
|
||||
mask := mlx.Zeros([]int32{e.vocabSize})
|
||||
if len(e.validTokenIDs) > 0 {
|
||||
indices := mlx.NewArrayInt32(e.validTokenIDs, []int32{int32(len(e.validTokenIDs))})
|
||||
values := mlx.Ones(int32(len(e.validTokenIDs)))
|
||||
mask = mlx.PutAlongAxis(mask, indices, values, 0)
|
||||
}
|
||||
mlx.Keep(mask)
|
||||
|
||||
// Cache by state signature
|
||||
e.maskCache.put(sig, mask)
|
||||
|
||||
// Apply mask
|
||||
condition := mlx.GreaterEqual(mask, e.threshold)
|
||||
return mlx.Where(condition, logits, e.negInfMask)
|
||||
}
|
||||
|
||||
// Accept processes a token and updates the parser state.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) Accept(tokenID int) bool {
|
||||
if tokenID < 0 || tokenID >= len(e.vocab) {
|
||||
return false
|
||||
}
|
||||
|
||||
newConfig := e.bridge.acceptToken(tokenID, e.configSet)
|
||||
if newConfig == nil {
|
||||
return false
|
||||
}
|
||||
e.configSet = newConfig
|
||||
return true
|
||||
}
|
||||
|
||||
// AcceptString processes a token string directly.
|
||||
// Returns true if the token was valid and accepted.
|
||||
func (e *Engine) AcceptString(token string) bool {
|
||||
if id, ok := e.tokenToID[token]; ok {
|
||||
return e.Accept(id)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsComplete returns true if the current state is accepting.
|
||||
func (e *Engine) IsComplete() bool {
|
||||
return e.bridge.isAccepting(e.configSet)
|
||||
}
|
||||
|
||||
// Reset resets the engine to initial state.
|
||||
func (e *Engine) Reset() {
|
||||
e.configSet = newConfigSet(e.grammar.pda.StartState, nil)
|
||||
}
|
||||
|
||||
// validTokens returns the indices of tokens that are currently valid.
|
||||
func (e *Engine) validTokens() []int {
|
||||
return e.bridge.validTokens(e.configSet)
|
||||
}
|
||||
|
||||
// validTerminals returns the valid terminal patterns from the current state.
|
||||
func (e *Engine) validTerminals() []string {
|
||||
return e.bridge.validTerminals(e.configSet)
|
||||
}
|
||||
|
||||
// Close releases MLX resources.
|
||||
func (e *Engine) Close() {
|
||||
if e.maskCache != nil {
|
||||
e.maskCache.clear()
|
||||
}
|
||||
if e.negInfMask != nil {
|
||||
e.negInfMask.Free()
|
||||
}
|
||||
if e.threshold != nil {
|
||||
e.threshold.Free()
|
||||
}
|
||||
}
|
||||
414
x/grammar/engine_benchmark_test.go
Normal file
414
x/grammar/engine_benchmark_test.go
Normal file
@@ -0,0 +1,414 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newBenchEngine creates a JSON engine for benchmarks
|
||||
func newBenchEngine(b *testing.B, vocab []string) *Engine {
|
||||
b.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Vocabulary sizes to test (matching real models)
|
||||
var vocabSizes = []int{
|
||||
32000, // Llama 2
|
||||
128000, // Llama 3
|
||||
256000, // Large models
|
||||
}
|
||||
|
||||
// createBenchVocabN creates a vocabulary of size n with realistic token distribution
|
||||
func createBenchVocabN(n int) []string {
|
||||
vocab := make([]string, n)
|
||||
|
||||
// JSON structural tokens (first 20)
|
||||
jsonTokens := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"true", "false", "null",
|
||||
" ", "\n", "\t", "\r",
|
||||
"\"", "'",
|
||||
}
|
||||
for i, t := range jsonTokens {
|
||||
if i < n {
|
||||
vocab[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
// String tokens (indices 20-1000)
|
||||
stringIdx := 20
|
||||
for i := 0; i < 980 && stringIdx+i < n; i++ {
|
||||
vocab[stringIdx+i] = fmt.Sprintf("\"token%d\"", i)
|
||||
}
|
||||
|
||||
// Number tokens (indices 1000-2000)
|
||||
numberIdx := 1000
|
||||
for i := 0; i < 1000 && numberIdx+i < n; i++ {
|
||||
vocab[numberIdx+i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
|
||||
// Generic tokens (rest)
|
||||
for i := 2000; i < n; i++ {
|
||||
vocab[i] = fmt.Sprintf("tok%d", i)
|
||||
}
|
||||
|
||||
return vocab
|
||||
}
|
||||
|
||||
// ============ Core Performance Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMask_32k measures mask application with 32k vocab
|
||||
func BenchmarkApplyMask_32k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 32000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_128k measures mask application with 128k vocab
|
||||
func BenchmarkApplyMask_128k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 128000)
|
||||
}
|
||||
|
||||
// BenchmarkApplyMask_256k measures mask application with 256k vocab
|
||||
func BenchmarkApplyMask_256k(b *testing.B) {
|
||||
benchmarkApplyMask(b, 256000)
|
||||
}
|
||||
|
||||
func benchmarkApplyMask(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(vocabSize), "vocab_size")
|
||||
}
|
||||
|
||||
// ============ state-Dependent Benchmarks ============
|
||||
|
||||
// BenchmarkApplyMaskAfterBrace measures mask after { (STRING or } valid)
|
||||
func BenchmarkApplyMaskAfterBrace(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
e.AcceptString("{")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkApplyMaskMidObject measures mask in middle of object
|
||||
func BenchmarkApplyMaskMidObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// state: {"key": _value_
|
||||
e.AcceptString("{")
|
||||
e.AcceptString("\"key\"")
|
||||
e.AcceptString(":")
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Token Sequence Benchmarks ============
|
||||
|
||||
// BenchmarkSequence_SimpleObject benchmarks {"key": "value"}
|
||||
func BenchmarkSequence_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_NestedObject benchmarks {"a": {"b": {"c": 1}}}
|
||||
func BenchmarkSequence_NestedObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{", "\"a\"", ":", "{", "\"b\"", ":", "{", "\"c\"", ":", "1", "}", "}", "}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_LargeArray benchmarks [1, 2, 3, ..., 100]
|
||||
func BenchmarkSequence_LargeArray(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Build sequence: [1, 2, 3, ..., 50]
|
||||
sequence := []string{"["}
|
||||
for i := 1; i <= 50; i++ {
|
||||
sequence = append(sequence, fmt.Sprintf("%d", i))
|
||||
if i < 50 {
|
||||
sequence = append(sequence, ",")
|
||||
}
|
||||
}
|
||||
sequence = append(sequence, "]")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// BenchmarkSequence_MixedTypes benchmarks complex mixed-type object
|
||||
func BenchmarkSequence_MixedTypes(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{
|
||||
"{",
|
||||
"\"name\"", ":", "\"test\"", ",",
|
||||
"\"count\"", ":", "42", ",",
|
||||
"\"enabled\"", ":", "true", ",",
|
||||
"\"data\"", ":", "null", ",",
|
||||
"\"items\"", ":", "[", "1", ",", "2", ",", "3", "]",
|
||||
"}",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
|
||||
// ============ Component Benchmarks ============
|
||||
|
||||
// BenchmarkValidInputs measures pda valid input computation
|
||||
func BenchmarkValidInputs(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.validTerminals()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStateTransition measures pda state transition
|
||||
func BenchmarkStateTransition(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
for _, token := range sequence {
|
||||
e.AcceptString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConstrainedGrammar_128k benchmarks x/grammar (graph only, no eval).
|
||||
func BenchmarkConstrainedGrammar_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // Graph only, no eval
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNewEngine measures one-time engine initialization.
|
||||
func BenchmarkNewEngine_32k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkNewEngine_128k(b *testing.B) {
|
||||
benchmarkNewEngine(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkNewEngine(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e := newBenchEngine(b, vocab)
|
||||
e.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Memory Benchmarks ============
|
||||
|
||||
func BenchmarkMemoryAllocs_32k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 32000)
|
||||
}
|
||||
|
||||
func BenchmarkMemoryAllocs_128k(b *testing.B) {
|
||||
benchmarkMemoryAllocs(b, 128000)
|
||||
}
|
||||
|
||||
func benchmarkMemoryAllocs(b *testing.B, vocabSize int) {
|
||||
vocab := createBenchVocabN(vocabSize)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(vocabSize))
|
||||
mlx.Keep(logits)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ No-Eval Benchmarks (simulating LLM graph integration) ============
|
||||
|
||||
// BenchmarkApplyMaskNoEval_128k measures mask generation WITHOUT GPU sync
|
||||
// This simulates adding mask to LLM compute graph
|
||||
func BenchmarkApplyMaskNoEval_128k(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Warm up
|
||||
for i := 0; i < 10; i++ {
|
||||
masked := e.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = e.ApplyMask(logits) // No Eval - just build graph
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSequenceNoEval simulates real LLM usage - build graph, eval once at end
|
||||
func BenchmarkSequenceNoEval_SimpleObject(b *testing.B) {
|
||||
vocab := createBenchVocabN(128000)
|
||||
e := newBenchEngine(b, vocab)
|
||||
defer e.Close()
|
||||
|
||||
logits := mlx.Ones(int32(128000))
|
||||
mlx.Keep(logits)
|
||||
|
||||
sequence := []string{"{", "\"key\"", ":", "\"value\"", "}"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.Reset()
|
||||
var lastMasked *mlx.Array
|
||||
for _, token := range sequence {
|
||||
lastMasked = e.ApplyMask(logits) // Build graph only
|
||||
e.AcceptString(token)
|
||||
}
|
||||
mlx.Eval(lastMasked) // Single eval at end
|
||||
}
|
||||
|
||||
b.ReportMetric(float64(len(sequence)), "tokens")
|
||||
}
|
||||
689
x/grammar/engine_test.go
Normal file
689
x/grammar/engine_test.go
Normal file
@@ -0,0 +1,689 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// newTestEngine creates a JSON engine for testing
|
||||
func newTestEngine(t testing.TB, vocab []string) *Engine {
|
||||
t.Helper()
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
e, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Mock vocabulary for testing
|
||||
func testVocab() []string {
|
||||
return []string{
|
||||
"{", // 0: object start
|
||||
"}", // 1: object end
|
||||
"[", // 2: array start
|
||||
"]", // 3: array end
|
||||
":", // 4: colon
|
||||
",", // 5: comma
|
||||
"\"key\"", // 6: string (quoted)
|
||||
"\"val\"", // 7: string (quoted)
|
||||
"123", // 8: number
|
||||
"-42.5", // 9: number
|
||||
"true", // 10: boolean
|
||||
"false", // 11: boolean
|
||||
"null", // 12: null
|
||||
" ", // 13: whitespace (should be ignored)
|
||||
"\n", // 14: whitespace (should be ignored)
|
||||
"subword", // 15: bare word (NOT valid JSON - requires quotes)
|
||||
"hello", // 16: bare word (NOT valid JSON - requires quotes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEngine(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != int32(len(vocab)) {
|
||||
t.Errorf("vocabSize = %d, want %d", e.vocabSize, len(vocab))
|
||||
}
|
||||
|
||||
// Verify grammar is set
|
||||
if e.grammar == nil {
|
||||
t.Error("grammar should not be nil")
|
||||
}
|
||||
|
||||
// Verify analyzer is set
|
||||
if e.analyzer == nil {
|
||||
t.Error("analyzer should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineValidTokens(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// At start, any value type should be valid
|
||||
validTokens := e.validTokens()
|
||||
|
||||
// Should include object start, array start, strings, numbers, booleans, null
|
||||
// Note: bare words like "subword" and "hello" are NOT valid JSON strings
|
||||
// (JSON strings must be quoted)
|
||||
expectedTokens := map[int]bool{
|
||||
0: true, // {
|
||||
2: true, // [
|
||||
6: true, // "key"
|
||||
7: true, // "val"
|
||||
8: true, // 123
|
||||
9: true, // -42.5
|
||||
10: true, // true
|
||||
11: true, // false
|
||||
12: true, // null
|
||||
}
|
||||
|
||||
// Check that expected tokens are present
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
for idx := range expectedTokens {
|
||||
if !validSet[idx] {
|
||||
t.Errorf("expected token %d (%s) to be valid", idx, vocab[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if validSet[15] || validSet[16] {
|
||||
t.Error("bare words should not be valid JSON at the start state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAccept(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { should work
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
|
||||
// After {, valid tokens should be STRING or }
|
||||
validTokens := e.validTokens()
|
||||
|
||||
validSet := make(map[int]bool)
|
||||
for _, idx := range validTokens {
|
||||
validSet[idx] = true
|
||||
}
|
||||
|
||||
// STRING tokens (indices 6, 7) and } (index 1) should be valid
|
||||
if !validSet[1] {
|
||||
t.Error("} should be valid after {")
|
||||
}
|
||||
if !validSet[6] && !validSet[7] {
|
||||
t.Error("STRING should be valid after { (for keys)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptSequence(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept {"key": "val"}
|
||||
sequence := []int{0, 6, 4, 7, 1} // {, "key", :, "val", }
|
||||
|
||||
for i, tokenID := range sequence {
|
||||
if !e.Accept(tokenID) {
|
||||
t.Fatalf("failed to accept token %d (%s) at position %d",
|
||||
tokenID, vocab[tokenID], i)
|
||||
}
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be in complete state after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineReset(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept some tokens
|
||||
e.Accept(0) // {
|
||||
e.Accept(1) // }
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
|
||||
// Reset
|
||||
e.Reset()
|
||||
|
||||
// Should be back to initial state
|
||||
if e.IsComplete() {
|
||||
t.Error("should not be complete after reset")
|
||||
}
|
||||
|
||||
// Should be able to accept new sequence
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept { after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineInvalidTokenRejection(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept { first
|
||||
if !e.Accept(0) {
|
||||
t.Fatal("should accept {")
|
||||
}
|
||||
|
||||
// Now try to accept [ which is invalid after {
|
||||
// (After {, only STRING or } are valid)
|
||||
if e.Accept(2) { // [
|
||||
t.Error("should not accept [ after { (expecting STRING or })")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineAcceptString(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Accept using string directly
|
||||
if !e.AcceptString("{") {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.AcceptString("\"key\"") {
|
||||
t.Error("should accept string key")
|
||||
}
|
||||
if !e.AcceptString(":") {
|
||||
t.Error("should accept :")
|
||||
}
|
||||
if !e.AcceptString("123") {
|
||||
t.Error("should accept number")
|
||||
}
|
||||
if !e.AcceptString("}") {
|
||||
t.Error("should accept }")
|
||||
}
|
||||
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after valid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONBackslashEscape(t *testing.T) {
|
||||
vocab := []string{`"`, `\`, "n", "a"}
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Valid escape: "\n"
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if !e.AcceptString("n") {
|
||||
t.Fatal("should accept escape code")
|
||||
}
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string end")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after escaped string")
|
||||
}
|
||||
|
||||
// Invalid escape: "\a"
|
||||
e.Reset()
|
||||
if !e.AcceptString(`"`) {
|
||||
t.Fatal("should accept string start")
|
||||
}
|
||||
if !e.AcceptString(`\`) {
|
||||
t.Fatal("should accept escape prefix")
|
||||
}
|
||||
if e.AcceptString("a") {
|
||||
t.Error("should reject invalid escape code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineNegInfMask(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Verify negInfMask exists and has correct shape
|
||||
if e.negInfMask == nil {
|
||||
t.Fatal("negInfMask should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineMaskCache(t *testing.T) {
|
||||
vocab := testVocab()
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
// Create test logits
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
|
||||
// Apply mask - should populate cache
|
||||
_ = e.ApplyMask(logits)
|
||||
|
||||
// Check cache was populated
|
||||
cacheSize := e.maskCache.size()
|
||||
if cacheSize == 0 {
|
||||
t.Error("mask cache should have at least one entry after ApplyMask")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineEmptyVocab(t *testing.T) {
|
||||
e := newTestEngine(t, []string{})
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 0 {
|
||||
t.Errorf("vocabSize = %d, want 0", e.vocabSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineLargeVocab(t *testing.T) {
|
||||
// Create a large vocabulary (simulating real model vocab)
|
||||
vocab := make([]string, 32000)
|
||||
for i := range vocab {
|
||||
vocab[i] = "token"
|
||||
}
|
||||
// Add some actual JSON tokens
|
||||
vocab[0] = "{"
|
||||
vocab[1] = "}"
|
||||
vocab[2] = "["
|
||||
vocab[3] = "]"
|
||||
vocab[4] = ":"
|
||||
vocab[5] = ","
|
||||
vocab[6] = "\"test\""
|
||||
vocab[7] = "123"
|
||||
vocab[8] = "true"
|
||||
vocab[9] = "false"
|
||||
vocab[10] = "null"
|
||||
|
||||
e := newTestEngine(t, vocab)
|
||||
defer e.Close()
|
||||
|
||||
if e.vocabSize != 32000 {
|
||||
t.Errorf("vocabSize = %d, want 32000", e.vocabSize)
|
||||
}
|
||||
|
||||
// Test that it still works correctly
|
||||
if !e.Accept(0) { // {
|
||||
t.Error("should accept {")
|
||||
}
|
||||
if !e.Accept(1) { // }
|
||||
t.Error("should accept }")
|
||||
}
|
||||
if !e.IsComplete() {
|
||||
t.Error("should be complete after {}")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_JSONDecoding tests end-to-end JSON constrained decoding.
|
||||
func TestE2E_JSONDecoding(t *testing.T) {
|
||||
// Create a realistic vocabulary with JSON tokens
|
||||
vocab := []string{
|
||||
// Structural tokens
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
// Keywords
|
||||
"true", "false", "null",
|
||||
// Quoted strings
|
||||
`"name"`, `"value"`, `"items"`, `"count"`, `"enabled"`,
|
||||
`"hello"`, `"world"`, `"test"`,
|
||||
// Numbers
|
||||
"0", "1", "2", "3", "42", "123", "-1", "-42",
|
||||
// Whitespace
|
||||
" ", "\n", "\t",
|
||||
// Multi-terminal tokens (span multiple JSON lexemes)
|
||||
`"key":`, `},`, `],`, `{"`, `["`,
|
||||
// Partial/invalid tokens (should be rejected)
|
||||
"invalid", "foo", "bar",
|
||||
}
|
||||
|
||||
grammar, err := JSONGrammar()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JSON grammar: %v", err)
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
// Simple values
|
||||
{"empty object", []string{"{", "}"}, true},
|
||||
{"empty array", []string{"[", "]"}, true},
|
||||
{"true literal", []string{"true"}, true},
|
||||
{"null literal", []string{"null"}, true},
|
||||
{"number", []string{"42"}, true},
|
||||
{"negative number", []string{"-42"}, true},
|
||||
{"quoted string", []string{`"hello"`}, true},
|
||||
|
||||
// Objects
|
||||
{"simple object", []string{"{", `"name"`, ":", `"value"`, "}"}, true},
|
||||
{"object with single-digit numbers", []string{"{", `"count"`, ":", "1", ",", `"value"`, ":", "2", "}"}, true},
|
||||
{"multi-terminal key", []string{"{", `"key":`, `"value"`, "}"}, true},
|
||||
|
||||
// Arrays
|
||||
{"array of numbers", []string{"[", "42", "]"}, true},
|
||||
{"array of single digits", []string{"[", "1", ",", "2", "]"}, true},
|
||||
{"array of strings", []string{"[", `"hello"`, ",", `"world"`, "]"}, true},
|
||||
{"nested array", []string{"[", "[", "42", "]", "]"}, true},
|
||||
|
||||
// Nested structures
|
||||
{"nested object", []string{"{", `"items"`, ":", "{", `"count"`, ":", "42", "}", "}"}, true},
|
||||
{"object with array", []string{"{", `"items"`, ":", "[", "42", "]", "}"}, true},
|
||||
|
||||
// Invalid sequences
|
||||
{"unclosed object", []string{"{", `"name"`, ":"}, false}, // incomplete
|
||||
{"double comma", []string{"[", "42", ",", ",", "42", "]"}, false}, // invalid
|
||||
{"missing value", []string{"{", `"name"`, ":", "}"}, false}, // missing value
|
||||
{"bare word", []string{"invalid"}, false}, // not valid JSON
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
// Process each token
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return // Already reported error
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
// For invalid sequences, we expect either rejection or incomplete
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_SimpleExpressionGrammar tests a custom expression grammar.
|
||||
func TestE2E_SimpleExpressionGrammar(t *testing.T) {
|
||||
// Simple expression grammar: expr = term { ("+" | "-") term }
|
||||
// term = number | "(" expr ")"
|
||||
// number = digit { digit }
|
||||
// digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"
|
||||
exprGrammar := `
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(exprGrammar, "expr")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse expression grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary for expression tokens
|
||||
vocab := []string{
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
"+", "-", "*", "/",
|
||||
"(", ")",
|
||||
// Multi-digit numbers as single tokens
|
||||
"10", "42", "100", "123",
|
||||
// Invalid tokens
|
||||
"x", "y", "invalid",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single digit", []string{"5"}, true},
|
||||
{"multi-digit", []string{"1", "2", "3"}, true},
|
||||
{"addition", []string{"1", "+", "2"}, true},
|
||||
{"subtraction", []string{"5", "-", "3"}, true},
|
||||
{"multiplication", []string{"2", "*", "3"}, true},
|
||||
{"division", []string{"8", "/", "2"}, true},
|
||||
{"complex expr", []string{"1", "+", "2", "*", "3"}, true},
|
||||
{"parentheses", []string{"(", "1", "+", "2", ")", "*", "3"}, true},
|
||||
{"nested parens", []string{"(", "(", "1", ")", ")"}, true},
|
||||
|
||||
// Invalid
|
||||
{"just operator", []string{"+"}, false},
|
||||
{"double operator", []string{"1", "+", "+", "2"}, false},
|
||||
{"unclosed paren", []string{"(", "1", "+", "2"}, false},
|
||||
{"variable", []string{"x"}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass {
|
||||
if !allAccepted {
|
||||
return
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
} else {
|
||||
if allAccepted && engine.IsComplete() {
|
||||
t.Errorf("expected rejection or incomplete, but parse succeeded")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_IdentifierGrammar tests a grammar with character ranges.
|
||||
func TestE2E_IdentifierGrammar(t *testing.T) {
|
||||
// Identifier grammar using character ranges
|
||||
identGrammar := `
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" | "_" .
|
||||
digit = "0" … "9" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(identGrammar, "ident")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse identifier grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with letters and digits
|
||||
vocab := []string{
|
||||
"a", "b", "c", "x", "y", "z",
|
||||
"A", "B", "C", "X", "Y", "Z",
|
||||
"_",
|
||||
"0", "1", "2", "9",
|
||||
// Multi-char tokens
|
||||
"foo", "bar", "myVar", "test123",
|
||||
// Invalid starting chars
|
||||
"1abc", "123",
|
||||
}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokens []string
|
||||
wantPass bool
|
||||
}{
|
||||
{"single letter", []string{"a"}, true},
|
||||
{"uppercase", []string{"A"}, true},
|
||||
{"underscore", []string{"_"}, true},
|
||||
{"multi-letter", []string{"a", "b", "c"}, true},
|
||||
{"letter then digit", []string{"x", "1"}, true},
|
||||
{"underscore prefix", []string{"_", "a", "1"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine.Reset()
|
||||
|
||||
allAccepted := true
|
||||
for i, token := range tt.tokens {
|
||||
if !engine.AcceptString(token) {
|
||||
if tt.wantPass {
|
||||
t.Errorf("token %d (%q) rejected unexpectedly", i, token)
|
||||
}
|
||||
allAccepted = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantPass && allAccepted && !engine.IsComplete() {
|
||||
t.Errorf("expected complete parse, but not in accepting state")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_UnicodeRange ensures unicode ranges compile and match tokens.
|
||||
func TestE2E_UnicodeRange(t *testing.T) {
|
||||
greekGrammar := `
|
||||
greek = "α" … "ω" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(greekGrammar, "greek")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse unicode grammar: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{"α", "β", "ω", "a"}
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
if !engine.AcceptString("β") {
|
||||
t.Error("should accept beta")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete after single rune")
|
||||
}
|
||||
|
||||
engine.Reset()
|
||||
if engine.AcceptString("a") {
|
||||
t.Error("should reject ASCII outside unicode range")
|
||||
}
|
||||
}
|
||||
|
||||
// TestE2E_NondeterminismPreserved tests that nondeterministic paths are preserved.
|
||||
func TestE2E_NondeterminismPreserved(t *testing.T) {
|
||||
// This grammar has nondeterminism: "ab" could be parsed as
|
||||
// a single token or as two tokens "a" "b"
|
||||
ambiguousGrammar := `
|
||||
start = item item .
|
||||
item = "a" | "b" | "ab" .
|
||||
`
|
||||
|
||||
grammar, err := ParseEBNF(ambiguousGrammar, "start")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse grammar: %v", err)
|
||||
}
|
||||
|
||||
// Vocabulary with both single and combined tokens
|
||||
vocab := []string{"a", "b", "ab"}
|
||||
|
||||
engine, err := NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
// Test: "ab" "a" should be valid (ab as first item, a as second)
|
||||
t.Run("ab then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a after ab")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then ab", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept a")
|
||||
}
|
||||
if !engine.AcceptString("ab") {
|
||||
t.Error("should accept ab after a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a then a", func(t *testing.T) {
|
||||
engine.Reset()
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept first a")
|
||||
}
|
||||
if !engine.AcceptString("a") {
|
||||
t.Error("should accept second a")
|
||||
}
|
||||
if !engine.IsComplete() {
|
||||
t.Error("should be complete")
|
||||
}
|
||||
})
|
||||
}
|
||||
614
x/grammar/grammar.go
Normal file
614
x/grammar/grammar.go
Normal file
@@ -0,0 +1,614 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package grammar provides GPU-accelerated constrained decoding using MLX.
|
||||
// It compiles EBNF grammars to pushdown automata (pda) with precomputed token masks.
|
||||
// For JSON Schema conversion, see the grammar/schema subpackage.
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/ebnf"
|
||||
)
|
||||
|
||||
// stackSymbol represents a symbol that can be pushed onto the pda stack.
|
||||
type stackSymbol int
|
||||
|
||||
const (
|
||||
stackEmpty stackSymbol = iota
|
||||
// Additional stack symbols will be generated per-grammar
|
||||
)
|
||||
|
||||
// state represents a pda state.
|
||||
type state int
|
||||
|
||||
const (
|
||||
stateError state = -1
|
||||
stateStart state = 0
|
||||
stateAccept state = 1
|
||||
// Additional states will be generated per-grammar
|
||||
)
|
||||
|
||||
// transition represents a pda transition.
|
||||
// On input matching Pattern, from FromState with stackTop:
|
||||
// - Move to ToState
|
||||
// - Pop StackPop symbols, push StackPush symbols
|
||||
type transition struct {
|
||||
FromState state
|
||||
stackTop stackSymbol // What must be on stack top (stackEmpty = don't care)
|
||||
Pattern string // Input pattern to match (token or character class)
|
||||
ToState state
|
||||
StackPop int // Number of symbols to pop
|
||||
StackPush []stackSymbol // Symbols to push (in order, first pushed first)
|
||||
}
|
||||
|
||||
// pda represents a compiled pushdown automaton.
|
||||
type pda struct {
|
||||
States int // Total number of states
|
||||
StackSymbols int // Total number of stack symbols
|
||||
StartState state // Initial state
|
||||
AcceptStates map[state]bool // Set of accepting states
|
||||
Transitions map[state][]transition // Transitions indexed by from-state
|
||||
|
||||
// For token-level matching
|
||||
Terminals []string // All terminal symbols (patterns to match)
|
||||
}
|
||||
|
||||
// newPDA creates an empty pda.
|
||||
func newPDA() *pda {
|
||||
return &pda{
|
||||
States: 2, // Error and Start
|
||||
StackSymbols: 1, // Empty
|
||||
StartState: stateStart,
|
||||
AcceptStates: make(map[state]bool),
|
||||
Transitions: make(map[state][]transition),
|
||||
Terminals: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// addState adds a new state and returns its ID.
|
||||
func (p *pda) addState() state {
|
||||
s := state(p.States)
|
||||
p.States++
|
||||
return s
|
||||
}
|
||||
|
||||
// addStackSymbol adds a new stack symbol and returns its ID.
|
||||
func (p *pda) addStackSymbol() stackSymbol {
|
||||
s := stackSymbol(p.StackSymbols)
|
||||
p.StackSymbols++
|
||||
return s
|
||||
}
|
||||
|
||||
// addTransition adds a transition to the pda.
|
||||
func (p *pda) addTransition(t transition) {
|
||||
p.Transitions[t.FromState] = append(p.Transitions[t.FromState], t)
|
||||
}
|
||||
|
||||
// addTerminal registers a terminal pattern and returns its index.
|
||||
func (p *pda) addTerminal(pattern string) int {
|
||||
for i, t := range p.Terminals {
|
||||
if t == pattern {
|
||||
return i
|
||||
}
|
||||
}
|
||||
p.Terminals = append(p.Terminals, pattern)
|
||||
return len(p.Terminals) - 1
|
||||
}
|
||||
|
||||
// compiler compiles EBNF grammars to PDAs.
|
||||
type compiler struct {
|
||||
grammar ebnf.Grammar
|
||||
pda *pda
|
||||
|
||||
// Maps production names to their entry/exit states
|
||||
prodEntry map[string]state
|
||||
prodExit map[string]state
|
||||
}
|
||||
|
||||
// compile parses an EBNF grammar and compiles it to a pda.
|
||||
func compile(name string, src io.Reader, start string) (*pda, error) {
|
||||
grammar, err := ebnf.Parse(name, src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse grammar: %w", err)
|
||||
}
|
||||
|
||||
if err := ebnf.Verify(grammar, start); err != nil {
|
||||
return nil, fmt.Errorf("verify grammar: %w", err)
|
||||
}
|
||||
|
||||
c := &compiler{
|
||||
grammar: grammar,
|
||||
pda: newPDA(),
|
||||
prodEntry: make(map[string]state),
|
||||
prodExit: make(map[string]state),
|
||||
}
|
||||
|
||||
// Create entry/exit states for each production
|
||||
for name := range grammar {
|
||||
c.prodEntry[name] = c.pda.addState()
|
||||
c.prodExit[name] = c.pda.addState()
|
||||
}
|
||||
|
||||
// compile each production
|
||||
for name, prod := range grammar {
|
||||
if err := c.compileProduction(name, prod); err != nil {
|
||||
return nil, fmt.Errorf("compile production %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set start state to entry of start production
|
||||
if entry, ok := c.prodEntry[start]; ok {
|
||||
// Add epsilon transition from pda start to grammar start
|
||||
c.pda.addTransition(transition{
|
||||
FromState: stateStart,
|
||||
Pattern: "", // epsilon
|
||||
ToState: entry,
|
||||
})
|
||||
} else {
|
||||
return nil, fmt.Errorf("start production %q not found", start)
|
||||
}
|
||||
|
||||
// Mark exit of start production as accepting
|
||||
if exit, ok := c.prodExit[start]; ok {
|
||||
c.pda.AcceptStates[exit] = true
|
||||
}
|
||||
|
||||
return c.pda, nil
|
||||
}
|
||||
|
||||
// compileString is a convenience function to compile from a string.
|
||||
func compileString(grammar string, start string) (*pda, error) {
|
||||
return compile("grammar", strings.NewReader(grammar), start)
|
||||
}
|
||||
|
||||
func (c *compiler) compileProduction(name string, prod *ebnf.Production) error {
|
||||
entry := c.prodEntry[name]
|
||||
exit := c.prodExit[name]
|
||||
|
||||
return c.compileExpr(prod.Expr, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileExpr(expr ebnf.Expression, entry, exit state) error {
|
||||
switch e := expr.(type) {
|
||||
case *ebnf.Name:
|
||||
return c.compileName(e, entry, exit)
|
||||
case *ebnf.Token:
|
||||
return c.compileToken(e, entry, exit)
|
||||
case ebnf.Sequence:
|
||||
return c.compileSequence(e, entry, exit)
|
||||
case ebnf.Alternative:
|
||||
return c.compileAlternative(e, entry, exit)
|
||||
case *ebnf.Option:
|
||||
return c.compileOption(e, entry, exit)
|
||||
case *ebnf.Repetition:
|
||||
return c.compileRepetition(e, entry, exit)
|
||||
case *ebnf.Group:
|
||||
return c.compileExpr(e.Body, entry, exit)
|
||||
case *ebnf.Range:
|
||||
return c.compileRange(e, entry, exit)
|
||||
case nil:
|
||||
// Empty production - direct epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported expression type: %T", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compiler) compileName(n *ebnf.Name, entry, exit state) error {
|
||||
// Reference to another production
|
||||
prodName := n.String
|
||||
|
||||
prodEntry, ok := c.prodEntry[prodName]
|
||||
if !ok {
|
||||
return fmt.Errorf("undefined production: %s", prodName)
|
||||
}
|
||||
prodExit := c.prodExit[prodName]
|
||||
// Use a unique stack symbol per call site so returns are unambiguous.
|
||||
stackSym := c.pda.addStackSymbol()
|
||||
|
||||
// Push return address, go to production entry
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "", // epsilon
|
||||
ToState: prodEntry,
|
||||
StackPush: []stackSymbol{stackSym},
|
||||
})
|
||||
|
||||
// On production exit, pop and return
|
||||
c.pda.addTransition(transition{
|
||||
FromState: prodExit,
|
||||
stackTop: stackSym,
|
||||
Pattern: "", // epsilon
|
||||
ToState: exit,
|
||||
StackPop: 1,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileToken(t *ebnf.Token, entry, exit state) error {
|
||||
// terminal symbol - add transition that consumes this token
|
||||
pattern := t.String
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileSequence(seq ebnf.Sequence, entry, exit state) error {
|
||||
if len(seq) == 0 {
|
||||
// Empty sequence - epsilon transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chain: entry -> s1 -> s2 -> ... -> exit
|
||||
current := entry
|
||||
for i, expr := range seq {
|
||||
var next state
|
||||
if i == len(seq)-1 {
|
||||
next = exit
|
||||
} else {
|
||||
next = c.pda.addState()
|
||||
}
|
||||
|
||||
if err := c.compileExpr(expr, current, next); err != nil {
|
||||
return err
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileAlternative(alt ebnf.Alternative, entry, exit state) error {
|
||||
// Each alternative goes from entry to exit
|
||||
for _, expr := range alt {
|
||||
if err := c.compileExpr(expr, entry, exit); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compiler) compileOption(opt *ebnf.Option, entry, exit state) error {
|
||||
// Optional: can skip (epsilon) or take the body
|
||||
|
||||
// Epsilon transition (skip)
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Or take the body
|
||||
return c.compileExpr(opt.Body, entry, exit)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRepetition(rep *ebnf.Repetition, entry, exit state) error {
|
||||
// Repetition {body}: zero or more
|
||||
// entry -> exit (skip)
|
||||
// entry -> body -> entry (loop back)
|
||||
|
||||
// Skip transition
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: "",
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
// Loop: entry -> (body) -> entry
|
||||
return c.compileExpr(rep.Body, entry, entry)
|
||||
}
|
||||
|
||||
func (c *compiler) compileRange(r *ebnf.Range, entry, exit state) error {
|
||||
// Character range like "a" … "z" or "\u03b1" … "\u03c9"
|
||||
begin := strings.Trim(r.Begin.String, "\"")
|
||||
end := strings.Trim(r.End.String, "\"")
|
||||
|
||||
// Unescape bounds first (so "\u03b1" works)
|
||||
beginUnesc, err := unescapeLiteral(begin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range begin: %w", err)
|
||||
}
|
||||
endUnesc, err := unescapeLiteral(end)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid range end: %w", err)
|
||||
}
|
||||
|
||||
// Validate as single runes (not bytes) for Unicode support
|
||||
beginRunes := []rune(beginUnesc)
|
||||
endRunes := []rune(endUnesc)
|
||||
if len(beginRunes) != 1 || len(endRunes) != 1 {
|
||||
return fmt.Errorf("range bounds must be single characters: %q..%q", r.Begin.String, r.End.String)
|
||||
}
|
||||
|
||||
// Use unescaped rune strings in pattern (consistent with matcher)
|
||||
pattern := fmt.Sprintf("[%s-%s]", string(beginRunes[0]), string(endRunes[0]))
|
||||
c.pda.addTerminal(pattern)
|
||||
|
||||
c.pda.addTransition(transition{
|
||||
FromState: entry,
|
||||
Pattern: pattern,
|
||||
ToState: exit,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runtime represents a pda execution instance.
|
||||
type runtime struct {
|
||||
pda *pda
|
||||
state state
|
||||
stack []stackSymbol
|
||||
}
|
||||
|
||||
// newRuntime creates a new pda runtime.
|
||||
func newRuntime(pda *pda) *runtime {
|
||||
return &runtime{
|
||||
pda: pda,
|
||||
state: pda.StartState,
|
||||
stack: make([]stackSymbol, 0, 32),
|
||||
}
|
||||
}
|
||||
|
||||
// stackTop returns the top of the stack, or stackEmpty if empty.
|
||||
func (r *runtime) stackTop() stackSymbol {
|
||||
if len(r.stack) == 0 {
|
||||
return stackEmpty
|
||||
}
|
||||
return r.stack[len(r.stack)-1]
|
||||
}
|
||||
|
||||
// isAccepting returns true if we can reach an accepting state via epsilon transitions
|
||||
// with an empty stack.
|
||||
func (r *runtime) isAccepting() bool {
|
||||
return r.canReachAccept(r.state, r.stack, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) canReachAccept(state state, stack []stackSymbol, visited map[stateStackKey]bool) bool {
|
||||
// Check if this state is accepting with empty stack
|
||||
if r.pda.AcceptStates[state] && len(stack) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
// Try epsilon transitions
|
||||
for _, t := range r.pda.Transitions[state] {
|
||||
if t.Pattern != "" {
|
||||
continue // Not epsilon
|
||||
}
|
||||
|
||||
// Check stack constraint
|
||||
stackTop := stackEmpty
|
||||
if len(stack) > 0 {
|
||||
stackTop = stack[len(stack)-1]
|
||||
}
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
// Simulate stack operations
|
||||
newStack := make([]stackSymbol, len(stack))
|
||||
copy(newStack, stack)
|
||||
|
||||
if t.StackPop > 0 && len(newStack) >= t.StackPop {
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
if r.canReachAccept(t.ToState, newStack, visited) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Reset resets the runtime to initial state.
|
||||
func (r *runtime) Reset() {
|
||||
r.state = r.pda.StartState
|
||||
r.stack = r.stack[:0]
|
||||
}
|
||||
|
||||
// validInputs returns all valid input patterns from current state.
|
||||
func (r *runtime) validInputs() []string {
|
||||
var valid []string
|
||||
seen := make(map[string]bool)
|
||||
visited := make(map[stateStackKey]bool)
|
||||
|
||||
// Make a copy of the stack for simulation
|
||||
simStack := make([]stackSymbol, len(r.stack))
|
||||
copy(simStack, r.stack)
|
||||
|
||||
r.collectValidInputs(r.state, simStack, seen, visited, &valid)
|
||||
return valid
|
||||
}
|
||||
|
||||
// stateStackKey is used to detect cycles in epsilon closure
|
||||
type stateStackKey struct {
|
||||
state state
|
||||
stackSig string
|
||||
}
|
||||
|
||||
func stackSignature(stack []stackSymbol) string {
|
||||
if len(stack) == 0 {
|
||||
return ""
|
||||
}
|
||||
buf := make([]byte, len(stack)*8)
|
||||
for i, sym := range stack {
|
||||
binary.LittleEndian.PutUint64(buf[i*8:], uint64(sym))
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (r *runtime) collectValidInputs(state state, simStack []stackSymbol, seen map[string]bool, visited map[stateStackKey]bool, valid *[]string) {
|
||||
// Get stack top for comparisons
|
||||
stackTop := stackEmpty
|
||||
if len(simStack) > 0 {
|
||||
stackTop = simStack[len(simStack)-1]
|
||||
}
|
||||
|
||||
// Check for cycles to avoid infinite loops
|
||||
key := stateStackKey{state: state, stackSig: stackSignature(simStack)}
|
||||
if visited[key] {
|
||||
return
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[state]
|
||||
|
||||
for _, t := range transitions {
|
||||
// Check stack constraint
|
||||
if t.stackTop != stackEmpty && t.stackTop != stackTop {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Pattern == "" {
|
||||
// Epsilon transition - simulate stack operations
|
||||
newStack := make([]stackSymbol, len(simStack))
|
||||
copy(newStack, simStack)
|
||||
|
||||
// Pop
|
||||
if t.StackPop > 0 {
|
||||
if len(newStack) < t.StackPop {
|
||||
continue // Can't pop, skip this transition
|
||||
}
|
||||
newStack = newStack[:len(newStack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
newStack = append(newStack, t.StackPush...)
|
||||
|
||||
r.collectValidInputs(t.ToState, newStack, seen, visited, valid)
|
||||
} else {
|
||||
// terminal - add if not seen
|
||||
if !seen[t.Pattern] {
|
||||
seen[t.Pattern] = true
|
||||
*valid = append(*valid, t.Pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// matchesPattern checks if input matches a pattern.
|
||||
// Patterns can be:
|
||||
// - Exact strings: "a", "{", "true"
|
||||
// - Character ranges: "[a-z]", "[0-9]", "[#-~]"
|
||||
func matchesPattern(input, pattern string) bool {
|
||||
// Exact match
|
||||
if input == pattern {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for character range pattern [X-Y]
|
||||
if len(pattern) == 5 && pattern[0] == '[' && pattern[2] == '-' && pattern[4] == ']' {
|
||||
if len(input) != 1 {
|
||||
return false
|
||||
}
|
||||
ch := input[0]
|
||||
low := pattern[1]
|
||||
high := pattern[3]
|
||||
return ch >= low && ch <= high
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Accept tries to accept an input, returning true if successful.
|
||||
func (r *runtime) Accept(input string) bool {
|
||||
return r.accept(input, make(map[stateStackKey]bool))
|
||||
}
|
||||
|
||||
func (r *runtime) accept(input string, visited map[stateStackKey]bool) bool {
|
||||
key := stateStackKey{state: r.state, stackSig: stackSignature(r.stack)}
|
||||
if visited[key] {
|
||||
return false
|
||||
}
|
||||
visited[key] = true
|
||||
|
||||
transitions := r.pda.Transitions[r.state]
|
||||
|
||||
// First, process any epsilon transitions to reach a state that can accept input
|
||||
// This is a simplified version - full implementation would need epsilon closure
|
||||
for _, t := range transitions {
|
||||
if matchesPattern(input, t.Pattern) {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply transition
|
||||
r.applyTransition(t)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Try epsilon transitions first
|
||||
for _, t := range transitions {
|
||||
if t.Pattern == "" {
|
||||
if t.stackTop != stackEmpty && t.stackTop != r.stackTop() {
|
||||
continue
|
||||
}
|
||||
if t.StackPop > len(r.stack) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Save state for backtracking
|
||||
oldState := r.state
|
||||
oldStack := make([]stackSymbol, len(r.stack))
|
||||
copy(oldStack, r.stack)
|
||||
|
||||
r.applyTransition(t)
|
||||
|
||||
if r.accept(input, visited) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Backtrack
|
||||
r.state = oldState
|
||||
r.stack = oldStack
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *runtime) applyTransition(t transition) {
|
||||
// Pop
|
||||
if t.StackPop > 0 && len(r.stack) >= t.StackPop {
|
||||
r.stack = r.stack[:len(r.stack)-t.StackPop]
|
||||
}
|
||||
|
||||
// Push
|
||||
r.stack = append(r.stack, t.StackPush...)
|
||||
|
||||
// Move to new state
|
||||
r.state = t.ToState
|
||||
}
|
||||
540
x/grammar/grammar_test.go
Normal file
540
x/grammar/grammar_test.go
Normal file
@@ -0,0 +1,540 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileSimpleGrammar(t *testing.T) {
|
||||
// Simple grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
if pda == nil {
|
||||
t.Fatal("pda is nil")
|
||||
}
|
||||
|
||||
// Should have terminals "a" and "b"
|
||||
if len(pda.Terminals) != 2 {
|
||||
t.Errorf("expected 2 terminals, got %d: %v", len(pda.Terminals), pda.Terminals)
|
||||
}
|
||||
|
||||
// Test runtime
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Should accept "a" then "b"
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be in accepting state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileAlternative(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Test accepting "a"
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'a'")
|
||||
}
|
||||
|
||||
// Test accepting "b"
|
||||
rt.Reset()
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// Test rejecting "c"
|
||||
rt.Reset()
|
||||
if rt.Accept("c") {
|
||||
t.Error("should not accept 'c'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRepetition(t *testing.T) {
|
||||
// Grammar: S = {"a"} .
|
||||
grammar := `S = {"a"} .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// Empty should be accepted (zero repetitions)
|
||||
rt := newRuntime(pda)
|
||||
if !rt.isAccepting() {
|
||||
t.Error("empty should be accepting")
|
||||
}
|
||||
|
||||
// "a" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept first 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after one 'a'")
|
||||
}
|
||||
|
||||
// "aa" should be accepted
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept second 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after two 'a's")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileOption(t *testing.T) {
|
||||
// Grammar: S = ["a"] "b" .
|
||||
grammar := `S = ["a"] "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "b" alone should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' alone")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'b'")
|
||||
}
|
||||
|
||||
// "ab" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("a") {
|
||||
t.Error("should accept 'a'")
|
||||
}
|
||||
if !rt.Accept("b") {
|
||||
t.Error("should accept 'b' after 'a'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'ab'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRecursive(t *testing.T) {
|
||||
// Grammar with recursion: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
// "x" should be accepted
|
||||
rt := newRuntime(pda)
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after 'x'")
|
||||
}
|
||||
|
||||
// "(x)" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x' inside parens")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x)'")
|
||||
}
|
||||
|
||||
// "((x))" should be accepted
|
||||
rt.Reset()
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept first '('")
|
||||
}
|
||||
if !rt.Accept("(") {
|
||||
t.Error("should accept second '('")
|
||||
}
|
||||
if !rt.Accept("x") {
|
||||
t.Error("should accept 'x'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept first ')'")
|
||||
}
|
||||
if !rt.Accept(")") {
|
||||
t.Error("should accept second ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '((x))'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidInputs(t *testing.T) {
|
||||
// Grammar: S = "a" | "b" .
|
||||
grammar := `S = "a" | "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
valid := rt.validInputs()
|
||||
|
||||
// Should have both "a" and "b" as valid
|
||||
hasA, hasB := false, false
|
||||
for _, v := range valid {
|
||||
if v == "a" {
|
||||
hasA = true
|
||||
}
|
||||
if v == "b" {
|
||||
hasB = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasA {
|
||||
t.Error("'a' should be valid input")
|
||||
}
|
||||
if !hasB {
|
||||
t.Error("'b' should be valid input")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsAfterAccept tests that validInputs returns correct values
|
||||
// after accepting tokens, ensuring proper stack simulation.
|
||||
func TestValidInputsAfterAccept(t *testing.T) {
|
||||
// Grammar: S = "a" "b" "c" .
|
||||
grammar := `S = "a" "b" "c" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "a" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "a" {
|
||||
t.Errorf("initially expected only 'a', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "a", only "b" should be valid
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "b" {
|
||||
t.Errorf("after 'a', expected only 'b', got %v", valid)
|
||||
}
|
||||
|
||||
// After accepting "b", only "c" should be valid
|
||||
if !rt.Accept("b") {
|
||||
t.Fatal("failed to accept 'b'")
|
||||
}
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "c" {
|
||||
t.Errorf("after 'ab', expected only 'c', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsWithRepetitionInProduction tests the critical case where
|
||||
// a repetition exists inside a called production. This requires proper
|
||||
// stack simulation to determine when closing symbols are valid.
|
||||
func TestValidInputsWithRepetitionInProduction(t *testing.T) {
|
||||
// Grammar similar to JSON:
|
||||
// S = "(" items ")" .
|
||||
// items = item { "," item } .
|
||||
// item = "x" .
|
||||
grammar := `
|
||||
S = "(" items ")" .
|
||||
items = item { "," item } .
|
||||
item = "x" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially only "(" should be valid
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "(" {
|
||||
t.Errorf("initially expected only '(', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept "("
|
||||
if !rt.Accept("(") {
|
||||
t.Fatal("failed to accept '('")
|
||||
}
|
||||
// After "(", should be able to accept "x" (item)
|
||||
valid = rt.validInputs()
|
||||
hasX := false
|
||||
for _, v := range valid {
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasX {
|
||||
t.Errorf("after '(', expected 'x' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept first item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept 'x'")
|
||||
}
|
||||
// After "(x", should be able to accept "," (more items) OR ")" (end)
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose := false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Accept comma for another item
|
||||
if !rt.Accept(",") {
|
||||
t.Fatal("failed to accept ','")
|
||||
}
|
||||
// After "(x,", should only be able to accept "x" (next item)
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "x" {
|
||||
t.Errorf("after '(x,', expected only 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// Accept second item "x"
|
||||
if !rt.Accept("x") {
|
||||
t.Fatal("failed to accept second 'x'")
|
||||
}
|
||||
// CRITICAL: After "(x,x", should be able to accept "," OR ")"
|
||||
// This tests the stack simulation fix - we need to properly
|
||||
// follow epsilon transitions through the production call stack.
|
||||
valid = rt.validInputs()
|
||||
hasComma, hasClose = false, false
|
||||
for _, v := range valid {
|
||||
if v == "," {
|
||||
hasComma = true
|
||||
}
|
||||
if v == ")" {
|
||||
hasClose = true
|
||||
}
|
||||
}
|
||||
if !hasComma {
|
||||
t.Errorf("after '(x,x', expected ',' to be valid, got %v", valid)
|
||||
}
|
||||
if !hasClose {
|
||||
t.Errorf("after '(x,x', expected ')' to be valid, got %v", valid)
|
||||
}
|
||||
|
||||
// Close with ")"
|
||||
if !rt.Accept(")") {
|
||||
t.Fatal("failed to accept ')'")
|
||||
}
|
||||
if !rt.isAccepting() {
|
||||
t.Error("should be accepting after '(x,x)'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsNestedCalls tests validInputs with deeply nested production calls.
|
||||
func TestValidInputsNestedCalls(t *testing.T) {
|
||||
// Grammar: A = "start" B "end" . B = "middle" .
|
||||
grammar := `
|
||||
A = "start" B "end" .
|
||||
B = "middle" .
|
||||
`
|
||||
pda, err := compileString(grammar, "A")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// After "start", should accept "middle" (from B)
|
||||
rt.Accept("start")
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "middle" {
|
||||
t.Errorf("after 'start', expected 'middle', got %v", valid)
|
||||
}
|
||||
|
||||
// After "start middle", should accept "end"
|
||||
rt.Accept("middle")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "end" {
|
||||
t.Errorf("after 'start middle', expected 'end', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnAddressDisambiguation(t *testing.T) {
|
||||
// Grammar where the same production is called from different contexts:
|
||||
// S = A "x" | "c" A "y" .
|
||||
// A = "a" .
|
||||
grammar := `
|
||||
S = A "x" | "c" A "y" .
|
||||
A = "a" .
|
||||
`
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
if !rt.Accept("c") {
|
||||
t.Fatal("failed to accept 'c'")
|
||||
}
|
||||
if !rt.Accept("a") {
|
||||
t.Fatal("failed to accept 'a'")
|
||||
}
|
||||
|
||||
valid := rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != "y" {
|
||||
t.Errorf("after 'ca', expected only 'y', got %v", valid)
|
||||
}
|
||||
|
||||
rt.Reset()
|
||||
rt.Accept("c")
|
||||
rt.Accept("a")
|
||||
if rt.Accept("x") {
|
||||
t.Error("should not accept 'x' after 'ca'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidInputsRecursiveWithStack tests validInputs with recursive grammars
|
||||
// which heavily exercise the stack simulation.
|
||||
func TestValidInputsRecursiveWithStack(t *testing.T) {
|
||||
// Grammar: S = "(" S ")" | "x" .
|
||||
grammar := `S = "(" S ")" | "x" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// Initially: "(" or "x" should be valid
|
||||
valid := rt.validInputs()
|
||||
hasParen, hasX := false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("initially expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "(": "(" or "x" should be valid (nested S)
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '(', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((": "(" or "x" should still be valid
|
||||
rt.Accept("(")
|
||||
valid = rt.validInputs()
|
||||
hasParen, hasX = false, false
|
||||
for _, v := range valid {
|
||||
if v == "(" {
|
||||
hasParen = true
|
||||
}
|
||||
if v == "x" {
|
||||
hasX = true
|
||||
}
|
||||
}
|
||||
if !hasParen || !hasX {
|
||||
t.Errorf("after '((', expected '(' and 'x', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x": only ")" should be valid
|
||||
rt.Accept("x")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x', expected only ')', got %v", valid)
|
||||
}
|
||||
|
||||
// After "((x)": only ")" should be valid (closing outer)
|
||||
rt.Accept(")")
|
||||
valid = rt.validInputs()
|
||||
if len(valid) != 1 || valid[0] != ")" {
|
||||
t.Errorf("after '((x)', expected only ')', got %v", valid)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRejectionAfterValid tests that invalid inputs are rejected
|
||||
// at various points in the grammar.
|
||||
func TestRejectionAfterValid(t *testing.T) {
|
||||
// Grammar: S = "a" "b" .
|
||||
grammar := `S = "a" "b" .`
|
||||
|
||||
pda, err := compileString(grammar, "S")
|
||||
if err != nil {
|
||||
t.Fatalf("compile failed: %v", err)
|
||||
}
|
||||
|
||||
rt := newRuntime(pda)
|
||||
|
||||
// "b" should be rejected initially
|
||||
if rt.Accept("b") {
|
||||
t.Error("'b' should be rejected initially")
|
||||
}
|
||||
|
||||
// Accept "a"
|
||||
rt.Accept("a")
|
||||
|
||||
// "a" should be rejected after "a"
|
||||
if rt.Accept("a") {
|
||||
t.Error("'a' should be rejected after 'a'")
|
||||
}
|
||||
|
||||
// "c" should be rejected (not in grammar)
|
||||
if rt.Accept("c") {
|
||||
t.Error("'c' should be rejected (not in grammar)")
|
||||
}
|
||||
}
|
||||
56
x/grammar/grammars/README.md
Normal file
56
x/grammar/grammars/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Example Grammars
|
||||
|
||||
This directory contains example EBNF grammars for constrained decoding.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
go run -tags mlx ./x/imagegen/cmd/engine/ \
|
||||
-model /path/to/model \
|
||||
-prompt "Your prompt" \
|
||||
-grammar x/grammar/grammars/json.ebnf \
|
||||
-grammar-start value
|
||||
```
|
||||
|
||||
## Available Grammars
|
||||
|
||||
| File | Start Rule | Description |
|
||||
|------|------------|-------------|
|
||||
| `json.ebnf` | `value` | Standard JSON (RFC 8259) |
|
||||
| `expression.ebnf` | `expr` | Arithmetic expressions (+, -, *, /, parens) |
|
||||
| `identifier.ebnf` | `ident` | Programming language identifiers |
|
||||
| `boolean.ebnf` | `expr` | Boolean expressions (AND, OR, NOT) |
|
||||
| `list.ebnf` | `list` | Comma-separated word list |
|
||||
| `yesno.ebnf` | `response` | Simple yes/no responses |
|
||||
| `date.ebnf` | `date` | Dates in YYYY-MM-DD format |
|
||||
| `email.ebnf` | `email` | Basic email addresses |
|
||||
| `phone.ebnf` | `phone` | US phone numbers |
|
||||
| `hexcolor.ebnf` | `color` | CSS hex colors (#RGB or #RRGGBB) |
|
||||
| `url.ebnf` | `url` | HTTP/HTTPS URLs |
|
||||
|
||||
## Grammar Syntax
|
||||
|
||||
**Note:** Comments are not supported. Grammar files must contain only EBNF productions.
|
||||
|
||||
The grammars use EBNF notation:
|
||||
|
||||
- `=` defines a production rule
|
||||
- `|` is alternation (or)
|
||||
- `{ }` is repetition (zero or more)
|
||||
- `[ ]` is optional (zero or one)
|
||||
- `" "` is a literal string
|
||||
- `…` is a character range (e.g., `"a" … "z"`)
|
||||
- `.` ends a production
|
||||
|
||||
## Writing Custom Grammars
|
||||
|
||||
1. Define your grammar in a `.ebnf` file
|
||||
2. Choose a start rule name
|
||||
3. Pass `-grammar path/to/grammar.ebnf -grammar-start rulename`
|
||||
|
||||
Example custom grammar for RGB colors:
|
||||
|
||||
```ebnf
|
||||
color = "#" hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
|
||||
hexdigit = "0" … "9" | "a" … "f" | "A" … "F" .
|
||||
```
|
||||
7
x/grammar/grammars/boolean.ebnf
Normal file
7
x/grammar/grammars/boolean.ebnf
Normal file
@@ -0,0 +1,7 @@
|
||||
expr = term { " OR " term } .
|
||||
term = factor { " AND " factor } .
|
||||
factor = "NOT " factor | atom | "(" expr ")" .
|
||||
atom = "true" | "false" | ident .
|
||||
ident = letter { letter | digit } .
|
||||
letter = "a" … "z" | "A" … "Z" .
|
||||
digit = "0" … "9" .
|
||||
6
x/grammar/grammars/date.ebnf
Normal file
6
x/grammar/grammars/date.ebnf
Normal file
@@ -0,0 +1,6 @@
|
||||
date = year "-" month "-" day .
|
||||
year = digit digit digit digit .
|
||||
month = ( "0" digit1to9 ) | ( "1" ( "0" | "1" | "2" ) ) .
|
||||
day = ( "0" digit1to9 ) | ( ( "1" | "2" ) digit ) | ( "3" ( "0" | "1" ) ) .
|
||||
digit1to9 = "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
5
x/grammar/grammars/email.ebnf
Normal file
5
x/grammar/grammars/email.ebnf
Normal file
@@ -0,0 +1,5 @@
|
||||
email = localpart "@" domain .
|
||||
localpart = word { "." word } .
|
||||
domain = word { "." word } .
|
||||
word = alphanum { alphanum | "-" | "_" } .
|
||||
alphanum = "a" … "z" | "A" … "Z" | "0" … "9" .
|
||||
7
x/grammar/grammars/expression.ebnf
Normal file
7
x/grammar/grammars/expression.ebnf
Normal file
@@ -0,0 +1,7 @@
|
||||
expr = term { addop term } .
|
||||
addop = "+" | "-" .
|
||||
term = factor { mulop factor } .
|
||||
mulop = "*" | "/" .
|
||||
factor = number | "(" expr ")" .
|
||||
number = [ "-" ] digit { digit } .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
4
x/grammar/grammars/hexcolor.ebnf
Normal file
4
x/grammar/grammars/hexcolor.ebnf
Normal file
@@ -0,0 +1,4 @@
|
||||
color = "#" ( hex6 | hex3 ) .
|
||||
hex6 = hexdigit hexdigit hexdigit hexdigit hexdigit hexdigit .
|
||||
hex3 = hexdigit hexdigit hexdigit .
|
||||
hexdigit = "0" … "9" | "a" … "f" | "A" … "F" .
|
||||
3
x/grammar/grammars/identifier.ebnf
Normal file
3
x/grammar/grammars/identifier.ebnf
Normal file
@@ -0,0 +1,3 @@
|
||||
ident = letter { letter | digit | "_" } .
|
||||
letter = "a" … "z" | "A" … "Z" | "_" .
|
||||
digit = "0" … "9" .
|
||||
16
x/grammar/grammars/json.ebnf
Normal file
16
x/grammar/grammars/json.ebnf
Normal file
@@ -0,0 +1,16 @@
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
object = "{" [ members ] "}" .
|
||||
members = pair { "," pair } .
|
||||
pair = string ":" value .
|
||||
array = "[" [ elements ] "]" .
|
||||
elements = value { "," value } .
|
||||
string = "\"" { char } "\"" .
|
||||
char = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
onenine = "1" … "9" .
|
||||
digit = "0" … "9" .
|
||||
27
x/grammar/grammars/json_array.ebnf
Normal file
27
x/grammar/grammars/json_array.ebnf
Normal file
@@ -0,0 +1,27 @@
|
||||
root = array .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
4
x/grammar/grammars/list.ebnf
Normal file
4
x/grammar/grammars/list.ebnf
Normal file
@@ -0,0 +1,4 @@
|
||||
list = item { ", " item } .
|
||||
item = word .
|
||||
word = letter { letter } .
|
||||
letter = "a" … "z" | "A" … "Z" .
|
||||
19
x/grammar/grammars/people20.ebnf
Normal file
19
x/grammar/grammars/people20.ebnf
Normal file
@@ -0,0 +1,19 @@
|
||||
root = "[" ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person "," ws person { "," ws person } ws "]" .
|
||||
|
||||
person = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
|
||||
|
||||
name_field = "\"" "n" "a" "m" "e" "\"" ws ":" ws string .
|
||||
age_field = "\"" "a" "g" "e" "\"" ws ":" ws number .
|
||||
email_field = "\"" "e" "m" "a" "i" "l" "\"" ws ":" ws string .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" ) .
|
||||
|
||||
number = [ "-" ] integer .
|
||||
integer = "0" | onenine { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
15
x/grammar/grammars/person.ebnf
Normal file
15
x/grammar/grammars/person.ebnf
Normal file
@@ -0,0 +1,15 @@
|
||||
root = "{" ws name_field "," ws age_field "," ws email_field ws "}" .
|
||||
|
||||
name_field = "\"name\"" ws ":" ws string .
|
||||
age_field = "\"age\"" ws ":" ws number .
|
||||
email_field = "\"email\"" ws ":" ws string .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = " " | "!" | "#" … "~" .
|
||||
|
||||
number = [ "-" ] integer .
|
||||
integer = "0" | onenine { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
7
x/grammar/grammars/phone.ebnf
Normal file
7
x/grammar/grammars/phone.ebnf
Normal file
@@ -0,0 +1,7 @@
|
||||
phone = parenformat | dashformat .
|
||||
parenformat = "(" areacode ") " exchange "-" subscriber .
|
||||
dashformat = areacode "-" exchange "-" subscriber .
|
||||
areacode = digit digit digit .
|
||||
exchange = digit digit digit .
|
||||
subscriber = digit digit digit digit .
|
||||
digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" .
|
||||
11
x/grammar/grammars/url.ebnf
Normal file
11
x/grammar/grammars/url.ebnf
Normal file
@@ -0,0 +1,11 @@
|
||||
url = scheme "://" host [ ":" port ] [ path ] [ query ] .
|
||||
scheme = "http" | "https" .
|
||||
host = word { "." word } .
|
||||
port = digit { digit } .
|
||||
path = "/" { pathseg } .
|
||||
pathseg = word [ "/" ] .
|
||||
query = "?" param { "&" param } .
|
||||
param = word "=" word .
|
||||
word = alphanum { alphanum | "-" | "_" } .
|
||||
alphanum = "a" … "z" | "A" … "Z" | "0" … "9" .
|
||||
digit = "0" … "9" .
|
||||
3
x/grammar/grammars/yesno.ebnf
Normal file
3
x/grammar/grammars/yesno.ebnf
Normal file
@@ -0,0 +1,3 @@
|
||||
response = affirmative | negative .
|
||||
affirmative = "yes" | "Yes" | "YES" | "y" | "Y" | "true" | "True" .
|
||||
negative = "no" | "No" | "NO" | "n" | "N" | "false" | "False" .
|
||||
69
x/grammar/json.go
Normal file
69
x/grammar/json.go
Normal file
@@ -0,0 +1,69 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
// JSONGrammarEBNF is the EBNF grammar for JSON (character-level).
|
||||
// Based on https://www.json.org/json-en.html
|
||||
//
|
||||
// This grammar operates at the character level. The engine validates
|
||||
// tokens by matching them as sequences of these character-level terminals.
|
||||
const JSONGrammarEBNF = `
|
||||
json = value .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`
|
||||
|
||||
// JSONObjectGrammarEBNF is like JSONGrammarEBNF but only allows objects at the top level.
|
||||
const JSONObjectGrammarEBNF = `
|
||||
json = object .
|
||||
|
||||
value = object | array | string | number | "true" | "false" | "null" .
|
||||
|
||||
object = "{" ws "}" | "{" members "}" .
|
||||
members = member { "," member } .
|
||||
member = ws string ws ":" element .
|
||||
|
||||
array = "[" ws "]" | "[" elements "]" .
|
||||
elements = element { "," element } .
|
||||
element = ws value ws .
|
||||
|
||||
string = "\"" { character } "\"" .
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
digit = "0" … "9" .
|
||||
onenine = "1" … "9" .
|
||||
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`
|
||||
726
x/grammar/schema/schema.go
Normal file
726
x/grammar/schema/schema.go
Normal file
@@ -0,0 +1,726 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package schema converts OpenAI-compatible JSON Schema into constrained grammars.
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
)
|
||||
|
||||
// schemaNode represents OpenAI-compatible JSON Schema for structured outputs.
|
||||
// See: https://platform.openai.com/docs/guides/structured-outputs
|
||||
type schemaNode struct {
|
||||
// Core types
|
||||
Type interface{} `json:"type"` // string, []string, or nil
|
||||
|
||||
// Object properties
|
||||
Properties map[string]*schemaNode `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
AdditionalProperties interface{} `json:"additionalProperties"`
|
||||
|
||||
// Array properties
|
||||
Items *schemaNode `json:"items"`
|
||||
MinItems *int `json:"minItems"`
|
||||
MaxItems *int `json:"maxItems"`
|
||||
|
||||
// String properties
|
||||
Pattern string `json:"pattern"` // Regex pattern
|
||||
Format string `json:"format"` // date-time, email, uuid, etc.
|
||||
|
||||
// Number properties (noted but not enforced in grammar - validated post-generation)
|
||||
Minimum *float64 `json:"minimum"`
|
||||
Maximum *float64 `json:"maximum"`
|
||||
ExclusiveMinimum *float64 `json:"exclusiveMinimum"`
|
||||
ExclusiveMaximum *float64 `json:"exclusiveMaximum"`
|
||||
MultipleOf *float64 `json:"multipleOf"`
|
||||
|
||||
// Enum and const
|
||||
Enum []interface{} `json:"enum"`
|
||||
Const interface{} `json:"const"`
|
||||
|
||||
// Composition
|
||||
AnyOf []*schemaNode `json:"anyOf"`
|
||||
OneOf []*schemaNode `json:"oneOf"` // Treated same as anyOf for grammar
|
||||
|
||||
// References and definitions
|
||||
Ref string `json:"$ref"`
|
||||
Defs map[string]*schemaNode `json:"$defs"`
|
||||
|
||||
// Description (ignored for grammar but useful for docs)
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// converter handles JSON Schema to EBNF conversion with state.
|
||||
type converter struct {
|
||||
schema *schemaNode
|
||||
definitions map[string]*schemaNode // Resolved $defs
|
||||
usedTypes map[string]bool
|
||||
rules []string
|
||||
ruleNum int
|
||||
definedRefs map[string]bool // Track which refs we've already defined as rules
|
||||
}
|
||||
|
||||
// EBNF converts a JSON Schema to EBNF grammar
|
||||
func EBNF(schemaJSON string) (string, error) {
|
||||
var schema schemaNode
|
||||
if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
|
||||
return "", fmt.Errorf("failed to parse JSON Schema: %w", err)
|
||||
}
|
||||
|
||||
conv := &converter{
|
||||
schema: &schema,
|
||||
definitions: schema.Defs,
|
||||
usedTypes: make(map[string]bool),
|
||||
definedRefs: make(map[string]bool),
|
||||
}
|
||||
|
||||
return conv.convert()
|
||||
}
|
||||
|
||||
func (c *converter) convert() (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Generate root rule
|
||||
rootExpr := c.schemaToExpr(c.schema, "root")
|
||||
b.WriteString("root = ")
|
||||
b.WriteString(rootExpr)
|
||||
b.WriteString(" .\n")
|
||||
|
||||
// Add generated rules (refs, items, etc.)
|
||||
for _, rule := range c.rules {
|
||||
b.WriteString(rule)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add primitives based on usage
|
||||
c.addPrimitives(&b)
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func (c *converter) addPrimitives(b *strings.Builder) {
|
||||
if c.usedTypes["string"] {
|
||||
b.WriteString(`
|
||||
string = "\"" { character } "\"" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["string"] || c.usedTypes["character"] {
|
||||
b.WriteString(`
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["number"] {
|
||||
b.WriteString(`
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["integer"] {
|
||||
b.WriteString(`
|
||||
int = [ "-" ] ( "0" | onenine { digit } ) .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["number"] || c.usedTypes["integer"] || c.usedTypes["digit"] {
|
||||
b.WriteString(`
|
||||
digit = "0" … "9" .
|
||||
`)
|
||||
}
|
||||
|
||||
// onenine only needed for number/integer, not for digit-only formats
|
||||
if c.usedTypes["number"] || c.usedTypes["integer"] {
|
||||
b.WriteString(`onenine = "1" … "9" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["string"] || c.usedTypes["character"] || c.usedTypes["hex"] {
|
||||
b.WriteString(`
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["ws"] {
|
||||
b.WriteString(`
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) schemaToExpr(schema *schemaNode, name string) string {
|
||||
if schema == nil {
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "( string | number | object | array | \"true\" | \"false\" | \"null\" )"
|
||||
}
|
||||
|
||||
// Handle $ref first
|
||||
if schema.Ref != "" {
|
||||
return c.resolveRef(schema.Ref)
|
||||
}
|
||||
|
||||
// Handle const
|
||||
if schema.Const != nil {
|
||||
return c.constToExpr(schema.Const)
|
||||
}
|
||||
|
||||
// Handle enum
|
||||
if len(schema.Enum) > 0 {
|
||||
return c.enumToExpr(schema.Enum)
|
||||
}
|
||||
|
||||
// Handle anyOf / oneOf
|
||||
if len(schema.AnyOf) > 0 {
|
||||
return c.anyOfToExpr(schema.AnyOf, name)
|
||||
}
|
||||
if len(schema.OneOf) > 0 {
|
||||
return c.anyOfToExpr(schema.OneOf, name)
|
||||
}
|
||||
|
||||
// Handle type
|
||||
types := c.getTypes(schema.Type)
|
||||
if len(types) == 0 {
|
||||
// No type specified, could be anything
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "( string | number | \"true\" | \"false\" | \"null\" )"
|
||||
}
|
||||
|
||||
if len(types) == 1 {
|
||||
return c.typeToExpr(types[0], schema, name)
|
||||
}
|
||||
|
||||
// Multiple types (e.g., ["string", "null"])
|
||||
var parts []string
|
||||
for _, t := range types {
|
||||
parts = append(parts, c.typeToExpr(t, schema, name))
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) typeToExpr(typeName string, schema *schemaNode, name string) string {
|
||||
switch typeName {
|
||||
case "object":
|
||||
return c.objectToExpr(schema, name)
|
||||
case "array":
|
||||
return c.arrayToExpr(schema, name)
|
||||
case "string":
|
||||
return c.stringToExpr(schema, name)
|
||||
case "number":
|
||||
c.usedTypes["number"] = true
|
||||
return "number"
|
||||
case "integer":
|
||||
c.usedTypes["integer"] = true
|
||||
c.usedTypes["digit"] = true
|
||||
return "int"
|
||||
case "boolean":
|
||||
return `( "true" | "false" )`
|
||||
case "null":
|
||||
return `"null"`
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) objectToExpr(schema *schemaNode, name string) string {
|
||||
c.usedTypes["ws"] = true
|
||||
|
||||
if len(schema.Properties) == 0 {
|
||||
return `"{" ws "}"`
|
||||
}
|
||||
|
||||
// Sort properties for deterministic output
|
||||
// Required properties come first, in their required order
|
||||
var propOrder []string
|
||||
requiredSet := make(map[string]bool)
|
||||
for _, r := range schema.Required {
|
||||
requiredSet[r] = true
|
||||
propOrder = append(propOrder, r)
|
||||
}
|
||||
|
||||
// Add any non-required properties (though OpenAI requires all to be required)
|
||||
var optionalProps []string
|
||||
for propName := range schema.Properties {
|
||||
if !requiredSet[propName] {
|
||||
optionalProps = append(optionalProps, propName)
|
||||
}
|
||||
}
|
||||
sort.Strings(optionalProps)
|
||||
propOrder = append(propOrder, optionalProps...)
|
||||
|
||||
var propExprs []string
|
||||
first := true
|
||||
|
||||
for _, propName := range propOrder {
|
||||
propSchema, exists := schema.Properties[propName]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
propExpr := c.schemaToExpr(propSchema, propName)
|
||||
|
||||
prefix := ""
|
||||
if !first {
|
||||
prefix = `"," ws `
|
||||
}
|
||||
first = false
|
||||
|
||||
propExprs = append(propExprs, fmt.Sprintf(`%s"\"%s\"" ws ":" ws %s`, prefix, propName, propExpr))
|
||||
}
|
||||
|
||||
if len(propExprs) == 0 {
|
||||
return `"{" ws "}"`
|
||||
}
|
||||
|
||||
return `"{" ws ` + strings.Join(propExprs, " ") + ` ws "}"`
|
||||
}
|
||||
|
||||
func (c *converter) arrayToExpr(schema *schemaNode, name string) string {
|
||||
c.usedTypes["ws"] = true
|
||||
|
||||
itemExpr := "value"
|
||||
if schema.Items != nil {
|
||||
itemExpr = c.schemaToExpr(schema.Items, name+"_item")
|
||||
} else {
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
}
|
||||
|
||||
// Create item rule
|
||||
c.ruleNum++
|
||||
itemRule := fmt.Sprintf("item%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", itemRule, itemExpr))
|
||||
|
||||
// Handle minItems/maxItems
|
||||
if schema.MinItems != nil || schema.MaxItems != nil {
|
||||
return c.arrayWithBounds(itemRule, schema.MinItems, schema.MaxItems)
|
||||
}
|
||||
|
||||
// Default: zero or more items
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
||||
}
|
||||
|
||||
func (c *converter) arrayWithBounds(itemRule string, minItems, maxItems *int) string {
|
||||
min := 0
|
||||
max := -1 // unlimited
|
||||
|
||||
if minItems != nil {
|
||||
min = *minItems
|
||||
}
|
||||
if maxItems != nil {
|
||||
max = *maxItems
|
||||
}
|
||||
|
||||
if min == 0 && max < 0 {
|
||||
// No constraints
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
||||
}
|
||||
|
||||
if min == 0 && max == 0 {
|
||||
return `"[" ws "]"`
|
||||
}
|
||||
|
||||
// Build pattern for bounded array
|
||||
// For min=2, max=4: item "," item [ "," item ] [ "," item ]
|
||||
var parts []string
|
||||
|
||||
// Required items
|
||||
for i := 0; i < min; i++ {
|
||||
if i > 0 {
|
||||
parts = append(parts, `"," ws`)
|
||||
}
|
||||
parts = append(parts, itemRule)
|
||||
}
|
||||
|
||||
// Optional items up to max
|
||||
if max > min {
|
||||
for i := min; i < max; i++ {
|
||||
if i == 0 {
|
||||
parts = append(parts, fmt.Sprintf(`[ %s`, itemRule))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf(`[ "," ws %s`, itemRule))
|
||||
}
|
||||
}
|
||||
// Close all optional brackets
|
||||
for i := min; i < max; i++ {
|
||||
parts = append(parts, "]")
|
||||
}
|
||||
} else if max < 0 {
|
||||
// Unlimited after min
|
||||
if min > 0 {
|
||||
parts = append(parts, fmt.Sprintf(`{ "," ws %s }`, itemRule))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf(`[ %s { "," ws %s } ]`, itemRule, itemRule))
|
||||
}
|
||||
}
|
||||
|
||||
if min == 0 {
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s ws "]" )`, strings.Join(parts, " "))
|
||||
}
|
||||
return fmt.Sprintf(`"[" ws %s ws "]"`, strings.Join(parts, " "))
|
||||
}
|
||||
|
||||
func (c *converter) stringToExpr(schema *schemaNode, name string) string {
|
||||
// Handle format
|
||||
if schema.Format != "" {
|
||||
return c.formatToExpr(schema.Format)
|
||||
}
|
||||
|
||||
// Handle pattern (regex)
|
||||
if schema.Pattern != "" {
|
||||
return c.patternToExpr(schema.Pattern, name)
|
||||
}
|
||||
|
||||
// Default string
|
||||
c.usedTypes["string"] = true
|
||||
if name == "root" {
|
||||
c.usedTypes["character"] = true
|
||||
return `"\"" { character } "\""`
|
||||
}
|
||||
return "string"
|
||||
}
|
||||
|
||||
func (c *converter) formatToExpr(format string) string {
|
||||
switch format {
|
||||
case "date":
|
||||
// YYYY-MM-DD
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("date%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "time":
|
||||
// HH:MM:SS
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("time%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit ":" digit digit ":" digit digit "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "date-time":
|
||||
// YYYY-MM-DDTHH:MM:SSZ or with offset
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("datetime%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "email":
|
||||
// Simplified email pattern
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("email%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" emailchar { emailchar } "@" emailchar { emailchar } "." emailchar { emailchar } "\"" .`, ruleName))
|
||||
c.rules = append(c.rules, `emailchar = "a" … "z" | "A" … "Z" | "0" … "9" | "." | "-" | "_" .`)
|
||||
return ruleName
|
||||
|
||||
case "uuid":
|
||||
// 8-4-4-4-12 hex pattern
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("uuid%d", c.ruleNum)
|
||||
c.usedTypes["hex"] = true
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "ipv4":
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("ipv4_%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit { digit } "." digit { digit } "." digit { digit } "." digit { digit } "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "uri", "hostname":
|
||||
// Fallback to general string for complex formats
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) patternToExpr(pattern string, name string) string {
|
||||
// Try to convert simple regex patterns to EBNF
|
||||
// This handles common cases; complex regex falls back to string
|
||||
|
||||
// Remove anchors
|
||||
pattern = strings.TrimPrefix(pattern, "^")
|
||||
pattern = strings.TrimSuffix(pattern, "$")
|
||||
|
||||
// Try to parse and convert
|
||||
expr, ok := c.regexToEBNF(pattern)
|
||||
if !ok {
|
||||
// Fallback to general string
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("pattern%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" %s "\"" .`, ruleName, expr))
|
||||
return ruleName
|
||||
}
|
||||
|
||||
func (c *converter) regexToEBNF(pattern string) (string, bool) {
|
||||
// Simple regex to EBNF converter
|
||||
// Handles: literals, [a-z], [A-Z], [0-9], +, *, ?, basic groups
|
||||
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
|
||||
for i < len(pattern) {
|
||||
ch := pattern[i]
|
||||
|
||||
switch ch {
|
||||
case '[':
|
||||
// Character class
|
||||
end := strings.Index(pattern[i:], "]")
|
||||
if end == -1 {
|
||||
return "", false
|
||||
}
|
||||
class := pattern[i+1 : i+end]
|
||||
ebnfClass, ok := c.charClassToEBNF(class)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
result.WriteString(ebnfClass)
|
||||
i += end + 1
|
||||
|
||||
case '(':
|
||||
// Group - find matching )
|
||||
depth := 1
|
||||
start := i + 1
|
||||
j := start
|
||||
for j < len(pattern) && depth > 0 {
|
||||
if pattern[j] == '(' {
|
||||
depth++
|
||||
} else if pattern[j] == ')' {
|
||||
depth--
|
||||
}
|
||||
j++
|
||||
}
|
||||
if depth != 0 {
|
||||
return "", false
|
||||
}
|
||||
groupContent := pattern[start : j-1]
|
||||
groupExpr, ok := c.regexToEBNF(groupContent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
result.WriteString("( ")
|
||||
result.WriteString(groupExpr)
|
||||
result.WriteString(" )")
|
||||
i = j
|
||||
|
||||
case '|':
|
||||
result.WriteString(" | ")
|
||||
i++
|
||||
|
||||
case '+':
|
||||
// One or more - wrap previous in { } and add one required
|
||||
// This is a simplification
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '*':
|
||||
// Zero or more - need to wrap previous
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '?':
|
||||
// Optional - need to wrap previous in [ ]
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '\\':
|
||||
// Escape sequence
|
||||
if i+1 >= len(pattern) {
|
||||
return "", false
|
||||
}
|
||||
next := pattern[i+1]
|
||||
switch next {
|
||||
case 'd':
|
||||
result.WriteString("digit")
|
||||
c.usedTypes["digit"] = true
|
||||
case 'w':
|
||||
result.WriteString(`( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`)
|
||||
case 's':
|
||||
result.WriteString(`( " " | "\t" )`)
|
||||
default:
|
||||
result.WriteString(fmt.Sprintf(`"%c"`, next))
|
||||
}
|
||||
i += 2
|
||||
|
||||
default:
|
||||
// Literal character
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' || ch == '.' {
|
||||
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
||||
} else {
|
||||
// Special char, try to escape
|
||||
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(result.String()), true
|
||||
}
|
||||
|
||||
func (c *converter) charClassToEBNF(class string) (string, bool) {
|
||||
// Handle character classes like a-z, A-Z, 0-9
|
||||
if class == "a-zA-Z0-9_" || class == "a-zA-Z_" {
|
||||
return `( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`, true
|
||||
}
|
||||
if class == "a-zA-Z0-9" {
|
||||
return `( "a" … "z" | "A" … "Z" | "0" … "9" )`, true
|
||||
}
|
||||
if class == "a-z" {
|
||||
return `"a" … "z"`, true
|
||||
}
|
||||
if class == "A-Z" {
|
||||
return `"A" … "Z"`, true
|
||||
}
|
||||
if class == "0-9" {
|
||||
c.usedTypes["digit"] = true
|
||||
return "digit", true
|
||||
}
|
||||
|
||||
// Try to parse range patterns
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z]-[a-zA-Z]$`, class); matched {
|
||||
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
||||
}
|
||||
if matched, _ := regexp.MatchString(`^[0-9]-[0-9]$`, class); matched {
|
||||
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (c *converter) anyOfToExpr(schemas []*schemaNode, name string) string {
|
||||
var parts []string
|
||||
for i, s := range schemas {
|
||||
expr := c.schemaToExpr(s, fmt.Sprintf("%s_opt%d", name, i))
|
||||
parts = append(parts, expr)
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) enumToExpr(values []interface{}) string {
|
||||
var parts []string
|
||||
for _, v := range values {
|
||||
parts = append(parts, c.constToExpr(v))
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) constToExpr(v interface{}) string {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf(`"\"%s\""`, c.escapeString(val))
|
||||
case float64:
|
||||
if val == float64(int(val)) {
|
||||
return fmt.Sprintf(`"%d"`, int(val))
|
||||
}
|
||||
return fmt.Sprintf(`"%v"`, val)
|
||||
case bool:
|
||||
if val {
|
||||
return `"true"`
|
||||
}
|
||||
return `"false"`
|
||||
case nil:
|
||||
return `"null"`
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) resolveRef(ref string) string {
|
||||
// Handle #/$defs/name references
|
||||
if strings.HasPrefix(ref, "#/$defs/") {
|
||||
defName := strings.TrimPrefix(ref, "#/$defs/")
|
||||
return c.resolveDefRef(defName)
|
||||
}
|
||||
|
||||
// Handle root recursion #
|
||||
if ref == "#" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
// Unknown ref format
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
func (c *converter) resolveDefRef(defName string) string {
|
||||
// Check if we've already defined this as a rule
|
||||
ruleName := "def_" + defName
|
||||
if c.definedRefs[defName] {
|
||||
return ruleName
|
||||
}
|
||||
|
||||
// Mark as defined to prevent infinite recursion
|
||||
c.definedRefs[defName] = true
|
||||
|
||||
// Look up the definition
|
||||
if c.definitions == nil {
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
defSchema, ok := c.definitions[defName]
|
||||
if !ok {
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
// Generate the rule
|
||||
expr := c.schemaToExpr(defSchema, ruleName)
|
||||
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", ruleName, expr))
|
||||
|
||||
return ruleName
|
||||
}
|
||||
|
||||
func (c *converter) getTypes(t interface{}) []string {
|
||||
switch v := t.(type) {
|
||||
case string:
|
||||
return []string{v}
|
||||
case []interface{}:
|
||||
var types []string
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
types = append(types, s)
|
||||
}
|
||||
}
|
||||
return types
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *converter) escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `"`, `\"`)
|
||||
return s
|
||||
}
|
||||
|
||||
// Grammar converts a JSON Schema string into a compiled grammar.
|
||||
func Grammar(schemaJSON string) (*grammar.Grammar, error) {
|
||||
ebnf, err := EBNF(schemaJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return grammar.ParseEBNF(ebnf, "root")
|
||||
}
|
||||
336
x/grammar/schema/schema_test.go
Normal file
336
x/grammar/schema/schema_test.go
Normal file
@@ -0,0 +1,336 @@
|
||||
//go:build mlx
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gram "github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
func TestJSONEBNF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "with enum",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"enum": ["active", "inactive", "pending"]}
|
||||
},
|
||||
"required": ["status"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "array of objects",
|
||||
schema: `{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "nested object",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}
|
||||
},
|
||||
"required": ["user"]
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ebnf, err := EBNF(tc.schema)
|
||||
if err != nil {
|
||||
t.Fatalf("EBNF failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to compile it
|
||||
grammar, err := gram.ParseEBNF(ebnf, "root")
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEBNF failed: %v", err)
|
||||
}
|
||||
|
||||
if grammar == nil {
|
||||
t.Fatal("grammar is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrammarEngine(t *testing.T) {
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`
|
||||
|
||||
grammar, err := Grammar(schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Grammar failed: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"\"name\"", "\"age\"", "\"test\"",
|
||||
"\"", "a", "b", "c",
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
" ", "\n",
|
||||
"true", "false", "null",
|
||||
}
|
||||
|
||||
engine, err := gram.NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("grammar.NewEngine failed: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Test that we can apply mask
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
// TestOpenAIStructuredOutputs tests features required for OpenAI compatibility
|
||||
func TestOpenAIStructuredOutputs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
}{
|
||||
{
|
||||
name: "anyOf union",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["value"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "nullable string via type array",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": ["string", "null"]}
|
||||
},
|
||||
"required": ["name"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "$ref with $defs",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {"$ref": "#/$defs/Person"}
|
||||
},
|
||||
"required": ["person"],
|
||||
"$defs": {
|
||||
"Person": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "const value",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "user"}
|
||||
},
|
||||
"required": ["type"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format date-time",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": {"type": "string", "format": "date-time"}
|
||||
},
|
||||
"required": ["created"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format date",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"birthday": {"type": "string", "format": "date"}
|
||||
},
|
||||
"required": ["birthday"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format email",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string", "format": "email"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format uuid",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "format": "uuid"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "array with minItems maxItems",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
"maxItems": 3
|
||||
}
|
||||
},
|
||||
"required": ["tags"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "deeply nested with refs",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"employees": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/$defs/Employee"}
|
||||
}
|
||||
},
|
||||
"required": ["name", "employees"]
|
||||
}
|
||||
},
|
||||
"required": ["company"],
|
||||
"$defs": {
|
||||
"Employee": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"role": {"enum": ["engineer", "manager", "intern"]}
|
||||
},
|
||||
"required": ["name", "role"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "multiple refs same def",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"from": {"$ref": "#/$defs/Address"},
|
||||
"to": {"$ref": "#/$defs/Address"}
|
||||
},
|
||||
"required": ["from", "to"],
|
||||
"$defs": {
|
||||
"Address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"zip": {"type": "string"}
|
||||
},
|
||||
"required": ["city", "zip"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "oneOf variant",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"success": {"type": "boolean"}},
|
||||
"required": ["success"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"error": {"type": "string"}},
|
||||
"required": ["error"]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["result"]
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ebnf, err := EBNF(tc.schema)
|
||||
if err != nil {
|
||||
t.Fatalf("EBNF failed: %v", err)
|
||||
}
|
||||
|
||||
grammar, err := gram.ParseEBNF(ebnf, "root")
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEBNF failed: %v", err)
|
||||
}
|
||||
|
||||
if grammar == nil {
|
||||
t.Fatal("grammar is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
105
x/grammar/terminal.go
Normal file
105
x/grammar/terminal.go
Normal file
@@ -0,0 +1,105 @@
|
||||
//go:build mlx
|
||||
|
||||
package grammar
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// terminalType distinguishes different kinds of grammar terminals
|
||||
type terminalType int
|
||||
|
||||
const (
|
||||
terminalLiteral terminalType = iota // Exact string: "true", "{"
|
||||
terminalRange // Character range: [a-z], [0-9]
|
||||
)
|
||||
|
||||
// terminal represents a compiled grammar terminal
|
||||
type terminal struct {
|
||||
ID int
|
||||
Type terminalType
|
||||
Pattern string // Original pattern from grammar
|
||||
Unescaped string // Unescaped literal (for terminalLiteral)
|
||||
LowRune rune // For unicode ranges: low bound
|
||||
HighRune rune // For unicode ranges: high bound
|
||||
}
|
||||
|
||||
// terminalMatch represents a terminal that matched at a position
|
||||
type terminalMatch struct {
|
||||
TerminalID int
|
||||
Length int // Number of bytes consumed
|
||||
}
|
||||
|
||||
// trieNode is a node in the literal matching trie
|
||||
type trieNode struct {
|
||||
children [256]*trieNode // Byte-indexed children
|
||||
terminalID int // -1 if not accepting, else terminal ID
|
||||
}
|
||||
|
||||
// terminalMatcher tests which terminals match at a position in a byte slice
|
||||
type terminalMatcher struct {
|
||||
// Trie for literal matching (fast path)
|
||||
literalTrie *trieNode
|
||||
|
||||
// Range terminals (single-byte matches)
|
||||
ranges []terminal
|
||||
|
||||
// All terminals for enumeration
|
||||
terminals []terminal
|
||||
|
||||
// Pattern to terminal ID map for fast lookup (keyed by raw pattern)
|
||||
patternToID map[string]int
|
||||
}
|
||||
|
||||
// addLiteralToTrie adds a literal pattern to the trie
|
||||
func (m *terminalMatcher) addLiteralToTrie(pattern string, terminalID int) {
|
||||
node := m.literalTrie
|
||||
for i := 0; i < len(pattern); i++ {
|
||||
c := pattern[i]
|
||||
if node.children[c] == nil {
|
||||
node.children[c] = &trieNode{terminalID: -1}
|
||||
}
|
||||
node = node.children[c]
|
||||
}
|
||||
node.terminalID = terminalID
|
||||
}
|
||||
|
||||
// matchesAt returns all terminals that match at pos in data
|
||||
func (m *terminalMatcher) matchesAt(data []byte, pos int) []terminalMatch {
|
||||
if pos >= len(data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var matches []terminalMatch
|
||||
|
||||
// Check literal matches via trie
|
||||
node := m.literalTrie
|
||||
for i := pos; i < len(data) && node != nil; i++ {
|
||||
c := data[i]
|
||||
node = node.children[c]
|
||||
if node != nil && node.terminalID >= 0 {
|
||||
matches = append(matches, terminalMatch{
|
||||
TerminalID: node.terminalID,
|
||||
Length: i - pos + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check range matches (unicode-aware)
|
||||
r, runeLen := utf8.DecodeRune(data[pos:])
|
||||
if r != utf8.RuneError {
|
||||
for _, rng := range m.ranges {
|
||||
if r >= rng.LowRune && r <= rng.HighRune {
|
||||
matches = append(matches, terminalMatch{
|
||||
TerminalID: rng.ID,
|
||||
Length: runeLen,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
// terminalCount returns the number of terminals
|
||||
func (m *terminalMatcher) terminalCount() int {
|
||||
return len(m.terminals)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
@@ -109,7 +110,11 @@ type input struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
TopK int
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
JSONMode bool // Enable JSON grammar constraint
|
||||
GrammarEBNF string // Raw EBNF grammar string
|
||||
GrammarStart string // Start rule name for grammar
|
||||
Vocab []string // Vocabulary for constrained decoding
|
||||
}
|
||||
|
||||
type output struct {
|
||||
@@ -127,9 +132,11 @@ type Decoder struct {
|
||||
temp float32
|
||||
topK int
|
||||
topP float32
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
grammar *grammar.Engine // Optional grammar constraint engine
|
||||
grammarVocab []string // Vocab for grammar debug
|
||||
}
|
||||
|
||||
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
@@ -145,6 +152,12 @@ func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
}
|
||||
}
|
||||
|
||||
// SetGrammar enables constrained decoding with the given grammar engine.
|
||||
func (d *Decoder) SetGrammar(g *grammar.Engine, vocab []string) {
|
||||
d.grammar = g
|
||||
d.grammarVocab = vocab
|
||||
}
|
||||
|
||||
// SetImage sets the image for multimodal prefill (call before prefill)
|
||||
func (d *Decoder) SetImage(img *mlx.Array) {
|
||||
d.image = img
|
||||
@@ -222,6 +235,16 @@ func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
|
||||
// Apply grammar constraints if enabled
|
||||
if d.grammar != nil {
|
||||
shape := logits.Shape()
|
||||
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
|
||||
maskedLogits := d.grammar.ApplyMask(lastLogits)
|
||||
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
|
||||
}
|
||||
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
@@ -245,6 +268,15 @@ func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Sync on previous token FIRST to get its value and update grammar state
|
||||
// This must happen before computing the next mask
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Update grammar state with the token we just synced
|
||||
if d.grammar != nil {
|
||||
d.grammar.Accept(int(val))
|
||||
}
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
@@ -253,6 +285,18 @@ func (d *Decoder) step() int32 {
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
|
||||
// Apply grammar constraints if enabled
|
||||
if d.grammar != nil {
|
||||
// Get last position logits: [1, 1, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
lastLogits := mlx.Slice(logits, []int32{0, shape[1] - 1, 0}, []int32{1, shape[1], d.vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, d.vocabSize)
|
||||
maskedLogits := d.grammar.ApplyMask(lastLogits)
|
||||
// Reshape back to [1, 1, vocab] for sample()
|
||||
logits = mlx.Reshape(maskedLogits, 1, 1, d.vocabSize)
|
||||
}
|
||||
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
@@ -262,9 +306,6 @@ func (d *Decoder) step() int32 {
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
@@ -289,6 +330,48 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
tok := m.Tokenizer()
|
||||
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
||||
|
||||
// Set up grammar constraint if enabled
|
||||
var grammarEngine *grammar.Engine
|
||||
var grammarVocab []string
|
||||
if (in.JSONMode || in.GrammarEBNF != "") && len(in.Vocab) > 0 {
|
||||
var compiled *grammar.Grammar
|
||||
var err error
|
||||
|
||||
if in.GrammarEBNF != "" {
|
||||
// Custom EBNF grammar
|
||||
startRule := in.GrammarStart
|
||||
if startRule == "" {
|
||||
startRule = "root"
|
||||
}
|
||||
compiled, err = grammar.ParseEBNF(in.GrammarEBNF, startRule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse grammar: %w", err)
|
||||
}
|
||||
fmt.Printf("[Grammar mode: start=%s]\n", startRule)
|
||||
} else {
|
||||
// JSON object grammar (only allows objects at top level)
|
||||
compiled, err = grammar.JSONObjectGrammar()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JSON grammar: %w", err)
|
||||
}
|
||||
fmt.Println("[JSON object mode enabled]")
|
||||
}
|
||||
|
||||
// Pad vocab to match model's vocab size if needed
|
||||
grammarVocab = in.Vocab
|
||||
modelVocabSize := int(m.VocabSize())
|
||||
if len(grammarVocab) < modelVocabSize {
|
||||
padded := make([]string, modelVocabSize)
|
||||
copy(padded, grammarVocab)
|
||||
grammarVocab = padded
|
||||
}
|
||||
grammarEngine, err = grammar.NewEngine(compiled, grammarVocab)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create grammar engine: %w", err)
|
||||
}
|
||||
defer grammarEngine.Close()
|
||||
}
|
||||
|
||||
// Apply chat template - use image template if we have an image
|
||||
prompt := in.Prompt
|
||||
var tokens []int32
|
||||
@@ -304,6 +387,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
tokens = tok.Encode(prompt, true)
|
||||
}
|
||||
|
||||
if grammarEngine != nil {
|
||||
dec.SetGrammar(grammarEngine, grammarVocab)
|
||||
}
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(tokens)
|
||||
// Prefill measurement should include time to first token (like mlx-lm)
|
||||
@@ -327,6 +414,11 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
// Check if grammar is complete after first token
|
||||
if dec.grammar != nil && dec.grammar.IsComplete() {
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
||||
return nil
|
||||
}
|
||||
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
if ctx.Err() != nil {
|
||||
@@ -341,6 +433,10 @@ func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
// Check if grammar is complete (valid JSON document finished)
|
||||
if dec.grammar != nil && dec.grammar.IsComplete() {
|
||||
break
|
||||
}
|
||||
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
|
||||
@@ -44,6 +44,9 @@ func main() {
|
||||
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
||||
topK := flag.Int("top-k", 40, "Top-k sampling")
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
jsonMode := flag.Bool("json", false, "Enable JSON grammar constraint (output will be valid JSON)")
|
||||
grammarFile := flag.String("grammar", "", "Path to EBNF grammar file for constrained decoding")
|
||||
grammarStart := flag.String("grammar-start", "root", "Start rule name for grammar (default: root)")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
@@ -186,6 +189,20 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// Get vocab for constrained decoding if needed
|
||||
var vocab []string
|
||||
var grammarEBNF string
|
||||
if *jsonMode || *grammarFile != "" {
|
||||
vocab = m.Tokenizer().Vocab()
|
||||
}
|
||||
if *grammarFile != "" {
|
||||
data, err := os.ReadFile(*grammarFile)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to read grammar file: %v", err)
|
||||
}
|
||||
grammarEBNF = string(data)
|
||||
}
|
||||
|
||||
err = generate(context.Background(), m, input{
|
||||
Prompt: *prompt,
|
||||
Image: image,
|
||||
@@ -194,6 +211,10 @@ func main() {
|
||||
TopP: float32(*topP),
|
||||
TopK: *topK,
|
||||
WiredLimitGB: *wiredLimitGB,
|
||||
JSONMode: *jsonMode,
|
||||
GrammarEBNF: grammarEBNF,
|
||||
GrammarStart: *grammarStart,
|
||||
Vocab: vocab,
|
||||
}, func(out output) {
|
||||
if out.Text != "" {
|
||||
fmt.Print(out.Text)
|
||||
|
||||
@@ -1729,6 +1729,14 @@ func init() {
|
||||
// Lock main goroutine to OS thread for CUDA context stability.
|
||||
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
|
||||
runtime.LockOSThread()
|
||||
// Avoid Metal device init crashes on systems without Metal.
|
||||
if runtime.GOOS == "darwin" {
|
||||
if MetalIsAvailable() {
|
||||
SetDefaultDeviceGPU()
|
||||
} else {
|
||||
SetDefaultDeviceCPU()
|
||||
}
|
||||
}
|
||||
RandomState[0] = RandomKey(uint64(time.Now().UnixMilli()))
|
||||
Keep(RandomState[0]) // Global state should persist
|
||||
}
|
||||
|
||||
@@ -311,8 +311,8 @@ type Model struct {
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
func (m *Model) NewCache(int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
|
||||
@@ -1082,6 +1082,12 @@ func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
|
||||
return id, ok
|
||||
}
|
||||
|
||||
// Vocab returns the vocabulary as a slice of token strings indexed by token ID.
|
||||
// This is useful for constrained decoding where we need to map tokens to grammar symbols.
|
||||
func (t *Tokenizer) Vocab() []string {
|
||||
return t.vocab.Values
|
||||
}
|
||||
|
||||
// LoadVocabMerges loads a tokenizer from vocab.json + merges.txt format (GPT-style)
|
||||
func LoadVocabMerges(dir string) (*Tokenizer, error) {
|
||||
vocabPath := dir + "/vocab.json"
|
||||
|
||||
Reference in New Issue
Block a user