Compare commits

...

2 Commits

Author SHA1 Message Date
ParthSareen
c0aeb3531b runner: add sync between computeBatch and completion 2025-09-10 19:16:28 -07:00
ParthSareen
1e5fecbbc3 runner/parser: allow on-the-fly grammar constraining 2025-09-10 11:50:12 -07:00
4 changed files with 63 additions and 9 deletions

View File

@@ -47,12 +47,13 @@ func (s harmonyParserState) String() string {
} }
type HarmonyParser struct { type HarmonyParser struct {
state harmonyParserState state harmonyParserState
MessageStartTag string MessageStartTag string
MessageEndTag string MessageEndTag string
HeaderEndTag string HeaderEndTag string
acc strings.Builder constraintsAllowed bool
lifetimeAcc strings.Builder acc strings.Builder
lifetimeAcc strings.Builder
} }
type HarmonyEvent interface { type HarmonyEvent interface {
@@ -89,6 +90,10 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant") s.acc.WriteString("<|start|>assistant")
} }
func (s *HarmonyParser) ConstraintsAllowed() bool {
return s.constraintsAllowed
}
func Prefill(lastMessage api.Message) string { func Prefill(lastMessage api.Message) string {
if lastMessage.Role != "assistant" { if lastMessage.Role != "assistant" {
return "" return ""
@@ -341,6 +346,7 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
} }
case "final": case "final":
h.state = harmonyMessageState_Normal h.state = harmonyMessageState_Normal
h.HarmonyParser.constraintsAllowed = true
} }
case HarmonyEventContentEmitted: case HarmonyEventContentEmitted:
logutil.Trace("harmony event content", "content", event.Content, "state", h.state) logutil.Trace("harmony event content", "content", event.Content, "state", h.state)

View File

@@ -33,6 +33,7 @@ type MessageHandler interface {
type ParserInternals interface { type ParserInternals interface {
AddImplicitStartOrPrefill(prefillString string) AddImplicitStartOrPrefill(prefillString string)
ConstraintsAllowed() bool
} }
type ToolParser interface { type ToolParser interface {
@@ -51,6 +52,10 @@ type defaultEngine struct{}
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {} func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
func (defaultEngine) ConstraintsAllowed() bool {
return true
}
type defaultToolParser struct{} type defaultToolParser struct{}
func (defaultToolParser) Add(token string) {} func (defaultToolParser) Add(token string) {}
@@ -104,6 +109,10 @@ func (p *TokenParser) repeatLimitReached(token string) bool {
return p.tokenRepeat >= p.repeatLimit return p.tokenRepeat >= p.repeatLimit
} }
func (p *TokenParser) ConstraintsAllowed() bool {
return p.parserEngine.ConstraintsAllowed()
}
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level // TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
func (p *TokenParser) Drain() []api.ToolCall { func (p *TokenParser) Drain() []api.ToolCall {
toolName, toolContent := p.toolParser.Drain() toolName, toolContent := p.toolParser.Drain()

View File

@@ -62,6 +62,11 @@ type Sequence struct {
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []string
// startGate
startGate *sync.Mutex
grammarReady bool
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
@@ -164,6 +169,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// TODO(jessegross): Ingest cached history for grammar // TODO(jessegross): Ingest cached history for grammar
startGate := &sync.Mutex{}
return &Sequence{ return &Sequence{
ctxs: ctxs, ctxs: ctxs,
mmStore: mmStore, mmStore: mmStore,
@@ -179,6 +185,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly: params.embedding, embeddingOnly: params.embedding,
stop: params.stop, stop: params.stop,
numKeep: params.numKeep, numKeep: params.numKeep,
startGate: startGate,
grammarReady: false,
}, nil }, nil
} }
@@ -707,11 +715,18 @@ func (s *Server) computeBatch(activeBatch batchState) {
// sample a token // sample a token
vocabSize := len(outputs) / len(activeBatch.batch.Outputs) vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
if !seq.grammarReady {
seq.startGate.Lock()
}
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil { if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return return
} }
if !seq.grammarReady {
seq.startGate.Unlock()
}
nextBatchTokens[i].Token = token nextBatchTokens[i].Token = token
@@ -782,8 +797,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
if req.Options == nil { if req.Options == nil {
opts := api.DefaultOptions() opts := api.DefaultOptions()
req.Options = &opts req.Options = &opts
@@ -816,7 +829,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.TopP, req.Options.TopP,
req.Options.MinP, req.Options.MinP,
req.Options.Seed, req.Options.Seed,
grammar, nil,
) )
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
@@ -831,6 +844,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
// this accounts for the default case and also the case where there is a prefill which moves the state of the parser to allow for constraints
if tokenParser.ConstraintsAllowed() {
seq.grammarReady = true
}
// Ensure there is a place to put the sequence, released when removed from s.seqs // Ensure there is a place to put the sequence, released when removed from s.seqs
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
@@ -874,6 +893,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if !seq.grammarReady {
seq.startGate.Lock()
}
var thinking string var thinking string
var err error var err error
content, thinking, err = tokenParser.AddContent(content) content, thinking, err = tokenParser.AddContent(content)
@@ -883,6 +905,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
// only apply the grammar once
if tokenParser.ConstraintsAllowed() && !seq.grammarReady {
seq.sampler.SetGrammar(grammar, &s.mu)
seq.grammarReady = true
seq.startGate.Unlock()
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content, Content: content,
Thinking: thinking, Thinking: thinking,
@@ -909,6 +938,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
if !seq.grammarReady {
seq.startGate.Unlock()
}
} }
} }
} }

View File

@@ -5,6 +5,7 @@ import (
"math" "math"
"math/rand/v2" "math/rand/v2"
"slices" "slices"
"sync"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
@@ -25,6 +26,12 @@ type Sampler struct {
grammar *GrammarSampler grammar *GrammarSampler
} }
func (s *Sampler) SetGrammar(grammar *GrammarSampler, mutex *sync.Mutex) {
mutex.Lock()
defer mutex.Unlock()
s.grammar = grammar
}
func (s *Sampler) Sample(logits []float32) (int32, error) { func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 { if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample") return -1, errors.New("sample: no logits provided to sample")