Compare commits

...

1 Commits

Author SHA1 Message Date
Bruce MacDonald
d5eae8248d runner: enable returning more info from runner processing
Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.
2025-06-13 16:26:57 -07:00
6 changed files with 378 additions and 230 deletions

View File

@@ -22,6 +22,7 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"golang.org/x/sync/semaphore"
@@ -725,10 +726,68 @@ type CompletionResponse struct {
EvalDuration time.Duration `json:"eval_duration"`
}
// unicodeBufferHandler wraps a completion response callback to handle partial UTF-8 sequences.
// This function creates a stateful closure that is NOT safe for concurrent use.
// Each completion request should create its own handler instance.
func unicodeBufferHandler(fn func(CompletionResponse)) func(CompletionResponse) {
var pendingUTF8 string
return func(resp CompletionResponse) {
if resp.Content == "" && !resp.Done {
// No content to process, just pass through
fn(resp)
return
}
// Combine any pending UTF-8 with current content
combinedContent := pendingUTF8 + resp.Content
pendingUTF8 = ""
// Check if combined content is valid UTF-8
if utf8.ValidString(combinedContent) {
// Valid UTF-8, send it
resp.Content = combinedContent
fn(resp)
} else {
// Invalid UTF-8
if resp.Done {
// This is the final response, trim incomplete UTF-8
trimmedContent := combinedContent
for !utf8.ValidString(trimmedContent) && len(trimmedContent) > 0 {
trimmedContent = trimmedContent[:len(trimmedContent)-1]
}
resp.Content = trimmedContent
fn(resp)
} else {
// Not final response, split valid and invalid parts
validPrefix := combinedContent
for !utf8.ValidString(validPrefix) && len(validPrefix) > 0 {
validPrefix = validPrefix[:len(validPrefix)-1]
}
if len(validPrefix) > 0 {
// Send valid prefix
resp.Content = validPrefix
fn(resp)
// Buffer the remainder
pendingUTF8 = combinedContent[len(validPrefix):]
} else {
// No valid prefix, buffer everything
pendingUTF8 = combinedContent
// Don't send this response
}
}
}
}
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format))
slog.Log(ctx, logutil.LevelTrace, "completion request", "prompt", req.Prompt)
// Wrap the callback with unicode buffer handling
unicodeFn := unicodeBufferHandler(fn)
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
@@ -854,13 +913,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
if c.Content != "" {
fn(CompletionResponse{
unicodeFn(CompletionResponse{
Content: c.Content,
})
}
if c.Done {
fn(c)
unicodeFn(c)
return nil
}
}

View File

@@ -70,3 +70,152 @@ func TestLLMServerCompletionFormat(t *testing.T) {
}, nil)
checkValid(err)
}
func TestUnicodeBufferHandler(t *testing.T) {
tests := []struct {
name string
inputResponses []CompletionResponse
expectedResponses []CompletionResponse
description string
}{
{
name: "complete_unicode",
inputResponses: []CompletionResponse{
{Content: "Hello", Done: false},
{Content: " world", Done: false},
{Content: "!", Done: true},
},
expectedResponses: []CompletionResponse{
{Content: "Hello", Done: false},
{Content: " world", Done: false},
{Content: "!", Done: true},
},
description: "All responses with valid unicode should pass through unchanged",
},
{
name: "incomplete_unicode_at_end_with_done",
inputResponses: []CompletionResponse{
{Content: "Hello", Done: false},
{Content: string([]byte{0xF0, 0x9F}), Done: true}, // Incomplete emoji with Done=true
},
expectedResponses: []CompletionResponse{
{Content: "Hello", Done: false},
{Content: "", Done: true}, // Content is trimmed but response is still sent with Done=true
},
description: "When Done=true, incomplete Unicode at the end should be trimmed",
},
{
name: "split_unicode_across_responses",
inputResponses: []CompletionResponse{
{Content: "Hello " + string([]byte{0xF0, 0x9F}), Done: false}, // First part of 😀
{Content: string([]byte{0x98, 0x80}) + " world!", Done: true}, // Second part of 😀 and more text
},
expectedResponses: []CompletionResponse{
{Content: "Hello ", Done: false}, // Incomplete Unicode trimmed
{Content: "😀 world!", Done: true}, // Complete emoji in second response
},
description: "Unicode split across responses should be handled correctly",
},
{
name: "incomplete_unicode_buffered",
inputResponses: []CompletionResponse{
{Content: "Test " + string([]byte{0xF0, 0x9F}), Done: false}, // Incomplete emoji
{Content: string([]byte{0x98, 0x80}), Done: false}, // Complete the emoji
{Content: " done", Done: true},
},
expectedResponses: []CompletionResponse{
{Content: "Test ", Done: false}, // First part without incomplete unicode
{Content: "😀", Done: false}, // Complete emoji
{Content: " done", Done: true},
},
description: "Incomplete unicode should be buffered and combined with next response",
},
{
name: "empty_response_with_done",
inputResponses: []CompletionResponse{
{Content: "Complete response", Done: false},
{Content: "", Done: true}, // Empty response with Done=true
},
expectedResponses: []CompletionResponse{
{Content: "Complete response", Done: false},
{Content: "", Done: true}, // Should still be sent because Done=true
},
description: "Empty final response with Done=true should still be sent",
},
{
name: "done_reason_preserved",
inputResponses: []CompletionResponse{
{Content: "Response", Done: false},
{Content: " complete", Done: true, DoneReason: DoneReasonStop},
},
expectedResponses: []CompletionResponse{
{Content: "Response", Done: false},
{Content: " complete", Done: true, DoneReason: DoneReasonStop},
},
description: "DoneReason should be preserved in the final response",
},
{
name: "only_incomplete_unicode_not_done",
inputResponses: []CompletionResponse{
{Content: string([]byte{0xF0, 0x9F}), Done: false}, // Only incomplete unicode
},
expectedResponses: []CompletionResponse{
// No response expected - should be buffered
},
description: "Response with only incomplete unicode should be buffered if not done",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var actualResponses []CompletionResponse
// Create a callback that collects responses
callback := func(resp CompletionResponse) {
actualResponses = append(actualResponses, resp)
}
// Create the unicode buffer handler
handler := unicodeBufferHandler(callback)
// Send all input responses through the handler
for _, resp := range tt.inputResponses {
handler(resp)
}
// Verify the number of responses
if len(actualResponses) != len(tt.expectedResponses) {
t.Fatalf("%s: got %d responses, want %d responses",
tt.description, len(actualResponses), len(tt.expectedResponses))
}
// Verify each response matches the expected one
for i, expected := range tt.expectedResponses {
if i >= len(actualResponses) {
t.Fatalf("%s: missing response at index %d", tt.description, i)
continue
}
actual := actualResponses[i]
// Verify content
if actual.Content != expected.Content {
t.Errorf("%s: response[%d].Content = %q, want %q",
tt.description, i, actual.Content, expected.Content)
}
// Verify Done flag
if actual.Done != expected.Done {
t.Errorf("%s: response[%d].Done = %v, want %v",
tt.description, i, actual.Done, expected.Done)
}
// Verify DoneReason if specified
if actual.DoneReason != expected.DoneReason {
t.Errorf("%s: response[%d].DoneReason = %v, want %v",
tt.description, i, actual.DoneReason, expected.DoneReason)
}
}
})
}
}

View File

@@ -2,6 +2,8 @@ package common
import (
"strings"
"github.com/ollama/ollama/llm"
)
func FindStop(sequence string, stops []string) (bool, string) {
@@ -29,68 +31,41 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case)
func TruncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return pieces, false
func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
var sequence string
for _, resp := range resps {
sequence += resp.Content
}
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
idx := strings.Index(sequence, stop)
if idx < 0 {
return resps, false
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
truncated := sequence[:idx]
if len(truncated) == 0 {
return nil, true
}
result := make([]llm.CompletionResponse, 0, len(resps))
// Track position in truncated sequence
pos := 0
truncationHappened := false
for _, resp := range resps {
if pos >= len(truncated) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
if len(chunk) < len(resp.Content) {
truncationHappened = true
}
result = append(result, joined[start:end])
start = end
if len(chunk) > 0 {
result = append(result, llm.CompletionResponse{Content: chunk})
}
pos += len(resp.Content)
}
return result, tokenTruncated
}
func IncompleteUnicode(token string) bool {
incomplete := false
// check if there is incomplete UTF-8 character at the end
for i := 1; i < 5 && i <= len(token); i++ {
c := token[len(token)-i]
if (c & 0xc0) == 0x80 {
// continuation byte: 10xxxxxx
continue
}
if (c & 0xe0) == 0xc0 {
// 2-byte character: 110xxxxx ...
incomplete = i < 2
} else if (c & 0xf0) == 0xe0 {
// 3-byte character: 1110xxxx ...
incomplete = i < 3
} else if (c & 0xf8) == 0xf0 {
// 4-byte character: 11110xxx ...
incomplete = i < 4
}
// else 1-byte character or invalid byte
break
}
return incomplete
return result, truncationHappened
}

View File

@@ -1,51 +1,84 @@
package common
import (
"fmt"
"reflect"
"testing"
"github.com/ollama/ollama/llm"
)
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
pieces []llm.CompletionResponse
stop string
expected []string
expected []llm.CompletionResponse
expectedTrunc bool
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
name: "Single word",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: "world"},
},
stop: "world",
expected: []llm.CompletionResponse{
{Content: "Hello"},
},
expectedTrunc: false,
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
name: "Partial",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " wor"},
},
stop: "or",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " w"},
},
expectedTrunc: true,
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
name: "Suffix",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
{Content: "!"},
},
stop: "!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
},
expectedTrunc: false,
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
name: "Suffix partial",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " the"},
{Content: "re!"},
},
stop: "there!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " "},
},
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
name: "Middle",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " wo"},
},
stop: "llo w",
expected: []llm.CompletionResponse{
{Content: "He"},
},
expectedTrunc: true,
},
}
@@ -54,76 +87,23 @@ func TestTruncateStop(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
}
})
}
}
func TestIncompleteUnicode(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Basic",
input: "hi",
expected: false,
},
{
name: "Two byte",
input: "hi" + string([]byte{0xc2, 0xa3}),
expected: false,
},
{
name: "Two byte - missing last",
input: "hi" + string([]byte{0xc2}),
expected: true,
},
{
name: "Three byte",
input: "hi" + string([]byte{0xe0, 0xA0, 0x80}),
expected: false,
},
{
name: "Three byte - missing last",
input: "hi" + string([]byte{0xe0, 0xA0}),
expected: true,
},
{
name: "Three byte - missing last 2",
input: "hi" + string([]byte{0xe0}),
expected: true,
},
{
name: "Four byte",
input: "hi" + string([]byte{0xf0, 0x92, 0x8a, 0xb7}),
expected: false,
},
{
name: "Four byte - missing last",
input: "hi" + string([]byte{0xf0, 0x92, 0x8a}),
expected: true,
},
{
name: "Four byte - missing last 2",
input: "hi" + string([]byte{0xf0, 0x92}),
expected: true,
},
{
name: "Four byte - missing last 3",
input: "hi" + string([]byte{0xf0}),
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IncompleteUnicode(tt.input)
if result != tt.expected {
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
}
})
func formatContentDiff(result, expected []llm.CompletionResponse) string {
var s string
for i := 0; i < len(result) || i < len(expected); i++ {
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
} else if i < len(result) && i >= len(expected) {
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
} else if i >= len(result) && i < len(expected) {
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
}
}
return s
}

View File

@@ -17,7 +17,6 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"golang.org/x/sync/semaphore"
@@ -52,13 +51,13 @@ type Sequence struct {
pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []llm.CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
// channel to send responses over
responses chan string
responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@@ -89,6 +88,19 @@ type Sequence struct {
numPromptInputs int
}
func (seq *Sequence) send(resp llm.CompletionResponse) bool {
if len(resp.Content) > 0 || resp.Done {
select {
case seq.responses <- resp:
// Successfully sent
return true
case <-seq.quit:
return false
}
}
return true
}
type NewSequenceParams struct {
numPredict int
stop []string
@@ -147,8 +159,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
@@ -272,36 +284,15 @@ func (s *Server) allNil() bool {
return true
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
}
}
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
flushPending(seq)
// Send any remaining pending responses
for _, resp := range seq.pendingResponses {
seq.send(resp)
}
seq.pendingResponses = []llm.CompletionResponse{}
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
@@ -490,8 +481,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
@@ -523,13 +517,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
if common.IncompleteUnicode(sequence) {
continue
}
if !flushPending(seq) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
for _, resp := range seq.pendingResponses {
if !seq.send(resp) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
break
}
}
seq.pendingResponses = []llm.CompletionResponse{}
}
return nil
@@ -627,9 +621,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return

View File

@@ -20,7 +20,6 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"golang.org/x/image/bmp"
"golang.org/x/sync/semaphore"
@@ -56,13 +55,13 @@ type Sequence struct {
pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []llm.CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
// channel to send responses over
responses chan string
responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@@ -94,6 +93,19 @@ type Sequence struct {
numPromptInputs int
}
func (seq *Sequence) send(resp llm.CompletionResponse) bool {
if len(resp.Content) > 0 || resp.Done {
select {
case seq.responses <- resp:
// Successfully sent
return true
case <-seq.quit:
return false
}
}
return true
}
type NewSequenceParams struct {
numPredict int
stop []string
@@ -167,8 +179,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
sampler: params.sampler,
@@ -313,36 +325,15 @@ func (s *Server) allNil() bool {
return true
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
}
}
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
flushPending(seq)
// Send any remaining pending responses
for _, resp := range seq.pendingResponses {
seq.send(resp)
}
seq.pendingResponses = []llm.CompletionResponse{}
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
@@ -541,8 +532,11 @@ func (s *Server) processBatch() error {
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
@@ -574,13 +568,14 @@ func (s *Server) processBatch() error {
continue
}
if common.IncompleteUnicode(sequence) {
continue
}
if !flushPending(seq) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
// Send all pending responses directly without unicode checking
for _, resp := range seq.pendingResponses {
if !seq.send(resp) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
break
}
}
seq.pendingResponses = []llm.CompletionResponse{}
}
return nil
@@ -683,9 +678,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return