mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
2 Commits
parth/decr
...
parth/enab
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0aeb3531b | ||
|
|
1e5fecbbc3 |
@@ -47,12 +47,13 @@ func (s harmonyParserState) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HarmonyParser struct {
|
type HarmonyParser struct {
|
||||||
state harmonyParserState
|
state harmonyParserState
|
||||||
MessageStartTag string
|
MessageStartTag string
|
||||||
MessageEndTag string
|
MessageEndTag string
|
||||||
HeaderEndTag string
|
HeaderEndTag string
|
||||||
acc strings.Builder
|
constraintsAllowed bool
|
||||||
lifetimeAcc strings.Builder
|
acc strings.Builder
|
||||||
|
lifetimeAcc strings.Builder
|
||||||
}
|
}
|
||||||
|
|
||||||
type HarmonyEvent interface {
|
type HarmonyEvent interface {
|
||||||
@@ -89,6 +90,10 @@ func (s *HarmonyParser) AddImplicitStart() {
|
|||||||
s.acc.WriteString("<|start|>assistant")
|
s.acc.WriteString("<|start|>assistant")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *HarmonyParser) ConstraintsAllowed() bool {
|
||||||
|
return s.constraintsAllowed
|
||||||
|
}
|
||||||
|
|
||||||
func Prefill(lastMessage api.Message) string {
|
func Prefill(lastMessage api.Message) string {
|
||||||
if lastMessage.Role != "assistant" {
|
if lastMessage.Role != "assistant" {
|
||||||
return ""
|
return ""
|
||||||
@@ -341,6 +346,7 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
|
|||||||
}
|
}
|
||||||
case "final":
|
case "final":
|
||||||
h.state = harmonyMessageState_Normal
|
h.state = harmonyMessageState_Normal
|
||||||
|
h.HarmonyParser.constraintsAllowed = true
|
||||||
}
|
}
|
||||||
case HarmonyEventContentEmitted:
|
case HarmonyEventContentEmitted:
|
||||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ type MessageHandler interface {
|
|||||||
|
|
||||||
type ParserInternals interface {
|
type ParserInternals interface {
|
||||||
AddImplicitStartOrPrefill(prefillString string)
|
AddImplicitStartOrPrefill(prefillString string)
|
||||||
|
ConstraintsAllowed() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolParser interface {
|
type ToolParser interface {
|
||||||
@@ -51,6 +52,10 @@ type defaultEngine struct{}
|
|||||||
|
|
||||||
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
|
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
|
||||||
|
|
||||||
|
func (defaultEngine) ConstraintsAllowed() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
type defaultToolParser struct{}
|
type defaultToolParser struct{}
|
||||||
|
|
||||||
func (defaultToolParser) Add(token string) {}
|
func (defaultToolParser) Add(token string) {}
|
||||||
@@ -104,6 +109,10 @@ func (p *TokenParser) repeatLimitReached(token string) bool {
|
|||||||
return p.tokenRepeat >= p.repeatLimit
|
return p.tokenRepeat >= p.repeatLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *TokenParser) ConstraintsAllowed() bool {
|
||||||
|
return p.parserEngine.ConstraintsAllowed()
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
|
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
|
||||||
func (p *TokenParser) Drain() []api.ToolCall {
|
func (p *TokenParser) Drain() []api.ToolCall {
|
||||||
toolName, toolContent := p.toolParser.Drain()
|
toolName, toolContent := p.toolParser.Drain()
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ type Sequence struct {
|
|||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []string
|
||||||
|
|
||||||
|
// startGate
|
||||||
|
startGate *sync.Mutex
|
||||||
|
|
||||||
|
grammarReady bool
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
|
|
||||||
@@ -164,6 +169,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
|
|
||||||
// TODO(jessegross): Ingest cached history for grammar
|
// TODO(jessegross): Ingest cached history for grammar
|
||||||
|
|
||||||
|
startGate := &sync.Mutex{}
|
||||||
return &Sequence{
|
return &Sequence{
|
||||||
ctxs: ctxs,
|
ctxs: ctxs,
|
||||||
mmStore: mmStore,
|
mmStore: mmStore,
|
||||||
@@ -179,6 +185,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
embeddingOnly: params.embedding,
|
embeddingOnly: params.embedding,
|
||||||
stop: params.stop,
|
stop: params.stop,
|
||||||
numKeep: params.numKeep,
|
numKeep: params.numKeep,
|
||||||
|
startGate: startGate,
|
||||||
|
grammarReady: false,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -707,11 +715,18 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
|||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
||||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||||
|
|
||||||
|
if !seq.grammarReady {
|
||||||
|
seq.startGate.Lock()
|
||||||
|
}
|
||||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !seq.grammarReady {
|
||||||
|
seq.startGate.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
nextBatchTokens[i].Token = token
|
nextBatchTokens[i].Token = token
|
||||||
|
|
||||||
@@ -782,8 +797,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
|
|
||||||
|
|
||||||
if req.Options == nil {
|
if req.Options == nil {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
req.Options = &opts
|
req.Options = &opts
|
||||||
@@ -816,7 +829,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Options.TopP,
|
req.Options.TopP,
|
||||||
req.Options.MinP,
|
req.Options.MinP,
|
||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
@@ -831,6 +844,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
|
||||||
|
// this accounts for the default case and also the case where there is a prefill which moves the state of the parser to allow for constraints
|
||||||
|
if tokenParser.ConstraintsAllowed() {
|
||||||
|
seq.grammarReady = true
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
@@ -874,6 +893,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
|
if !seq.grammarReady {
|
||||||
|
seq.startGate.Lock()
|
||||||
|
}
|
||||||
var thinking string
|
var thinking string
|
||||||
var err error
|
var err error
|
||||||
content, thinking, err = tokenParser.AddContent(content)
|
content, thinking, err = tokenParser.AddContent(content)
|
||||||
@@ -883,6 +905,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// only apply the grammar once
|
||||||
|
if tokenParser.ConstraintsAllowed() && !seq.grammarReady {
|
||||||
|
seq.sampler.SetGrammar(grammar, &s.mu)
|
||||||
|
seq.grammarReady = true
|
||||||
|
seq.startGate.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
Content: content,
|
Content: content,
|
||||||
Thinking: thinking,
|
Thinking: thinking,
|
||||||
@@ -909,6 +938,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !seq.grammarReady {
|
||||||
|
seq.startGate.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
@@ -25,6 +26,12 @@ type Sampler struct {
|
|||||||
grammar *GrammarSampler
|
grammar *GrammarSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) SetGrammar(grammar *GrammarSampler, mutex *sync.Mutex) {
|
||||||
|
mutex.Lock()
|
||||||
|
defer mutex.Unlock()
|
||||||
|
s.grammar = grammar
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
if len(logits) == 0 {
|
if len(logits) == 0 {
|
||||||
return -1, errors.New("sample: no logits provided to sample")
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
|
|||||||
Reference in New Issue
Block a user