Compare commits

..

4 Commits

Author SHA1 Message Date
ParthSareen
e18540fecc sample: wip structured outputs work 2025-03-27 11:26:49 -07:00
ParthSareen
040e65abce model: expose model vocab for structured outputs 2025-03-27 11:26:18 -07:00
Blake Mizerany
564b9b48af grammar: introduce new grammar package
This package provides a way to convert JSON schemas to equivalent EBNF.
It is intended to be a replacement to llama.cpp's schema_to_grammar.

This is still an early version and does not yet support all JSON schema
features. The to-do list includes:

- minumum/maximum constraints on integer types
- minLength/maxLength constraints on string types
- defs and refs
2025-03-25 16:17:18 -07:00
copeland3300
5e0b904e88 docs: add flags to example linux log output command (#9852) 2025-03-25 09:52:23 -07:00
14 changed files with 390 additions and 1658 deletions

View File

@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
On **Linux** systems with systemd, the logs can be found with this command:
```shell
journalctl -u ollama --no-pager
journalctl -u ollama --no-pager --follow --pager-end
```
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:

View File

@@ -32,6 +32,7 @@ type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocab() *Vocabulary
}

View File

@@ -49,14 +49,14 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePieceModel) Vocab() *Vocabulary {
return spm.vocab
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {

View File

@@ -468,20 +468,6 @@ func (s *Server) processBatch() error {
return fmt.Errorf("failed to sample token: %w", err)
}
if seq.sampler.JSONSampler != nil {
_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
if err != nil {
return fmt.Errorf("failed to update state: %w", err)
}
}
if seq.sampler.PythonSampler != nil {
err = seq.sampler.PythonSampler.UpdateState(token)
if err != nil {
return fmt.Errorf("failed to update state: %w", err)
}
}
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
@@ -576,21 +562,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
// if err != nil {
// http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
// return
// }
// jsonSampler = nil
pythonSampler := &sample.PythonSampler{}
functions := []sample.PythonFunction{
{
Name: "add_two_strings",
Args: []string{"s1", "s2"},
Types: []string{"string", "string"},
},
}
pythonSampler.Init(functions, s.model.(model.TextProcessor))
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
@@ -598,9 +569,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.MinP,
req.Options.Seed,
grammar,
nil,
pythonSampler,
// nil,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{

View File

@@ -1,53 +0,0 @@
package sample
var DefaultGrammar = map[string]string{
"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
"null": `"null"`,
"object": `"{" (kv ("," kv)*)? "}"`,
"array": `"[" (value ("," value)*)? "]"`,
"kv": `string ":" value`,
"integer": `"0" | [1-9] [0-9]*`,
"number": `"-"? integer frac? exp?`,
"frac": `"." [0-9]+`,
"exp": `("e" | "E") ("+" | "-") [0-9]+`,
"string": `"\"" char* "\""`,
"escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
"char": `[^"\\] | escape`,
"space": `(" " | "\t" | "\n" | "\r")*`,
"hex": `[0-9] | [a-f] | [A-F]`,
"boolean": `"true" | "false"`,
"value": `object | array | string | number | boolean | "null"`,
}
const jsonString = `object | array`
type StateMachine struct {
states map[rune]State
}
type State struct {
NextStates []string
// bitmask?
Mask []bool
IsTerminal bool
}
func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
states := make(map[rune]State)
var cumu string
flag := false
for _, r := range startRule {
if r == '"' {
flag = !flag
}
if flag {
cumu += string(r)
}
}
sm := &StateMachine{
states: states,
}
return sm
}

View File

@@ -1,138 +0,0 @@
package sample
import (
"testing"
)
func TestGrammarParsing(t *testing.T) {
tests := []struct {
name string
grammar map[string]string
startRule string
input string
want bool
}{
{
name: "simple object",
grammar: map[string]string{
"object": `"{" "}"`,
},
startRule: "object",
input: "{}",
want: true,
},
{
name: "simple array",
grammar: map[string]string{
"array": `"[" "]"`,
},
startRule: "array",
input: "[]",
want: true,
},
{
name: "character class",
grammar: map[string]string{
"digit": `[0-9]`,
},
startRule: "digit",
input: "5",
want: true,
},
{
name: "alternation",
grammar: map[string]string{
"bool": `"true" | "false"`,
},
startRule: "bool",
input: "true",
want: true,
},
{
name: "repetition",
grammar: map[string]string{
"digits": `[0-9]+`,
},
startRule: "digits",
input: "123",
want: true,
},
{
name: "nested rules",
grammar: map[string]string{
"value": `object | array`,
"object": `"{" "}"`,
"array": `"[" "]"`,
},
startRule: "value",
input: "{}",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.grammar)
machine, err := parser.Parse(tt.startRule)
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
matcher := NewMatcher(machine)
got, err := matcher.Match(tt.input)
if err != nil {
t.Fatalf("Match() error = %v", err)
}
if got != tt.want {
t.Errorf("Match() = %v, want %v", got, tt.want)
}
})
}
}
func TestJSONGrammar(t *testing.T) {
tests := []struct {
name string
input string
want bool
}{
{"empty object", "{}", true},
{"empty array", "[]", true},
{"simple string", `"hello"`, true},
{"simple number", "123", true},
{"simple boolean", "true", true},
{"simple null", "null", true},
{"object with string", `{"key": "value"}`, true},
{"array with numbers", "[1, 2, 3]", true},
{"nested object", `{"obj": {"key": "value"}}`, true},
{"nested array", `[1, [2, 3], 4]`, true},
{"invalid object", "{", false},
{"invalid array", "[1, 2", false},
{"invalid string", `"hello`, false},
}
parser := NewParser(DefaultGrammar)
machine, err := parser.Parse("value")
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
matcher := NewMatcher(machine)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := matcher.Match(tt.input)
if tt.want {
if err != nil {
t.Errorf("Match() error = %v", err)
}
if !got {
t.Errorf("Match() = false, want true")
}
} else {
if err == nil && got {
t.Errorf("Match() = true, want false")
}
}
})
}
}

View File

@@ -1,160 +0,0 @@
package sample
import (
"fmt"
)
type JSONState int
const (
StateStart JSONState = iota
StateInObject
StateInObjectKey
StateInStructuredKey
StateInStructuredValue
StateNewline
StateTab
StateSpace
StateInString
StateInInt
StateInFloat
StateInBool
StateInNull
StateInColon
StateInComma
StateInTab
StateInSpaceToValue
StateInSpaceEndValue
StateInNewlineEndValue
StateInObjSpace
StateInList
StateInListComma
StateInValue
StateInValueEnd
StateInListEnd
StateInListObjectEnd
StateInNewline
StateInNumber
StateInNumberEnd
StateInStringEnd
StateInObjectKeyEnd
StateTerminate
StateInObjectEnd
StateTransitioningToTerminate
StateInListStartJSON
)
var JSONStates = []JSONState{
StateStart,
StateInObject,
StateInObjectKey,
StateInStructuredKey,
StateInStructuredValue,
StateNewline,
StateTab,
StateSpace,
StateInString,
StateInInt,
StateInFloat,
StateInBool,
StateInNull,
StateInColon,
StateInComma,
StateInTab,
StateInSpaceToValue,
StateInSpaceEndValue,
StateInNewlineEndValue,
StateInObjSpace,
StateInListStartJSON,
StateInList,
StateInListComma,
StateInValue,
StateInValueEnd,
StateInListEnd,
StateInListObjectEnd,
StateInNewline,
StateInNumber,
StateInNumberEnd,
StateInStringEnd,
StateInObjectKeyEnd,
StateTerminate,
StateInObjectEnd,
StateTransitioningToTerminate,
}
func (s JSONState) String() string {
switch s {
case StateStart:
return "StateStart"
case StateInObject:
return "StateInObject"
case StateInObjectKey:
return "StateInObjectKey"
case StateInStructuredKey:
return "StateInStructuredKey"
case StateInStructuredValue:
return "StateInStructuredValue"
case StateNewline:
return "StateNewline"
case StateTab:
return "StateTab"
case StateSpace:
return "StateSpace"
case StateInString:
return "StateInString"
case StateInInt:
return "StateInInt"
case StateInFloat:
return "StateInFloat"
case StateInBool:
return "StateInBool"
case StateInNull:
return "StateInNull"
case StateInColon:
return "StateInColon"
case StateInComma:
return "StateInComma"
case StateInTab:
return "StateInTab"
case StateInSpaceToValue:
return "StateInSpaceToValue"
case StateInSpaceEndValue:
return "StateInSpaceEndValue"
case StateInNewlineEndValue:
return "StateInNewlineEndValue"
case StateInObjSpace:
return "StateInObjSpace"
case StateInList:
return "StateInList"
case StateInListComma:
return "StateInListComma"
case StateInValue:
return "StateInValue"
case StateInValueEnd:
return "StateInValueEnd"
case StateInListEnd:
return "StateInListEnd"
case StateInListObjectEnd:
return "StateInListObjectEnd"
case StateInNewline:
return "StateInNewline"
case StateInNumber:
return "StateInNumber"
case StateInNumberEnd:
return "StateInNumberEnd"
case StateInStringEnd:
return "StateInStringEnd"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateTerminate:
return "StateTerminate"
case StateInObjectEnd:
return "StateInObjectEnd"
case StateTransitioningToTerminate:
return "StateTransitioningToTerminate"
case StateInListStartJSON:
return "StateInListStartJSON"
default:
return fmt.Sprintf("Unknown state: %d", s)
}
}

View File

@@ -1,327 +0,0 @@
package sample
import (
"fmt"
"slices"
"github.com/ollama/ollama/model"
)
/*
Key JSON rules to consider:
1. Whitespace handling:
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
- Current code only handles some whitespace cases
2. Number validation:
- Need proper validation for special number cases like -0
- Should handle .5 style decimals
- Need limits on scientific notation (e, E)
3. String escaping:
- Currently marks \ as invalid but should allow escaped sequences:
- \"
- \n
- \u1234 unicode escapes
4. Empty object/array transitions:
- Direct {} and [] cases could be more explicit
- Need clear transitions for these edge cases
5. Nested depth limits:
- No protection against excessive nesting
- Could cause stack overflow with deeply nested structures
*/
// TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
var (
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
)
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDA struct {
State JSONState
TransitionEdges map[rune]*PDA
MaskTokenIDToNode map[int32]*PDA
}
func NewPDANode(state JSONState) *PDA {
return &PDA{
State: state,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
}
type PDAGraphBuilder struct {
proc model.TextProcessor
decodedToks []string
stateToNodeMap map[JSONState]*PDA
tokenToStatesMap map[int32][]JSONState
}
func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap := make(map[JSONState]*PDA)
for _, state := range JSONStates {
stateToNodeMap[state] = NewPDANode(state)
}
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
// TODO: update naming here - and revisit values
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// new line
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
// new line end value
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
// TODO: see if this is needed for formatting
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// where values should be
// this could be combined but the probl might change, we're alr doing a skip ahead
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
// Leads to a value
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
// Values
// string node
stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
// String end node
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list node
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
// early end
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// list end node
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// empty list
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
// null node
for _, r := range validNullRunes {
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
}
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list comma
// should point to values
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
// list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// TODO: not sure if this is needed
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// bool node
for _, r := range validBoolRunes {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
}
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// comma node
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// todo: review this space transition
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
// space end value
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
b.stateToNodeMap = stateToNodeMap
if err := b.preComputeValidStates(); err != nil {
return err
}
return nil
}
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
}
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
for _, r := range validNumberRunes {
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
// TODO(parthsareen): force the output and shift similar to structured outputs
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
}
func (b *PDAGraphBuilder) preComputeValidStates() error {
for _, node := range b.stateToNodeMap {
// if node.State == StateInObjectKey {
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
// fmt.Println("copying string mask to object key mask")
// }
// }
if err := b.CreateMask(node); err != nil {
return err
}
}
return nil
}
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
// TODO: make can be somewhere else too
b.tokenToStatesMap = make(map[int32][]JSONState)
for i, t := range b.decodedToks {
for _, r := range t {
if r == '"' {
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
}
}
}
return nil
}
// TODO: the mask for obj key and string should be the same?
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range b.decodedToks {
token := b.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
if consumedSpecialRunes[r] {
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == StateInString || curNode.State == StateInObjectKey {
return nil, false
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune {
if curNode.State == nextNode.State {
return nil, false
}
consumedSpecialRunes[r] = true
}
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return nextNode, true
}
return nil, false
}

View File

@@ -1,264 +0,0 @@
package sample
import (
"fmt"
"math"
"runtime"
"time"
"github.com/ollama/ollama/model"
)
// TODO: safety in case of invalid json
// TODO: partial JSON matching?
// TODO: interfaces to cleanup with return values
// TODO this interface shouldn't be the sampler - should just use Sampler
// TODO: add penalties for string \n stuff
// TODO: minimize number of fwd passes if there is only one match
// TODO: greedy sample initially and then backtrack if no match
type PushdownSampler struct {
PDAGraphBuilder
curNode *PDA
braceStack []rune
stateCounter uint32
}
// graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
start := time.Now()
fmt.Println("--------------------------------")
fmt.Println("PDA sampler")
fmt.Println("--------------------------------")
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
vocab := proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
gb := &PDAGraphBuilder{
proc: proc,
decodedToks: decodedToks,
}
if err := gb.BuildGraph(); err != nil {
return nil, err
}
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
// TODO: this can be simplified
return &PushdownSampler{
curNode: gb.stateToNodeMap[StateStart],
PDAGraphBuilder: *gb,
braceStack: []rune{},
stateCounter: 0,
}, nil
}
// TODO: need to add resampling logic if the first sample was not good
// greedy sample + backtrack?
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
switch s.curNode.State {
case StateInString:
return s.maskLogits(logits, s.curNode)
case StateInListEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
return forceFinish(s, logits)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateTerminate:
return forceFinish(s, logits)
case StateInObjectEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
return forceFinish(s, logits)
}
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateInComma:
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
}
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
for i := range logits {
if s.proc.Is(int32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = float32(math.Inf(-1))
}
}
return logits, nil
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return nil, err
}
fmt.Printf(">>> mappedString: %q\n", mappedString)
// Special handling for EOS token in terminate state
if s.curNode.State == StateTerminate {
for _, tokenID := range tokenSlice {
if s.proc.Is(tokenID, model.SpecialEOS) {
return tokenSlice, nil
}
}
}
// flag := -1
// endBraceRunes := []rune{'}', ']'}
for _, r := range mappedString {
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
// // flag = i
// break
// }
if r == rune('{') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('[') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('}') {
if len(s.braceStack) == 0 {
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if top != rune('{') {
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
if r == rune(']') {
if len(s.braceStack) == 0 {
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if top != rune('[') {
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
}
// if flag != -1 {
// tokenSlice = tokenSlice[:flag]
// }
// fmt.Println("flag!", flag)
for _, tokenID := range tokenSlice {
// transition to the next node
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok {
return nil, fmt.Errorf("invalid token: %q", mappedString)
}
fmt.Println("transitioning to", nextNode.State)
// TODO: add a penalty for staying in the same state too long
if nextNode.State == s.curNode.State {
s.stateCounter++
} else {
s.stateCounter = 0
}
s.curNode = nextNode
fmt.Println("updated curNode state", s.curNode.State)
}
return tokenSlice, nil
}
// greedy sample + backtrack?
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float32, len(logits))
for i := range maskedLogits {
maskedLogits[i] = float32(math.Inf(-1))
}
// Only update values for valid token IDs from the mask map
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) {
maskedLogits[tokenID] = logits[tokenID]
}
}
return maskedLogits, nil
}
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
maxLogit := float32(math.Inf(-1))
maxIndex := -1
// Find the maximum logit value among valid tokens
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
maxLogit = logits[tokenID]
maxIndex = int(tokenID)
}
}
if maxIndex == -1 {
return nil, fmt.Errorf("no valid tokens found in mask")
}
logits[0] = float32(maxIndex)
return logits, nil
// return maxIndex, nil
}

View File

@@ -17,14 +17,12 @@ type token struct {
}
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
JSONSampler *JSONSampler
PythonSampler *PythonSampler
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
@@ -32,19 +30,6 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
return -1, errors.New("sample: no logits provided to sample")
}
var err error
if s.JSONSampler != nil {
logits, err = s.JSONSampler.Apply(logits)
if err != nil {
return -1, err
}
}
if s.PythonSampler != nil {
logits, err = s.PythonSampler.ApplyMask(logits)
if err != nil {
return -1, err
}
}
tokens := make([]token, len(logits))
for i := range logits {
tokens[i].id = int32(i)
@@ -142,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@@ -170,14 +155,12 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
JSONSampler: jsonSampler,
PythonSampler: pythonSampler,
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
}
}

176
sample/state_machine.go Normal file
View File

@@ -0,0 +1,176 @@
package sample
import (
"bytes"
"strings"
"github.com/ollama/ollama/model"
)
type Node struct {
TransitionEdges map[rune]*Node
}
type Graph struct {
proc model.TextProcessor
decodedToks []string
curNode *Node
grammar []byte
rules map[string]string
}
// baseRules is the set of rules that are used to parse the grammar
// JSON grammar from RFC 7159
var baseRules = map[string]string{
"object": "\"{\" (kv (\",\" kv)*)? \"}\"",
"array": "\"[\" (value (\",\" value)*)? \"]\"",
"string": "\"\\\"\" char* \"\\\"\"",
"number": "\"-\"? integer frac? exp?",
"kv": "string \":\" value",
"integer": "\"0\" | [1-9] [0-9]*",
"frac": "\".\" [0-9]+",
"exp": "(\"e\" | \"E\") (\"+\" | \"-\") [0-9]+",
"escape": "[\"/\" | \"b\" | \"f\" | \"n\" | \"r\" | \"t\" | unicode]",
"char": "[^\"\\\\] | escape",
"space": "(\" \" | \"\\t\" | \"\\n\" | \"\\r\")*",
"hex": "[0-9] | [a-f] | [A-F]",
"boolean": "\"true\" | \"false\"",
"value": "object | array | string | number | boolean | \"null\"",
"null": "\"null\"",
}
func (g *Graph) BuildGraph(node *Node) error {
vocab := g.proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := g.proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
g.decodedToks = decodedToks
g.rules = baseRules
g.rootPrefixes()
rootNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
g.parseRule(g.rules["root"], rootNode)
return nil
}
// rootPrefixes extracts all root prefixes from the grammar
// and parses the grammar string to extract root prefixes
func (g *Graph) rootPrefixes() {
lines := bytes.Split(g.grammar, []byte("\n"))
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 || bytes.HasPrefix(line, []byte("#")) {
continue
}
parts := bytes.SplitN(line, []byte("::="), 2)
if len(parts) != 2 {
continue
}
ruleName := string(bytes.TrimSpace(parts[0]))
if strings.HasPrefix(ruleName, "root") {
g.rules[ruleName] = string(bytes.TrimSpace(parts[1]))
}
}
}
// parseRule parses a grammar rule and returns a Node
func (g *Graph) parseRule(rule string, curNode *Node) *Node {
/*
Here are the special characters in BNF grammar and their functions:
::= - Definition operator, means "is defined as"
| - Alternation, means "or"
* - Zero or more repetitions of preceding element
+ - One or more repetitions
? - Optional (zero or one occurrence)
[] - Character class, matches any single character within brackets
[^] - Negated character class, matches any character NOT listed
() - Grouping of elements
- - Range operator in character classes (e.g., [a-z])
"" - Literal string match
*/
// Split rule into tokens by whitespace
tokens := strings.Fields(rule)
if len(tokens) == 0 {
return &Node{
TransitionEdges: make(map[rune]*Node),
}
}
// Handle integer rule
if strings.Contains(rule, "[0-9]+") {
// Create node for first digit 1-9
firstDigitNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
for r := '1'; r <= '9'; r++ {
curNode.TransitionEdges[r] = firstDigitNode
}
// Create node for subsequent digits 0-9
zeroToNineNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
for r := '0'; r <= '9'; r++ {
// Loop back to same node for * operator
zeroToNineNode.TransitionEdges[r] = zeroToNineNode
}
// Connect first digit to subsequent digits
firstDigitNode.TransitionEdges = zeroToNineNode.TransitionEdges
// Also handle the "0" case
if strings.Contains(rule, "\"0\"") {
zeroNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode.TransitionEdges['0'] = zeroNode
}
return curNode
}
// recursive case
// grammar options
// TODO: handle left recursion
if strings.Contains(rule, "|") {
parts := strings.Split(rule, "|")
savedNode := curNode
for _, part := range parts {
// TODO: add correct transitions
g.parseRule(part, savedNode)
}
}
for _, token := range tokens {
if strings.HasPrefix(token, "\"") && strings.HasSuffix(token, "\"") {
token = strings.Trim(token, "\"")
for _, r := range token {
newNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode.TransitionEdges[r] = newNode
curNode = newNode
}
// strNode := &Node{
// TransitionEdges: make(map[rune]*Node),
// }
// TODO: length constraint
// to self
}
}
return curNode
}

View File

@@ -1,299 +1,3 @@
package sample
import (
"fmt"
"log/slog"
"runtime"
"time"
"github.com/ollama/ollama/grammar/jsonschema"
"github.com/ollama/ollama/model"
)
type JSONSampler struct {
schema *jsonschema.Schema
propIdx int
propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler
decodedToks []string
}
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
slog.Info("NewJSONSampler", "schema", schema)
if proc == nil {
return nil, fmt.Errorf("TextProcessor cannot be nil")
}
pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
}
if schema == nil {
return &JSONSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
// fmt.Println("schema not nil")
so := &JSONSampler{
schema: schema,
propIdx: -1,
propToNodeMap: make(map[string]*PDA),
pdaSampler: pdaSampler,
}
so.schemaToGraph()
// Benchmark token decoding
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
vocab := proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
so.decodedToks = decodedToks
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Token decode time = %v\n", time.Since(start))
fmt.Println("--------------------------------")
fmt.Println("SOSampler")
fmt.Println("--------------------------------")
// Benchmark this section
start = time.Now()
runtime.ReadMemStats(&m)
before = m.Alloc
// TODO: still messed up
// TODO: recursion use case
// key masks
for _, prop := range so.schema.Properties {
node := so.propToNodeMap[prop.Name]
// propName -> node
curState := node.State
fromNode := node
so.pdaSampler.CreateMask(fromNode)
for curState == StateInStructuredKey {
// there is only one edge
for r, toNode := range fromNode.TransitionEdges {
fmt.Println("rune", r, "edge", toNode.State)
so.pdaSampler.CreateMask(toNode)
fmt.Printf("created mask for %c\n", r)
curState = toNode.State
fmt.Println("next state", curState)
// TODO: theres an extra gen for " right now
fromNode = toNode
}
}
if curState != StateInColon {
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
}
// so.pdaSampler.CreateMask(fromNode)
fromNode = fromNode.TransitionEdges[' ']
so.pdaSampler.CreateMask(fromNode)
curState = fromNode.State
for _, toNode := range fromNode.TransitionEdges {
fmt.Println("toNode", toNode.State)
}
}
// runtime.ReadMemStats(&m)
// after = m.Alloc
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
// fmt.Println("--------------------------------")
return so, nil
}
func (s *JSONSampler) schemaToGraph() {
schemaType := s.schema.EffectiveType()
switch schemaType {
case "object":
// TODO: see if we need to connect these to the JSON graph
// each prop is a key
for _, prop := range s.schema.Properties {
// name of key
name := prop.Name
keyNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode := keyNode
for _, r := range name {
runeNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
// fmt.Println("runeNode created", runeNode.State)
// fmt.Printf("runeNode created %c\n", r)
// since alloc on heap connections wil still map
prevNode.TransitionEdges[r] = runeNode
prevNode = runeNode
}
// point to end of object key node after all chars are done
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
// link to value node
// Create a node for the end of the key (after the closing quote)
stringEndNode := &PDA{
State: StateInStructuredKey,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges['"'] = stringEndNode
prevNode = stringEndNode
// Add transition for colon after key
colonNode := &PDA{
State: StateInColon,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[':'] = colonNode
prevNode = colonNode
// Add transition for space after colon
spaceNode := &PDA{
State: StateInSpaceToValue,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[' '] = spaceNode
prevNode = spaceNode
value := prop.Type
switch value {
case "object":
fmt.Println("object under key: ", name)
case "array":
fmt.Println("array under key: ", name)
case "string":
fmt.Println("string under key: ", name)
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
case "number":
fmt.Println("number under key: ", name)
for _, r := range validNumberRunes {
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
}
case "boolean":
fmt.Println("boolean under key: ", name)
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
}
// points to start of the key
s.propToNodeMap[name] = keyNode
fmt.Println("name", name, "keyNode", keyNode.State)
}
}
// TODO: do values + recursion
}
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
switch s.pdaSampler.curNode.State {
// TODO: doesnt account for multi rune case
case StateInObjectKey:
if s.propIdx > len(s.schema.Properties)-1 {
return nil, fmt.Errorf("propIdx out of bounds")
}
// fmt.Println("in object key - structured outputs")
// TODO: this tracking should probably be coming from a stack to track nested objects
// simple case
s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
// Will only happen for the last prop - can also be precomputed.
if s.propIdx == len(s.schema.Properties)-1 {
// todo: if i incremenet propidx then i know im in last value as well
switch s.pdaSampler.curNode.State {
case StateInObjectEnd:
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
s.pdaSampler.curNode = NewPDANode(StateTerminate)
s.propIdx++
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
delete(s.pdaSampler.curNode.TransitionEdges, ',')
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
s.propIdx++
}
}
return s.pdaSampler.Apply(logits)
}
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
if err != nil {
return nil, err
}
if s.schema == nil {
// Don't need to update state for unconstrained JSON sampling
return tokenSlice, nil
}
switch s.pdaSampler.curNode.State {
case StateInObjectKey:
s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
// TODO: this does not work - mike
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
// if err != nil {
// return nil, err
// }
// fmt.Println("str", str)
return tokenSlice, nil
default:
return tokenSlice, nil
}
}
type StructuredOutput struct{}

View File

@@ -0,0 +1,194 @@
package sample
import (
"testing"
"github.com/ollama/ollama/model"
)
func TestBuildGraph(t *testing.T) {
tests := []struct {
name string
grammar []byte
wantErr bool
}{
{
name: "empty grammar",
grammar: []byte{},
wantErr: false,
},
{
name: "valid grammar",
grammar: []byte(`root ::= value
value ::= string | number`),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
proc: &mockProcessor{},
grammar: tt.grammar,
rules: make(map[string]string),
}
node := &Node{
TransitionEdges: make(map[rune]*Node),
}
err := g.BuildGraph(node)
if (err != nil) != tt.wantErr {
t.Errorf("BuildGraph() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
if len(g.decodedToks) == 0 {
t.Error("Expected decoded tokens, got none")
}
if len(g.rules) == 0 {
t.Error("Expected rules to be populated")
}
}
})
}
}
func TestRootPrefixes(t *testing.T) {
tests := []struct {
name string
grammar []byte
expected map[string]string
}{
{
name: "empty grammar",
grammar: []byte{},
expected: map[string]string{},
},
{
name: "grammar with root prefix",
grammar: []byte(`root ::= value
root_string ::= string`),
expected: map[string]string{
"root": "value",
"root_string": "string",
},
},
{
name: "grammar with comments and empty lines",
grammar: []byte(`# comment
root ::= value
# another comment
root_number ::= number`),
expected: map[string]string{
"root": "value",
"root_number": "number",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
grammar: tt.grammar,
rules: make(map[string]string),
}
g.rootPrefixes()
for k, v := range tt.expected {
if actual, ok := g.rules[k]; !ok || actual != v {
t.Errorf("Expected rule %s = %s, got %s", k, v, actual)
}
}
})
}
}
func TestParseRule(t *testing.T) {
tests := []struct {
name string
rule string
expected string
}{
{
name: "empty rule",
rule: "",
expected: "",
},
{
name: "simple string",
rule: "root ::= \"test_string\"",
expected: "test_string",
},
{
name: "simple string",
rule: "root ::= \"test_string\" | \"test_string2\"",
expected: "test_stringtest_string2",
},
{
name: "integer",
rule: "root ::= [0-9]+",
// TODO: this is infinite acutally
expected: "0123456789",
},
// TODO: handle left recursion
// {
// name: "left recursion",
// rule: "root ::= root \"test_string\"",
// expected: "test_string",
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
rules: make(map[string]string),
}
rootNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode := rootNode
g.parseRule(tt.rule, curNode)
sb := ""
for {
if len(curNode.TransitionEdges) == 0 {
break
}
for r, n := range curNode.TransitionEdges {
sb += string(r)
curNode = n
}
t.Logf("sb: %s", sb)
}
if sb != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, sb)
}
})
}
}
// mockProcessor implements the TextProcessor interface for testing
type mockProcessor struct{}
func (m *mockProcessor) Decode(tokens []int32) (string, error) {
return "test", nil
}
func (m *mockProcessor) Vocab() *model.Vocabulary {
return &model.Vocabulary{
Values: []string{"test1", "test2"},
}
}
func (m *mockProcessor) Encode(s string, addSpecial bool) ([]int32, error) {
return []int32{0, 1}, nil
}
func (m *mockProcessor) Is(token int32, special model.Special) bool {
return false
}

View File

@@ -1,352 +0,0 @@
package sample
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/model"
)
type PythonState int
const (
PythonStateStart PythonState = iota
StateInFunction
StateInFunctionArgs
StateInFunctionArgsType
StateInFunctionEnd
PStateInString
PStateInStringEnd
PStateInNumber
PStateInList
PStateInListEnd
PStateInDict
PStateInDictEnd
PStateInTuple
PStateInTupleEnd
PStateTerminate
)
func (s PythonState) String() string {
switch s {
case PythonStateStart:
return "PythonStateStart"
case StateInFunction:
return "StateInFunction"
case StateInFunctionArgs:
return "StateInFunctionArgs"
case StateInFunctionArgsType:
return "StateInFunctionArgsType"
case StateInFunctionEnd:
return "StateInFunctionEnd"
case PStateInString:
return "PStateInString"
case PStateInStringEnd:
return "PStateInStringEnd"
case PStateInNumber:
return "PStateInNumber"
case PStateInList:
return "PStateInList"
case PStateInListEnd:
return "PStateInListEnd"
case PStateInDict:
return "PStateInDict"
case PStateInDictEnd:
return "PStateInDictEnd"
case PStateInTuple:
return "PStateInTuple"
case PStateInTupleEnd:
return "PStateInTupleEnd"
case PStateTerminate:
return "PStateTerminate"
default:
return fmt.Sprintf("PythonState(%d)", s)
}
}
var PythonStates = []PythonState{
PythonStateStart,
StateInFunction,
StateInFunctionArgs,
StateInFunctionArgsType,
StateInFunctionEnd,
PStateInString,
PStateInStringEnd,
PStateInNumber,
PStateInList,
PStateInListEnd,
PStateInDict,
PStateInDictEnd,
PStateInTuple,
PStateInTupleEnd,
PStateTerminate,
}
type Node struct {
State PythonState
TransitionEdges map[rune]*Node
MaskTokenIDToNode map[int32]*Node
}
func NewNode(state PythonState) *Node {
return &Node{
State: state,
TransitionEdges: make(map[rune]*Node),
MaskTokenIDToNode: make(map[int32]*Node),
}
}
type PythonFunction struct {
Name string
Args []string
Types []string
}
type PythonSampler struct {
stateToNodes map[PythonState]*Node
proc model.TextProcessor
decodedToks []string
curNode *Node
completed int
functions []PythonFunction
}
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
s.proc = proc
s.functions = functions
decodedToks := make([]string, len(proc.Vocab().Values))
for i := range proc.Vocab().Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
s.decodedToks = decodedToks
s.BuildGraph()
for _, function := range functions {
prevNode := s.stateToNodes[PythonStateStart]
for _, r := range function.Name {
nextNode := NewNode(StateInFunction)
prevNode.TransitionEdges[r] = nextNode
if err := s.CreateMask(nextNode); err != nil {
return err
}
fmt.Println("prevNode", prevNode.State)
fmt.Printf("transition edge: %q\n", r)
fmt.Println("nextNode", nextNode.State)
prevNode = nextNode
}
prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
s.CreateMask(prevNode)
prevNode = s.stateToNodes[StateInFunctionArgs]
for i, arg := range function.Args {
for _, r := range arg {
nextNode := NewNode(StateInFunctionArgs)
prevNode.TransitionEdges[r] = nextNode
s.CreateMask(prevNode)
prevNode = nextNode
}
prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
// prevNode = s.stateToNodes[StateInFunctionArgs]
prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
s.CreateMask(prevNode)
prevNode = prevNode.TransitionEdges['=']
switch function.Types[i] {
case "string":
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
s.CreateMask(prevNode.TransitionEdges['"'])
case "number":
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
s.CreateMask(prevNode.TransitionEdges['"'])
}
}
}
s.curNode = s.stateToNodes[PythonStateStart]
fmt.Println("curNode", s.curNode.State)
fmt.Println("transition edges", s.curNode.TransitionEdges)
if err := s.CreateMask(s.curNode); err != nil {
return err
}
fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
for tokenID, node := range s.curNode.MaskTokenIDToNode {
fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
}
return nil
}
func (s *PythonSampler) BuildGraph() error {
s.stateToNodes = make(map[PythonState]*Node)
for _, state := range PythonStates {
s.stateToNodes[state] = NewNode(state)
}
for _, state := range s.stateToNodes {
if err := s.CreateMask(state); err != nil {
return err
}
}
// String
s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
// String end
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// Number
for _, r := range validNumberRunes {
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
}
s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
return nil
}
func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
if s.curNode.State == PStateTerminate {
logits, err := finish(s, logits)
if err != nil {
return nil, err
}
return logits, nil
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
func (s *PythonSampler) UpdateState(token int32) error {
mappedString, err := s.proc.Decode([]int32{token})
if err != nil {
return err
}
fmt.Printf(">>> mappedString: %q\n", mappedString)
if s.curNode.State == PStateTerminate {
if s.proc.Is(token, model.SpecialEOS) {
return nil
}
}
nextNode, ok := s.curNode.MaskTokenIDToNode[token]
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
if mappedString == "\"" {
if s.curNode.State == PStateInStringEnd {
s.completed++
}
if s.completed == len(s.functions) {
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.CreateMask(s.curNode)
}
}
s.curNode = nextNode
fmt.Println("curNode", s.curNode.State)
for r, node := range s.curNode.TransitionEdges {
fmt.Printf("transition edge: %q -> %v\n", r, node.State)
}
if err := s.CreateMask(s.curNode); err != nil {
return err
}
return nil
}
func (s *PythonSampler) CreateMask(node *Node) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range s.decodedToks {
token := s.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
if curNode.State == StateInFunction {
// fmt.Println("cm curNode", curNode.State)
// fmt.Println("cm token", s.decodedToks[i])
}
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
if consumedSpecialRunes[r] {
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
return nil, false
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune {
if curNode.State == nextNode.State {
return nil, false
}
consumedSpecialRunes[r] = true
}
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return nextNode, true
}
return nil, false
}
func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float32, len(logits))
for i := range maskedLogits {
maskedLogits[i] = float32(math.Inf(-1))
}
// Only update values for valid token IDs from the mask map
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) {
maskedLogits[tokenID] = logits[tokenID]
}
}
return maskedLogits, nil
}
func finish(s *PythonSampler, logits []float32) ([]float32, error) {
for i := range logits {
if s.proc.Is(int32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = float32(math.Inf(-1))
}
}
return logits, nil
}