Compare commits

...

1 Commits

Author SHA1 Message Date
Jesse Gross
c61023f554 ollamarunner: Fix off by one error with numPredict
When numPredict is set, the user will receive one less token
than the requested limit. In addition, the stats will incorrectly
show the number of tokens returned as the limit. In cases where
numPredict is not set, the number of tokens is reported correctly.

This occurs because numPredict is checked when setting up the next
batch but hitting the limit will terminate the current batch as well.
Instead, is is better to check the limit as we actually predict them.
2026-02-04 17:14:24 -08:00
2 changed files with 53 additions and 8 deletions

View File

@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
}
ChatTestHelper(ctx, t, req, blueSkyExpected)
}
// TestNumPredict verifies that when num_predict is set, the model generates
// exactly that many tokens. It uses logprobs to count the actual tokens output.
func TestNumPredict(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
t.Fatalf("failed to pull model: %v", err)
}
req := api.GenerateRequest{
Model: "qwen3:0.6b",
Prompt: "Write a long story.",
Stream: &stream,
Logprobs: true,
Options: map[string]any{
"num_predict": 10,
"temperature": 0,
"seed": 123,
},
}
logprobCount := 0
var finalResponse api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
logprobCount += len(resp.Logprobs)
if resp.Done {
finalResponse = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %v", err)
}
if logprobCount != 10 {
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
}
}

View File

@@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{}
@@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
@@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
seq.lastUpdatedAt = t
seq.numPredicted++
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt
@@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
}
seq.pendingResponses = append(seq.pendingResponses, piece)
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, llm.DoneReasonLength)
continue
}
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := common.FindStop(sequence, seq.stop); ok {