mirror of
https://github.com/ollama/ollama.git
synced 2026-02-19 07:45:22 -05:00
Change the truncation algorithm to start with all messages and remove from the front until it fits, rather than adding messages one at a time from the back. This reduces tokenization calls from O(n) to O(1) in the common case where all messages fit in context.
333 lines
9.9 KiB
Go
333 lines
9.9 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/template"
|
|
)
|
|
|
|
func TestChatPrompt(t *testing.T) {
|
|
type expect struct {
|
|
prompt string
|
|
images [][]byte
|
|
error error
|
|
}
|
|
|
|
tmpl, err := template.Parse(`
|
|
{{- if .System }}{{ .System }} {{ end }}
|
|
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
|
{{- if .Response }}{{ .Response }} {{ end }}`)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
|
|
|
cases := []struct {
|
|
name string
|
|
model Model
|
|
limit int
|
|
truncate bool
|
|
msgs []api.Message
|
|
expect
|
|
}{
|
|
{
|
|
name: "messages",
|
|
model: visionModel,
|
|
limit: 64,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!"},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
|
},
|
|
expect: expect{
|
|
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
|
},
|
|
},
|
|
{
|
|
name: "truncate messages",
|
|
model: visionModel,
|
|
limit: 1,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!"},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
|
},
|
|
expect: expect{
|
|
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
|
},
|
|
},
|
|
{
|
|
name: "truncate messages with image",
|
|
model: visionModel,
|
|
limit: 64,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!"},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
|
},
|
|
expect: expect{
|
|
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
|
images: [][]byte{
|
|
[]byte("something"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "truncate messages with images",
|
|
model: visionModel,
|
|
limit: 64,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
|
},
|
|
expect: expect{
|
|
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
|
images: [][]byte{
|
|
[]byte("somethingelse"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "messages with images",
|
|
model: visionModel,
|
|
limit: 2048,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
|
},
|
|
expect: expect{
|
|
prompt: "[img-0]You're a test, Harry! I-I'm a what? [img-1]A test. And a thumping good one at that, I'd wager. ",
|
|
images: [][]byte{
|
|
[]byte("something"),
|
|
[]byte("somethingelse"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "message with image tag",
|
|
model: visionModel,
|
|
limit: 2048,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
|
},
|
|
expect: expect{
|
|
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1]A test. And a thumping good one at that, I'd wager. ",
|
|
images: [][]byte{
|
|
[]byte("something"),
|
|
[]byte("somethingelse"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "messages with interleaved images",
|
|
model: visionModel,
|
|
limit: 2048,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!"},
|
|
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
|
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
|
},
|
|
expect: expect{
|
|
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
|
images: [][]byte{
|
|
[]byte("something"),
|
|
[]byte("somethingelse"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "truncate message with interleaved images",
|
|
model: visionModel,
|
|
limit: 1024,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!"},
|
|
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
|
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
|
},
|
|
expect: expect{
|
|
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
|
images: [][]byte{
|
|
[]byte("somethingelse"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "message with system prompt",
|
|
model: visionModel,
|
|
limit: 2048,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "system", Content: "You are the Test Who Lived."},
|
|
{Role: "user", Content: "You're a test, Harry!"},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
|
},
|
|
expect: expect{
|
|
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
|
},
|
|
},
|
|
{
|
|
name: "out of order system",
|
|
model: visionModel,
|
|
limit: 2048,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!"},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "system", Content: "You are the Test Who Lived."},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
|
},
|
|
expect: expect{
|
|
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
|
},
|
|
},
|
|
{
|
|
name: "multiple images same prompt",
|
|
model: visionModel,
|
|
limit: 2048,
|
|
truncate: true,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
|
|
},
|
|
expect: expect{
|
|
prompt: "[img-0][img-1]Compare these two pictures of hotdogs ",
|
|
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
|
|
},
|
|
},
|
|
{
|
|
name: "no truncate with limit exceeded",
|
|
model: visionModel,
|
|
limit: 10,
|
|
truncate: false,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "You're a test, Harry!"},
|
|
{Role: "assistant", Content: "I-I'm a what?"},
|
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
|
},
|
|
expect: expect{
|
|
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
model := tt.model
|
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
|
think := false
|
|
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
|
if tt.error == nil && err != nil {
|
|
t.Fatal(err)
|
|
} else if tt.error != nil && err != tt.error {
|
|
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
|
}
|
|
|
|
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
if len(images) != len(tt.images) {
|
|
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
|
|
}
|
|
|
|
for i := range images {
|
|
if images[i].ID != i {
|
|
t.Errorf("expected ID %d, got %d", i, images[i].ID)
|
|
}
|
|
|
|
if len(model.Config.ModelFamilies) == 0 {
|
|
if !bytes.Equal(images[i].Data, tt.images[i]) {
|
|
t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChatPromptTokenizeCalls(t *testing.T) {
|
|
tmpl, err := template.Parse(`
|
|
{{- if .System }}{{ .System }} {{ end }}
|
|
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
|
{{- if .Response }}{{ .Response }} {{ end }}`)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
model := Model{Template: tmpl}
|
|
|
|
cases := []struct {
|
|
name string
|
|
limit int
|
|
msgs []api.Message
|
|
maxTokenizes int
|
|
}{
|
|
{
|
|
name: "all messages fit",
|
|
limit: 2048,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "message 1"},
|
|
{Role: "assistant", Content: "response 1"},
|
|
{Role: "user", Content: "message 2"},
|
|
{Role: "assistant", Content: "response 2"},
|
|
{Role: "user", Content: "message 3"},
|
|
},
|
|
maxTokenizes: 1,
|
|
},
|
|
{
|
|
name: "truncate to last message",
|
|
limit: 5,
|
|
msgs: []api.Message{
|
|
{Role: "user", Content: "message 1"},
|
|
{Role: "assistant", Content: "response 1"},
|
|
{Role: "user", Content: "message 2"},
|
|
{Role: "assistant", Content: "response 2"},
|
|
{Role: "user", Content: "message 3"},
|
|
},
|
|
maxTokenizes: 5,
|
|
},
|
|
}
|
|
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tokenizeCount := 0
|
|
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
|
|
tokenizeCount++
|
|
tokens, err := mockRunner{}.Tokenize(ctx, s)
|
|
return tokens, err
|
|
}
|
|
|
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
|
think := false
|
|
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if tokenizeCount > tt.maxTokenizes {
|
|
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
|
|
}
|
|
})
|
|
}
|
|
}
|