mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 04:29:51 -05:00
Compare commits
1 Commits
implement-
...
parth/opt-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
beaa0e82f3 |
@@ -285,6 +285,7 @@ type Options struct {
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
ShiftContext bool `json:"shift_context,omitempty"`
|
||||
}
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
@@ -663,6 +664,7 @@ func DefaultOptions() Options {
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
Seed: -1,
|
||||
ShiftContext: true,
|
||||
|
||||
Runner: Runner{
|
||||
// options set when the model is loaded
|
||||
|
||||
@@ -700,6 +700,8 @@ const (
|
||||
DoneReasonStop DoneReason = iota
|
||||
// DoneReasonLength indicates the completion stopped due to length limits
|
||||
DoneReasonLength
|
||||
// DoneReasonContextShift indicates the completion stopped due to context shift
|
||||
DoneReasonContextShift
|
||||
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
|
||||
DoneReasonConnectionClosed
|
||||
)
|
||||
@@ -710,6 +712,8 @@ func (d DoneReason) String() string {
|
||||
return "length"
|
||||
case DoneReasonStop:
|
||||
return "stop"
|
||||
case DoneReasonContextShift:
|
||||
return "context_limit_reached"
|
||||
default:
|
||||
return "" // closed
|
||||
}
|
||||
|
||||
@@ -80,6 +80,9 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if context shifting should be enabled
|
||||
shiftContext bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -90,11 +93,12 @@ type Sequence struct {
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
enableContextShift bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
@@ -120,7 +124,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if len(inputs) > s.cache.numCtx {
|
||||
if len(inputs) > s.cache.numCtx && params.enableContextShift {
|
||||
discard := len(inputs) - s.cache.numCtx
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
@@ -155,6 +159,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shiftContext: params.enableContextShift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -300,13 +305,26 @@ func flushPending(seq *Sequence) bool {
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
if seq == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the sequence as being removed to prevent further processing
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
if seq.cache != nil {
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
if len(seq.pendingResponses) > 0 {
|
||||
flushPending(seq)
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
@@ -340,7 +358,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
default:
|
||||
err := s.processBatch(tokenBatch, embedBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
slog.Error("error processing batch", "error", err)
|
||||
}
|
||||
|
||||
tokenBatch.Clear()
|
||||
@@ -382,6 +400,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
for i, input := range seq.inputs {
|
||||
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
||||
if !seq.shiftContext {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
|
||||
continue
|
||||
}
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
@@ -573,11 +595,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
enableContextShift: req.Options.ShiftContext,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
|
||||
152
runner/llamarunner/runner_test.go
Normal file
152
runner/llamarunner/runner_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestContextShiftLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
contextLength int32
|
||||
cacheInputs int
|
||||
pendingInputs int
|
||||
minBatch int
|
||||
expectedDoneReason llm.DoneReason
|
||||
shouldRemove bool
|
||||
}{
|
||||
{
|
||||
name: "context shifting enabled - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - should remove",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - within limits",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "pending inputs - should break batch",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 20,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "no pending inputs - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "long prompt with context shifting disabled - will be handled at runtime",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 0,
|
||||
pendingInputs: 0,
|
||||
minBatch: 150, // Simulates a long prompt
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
|
||||
if tt.pendingInputs != 0 {
|
||||
// Should break batch
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when pending inputs exist")
|
||||
}
|
||||
} else if !tt.enableContextShift {
|
||||
// Should remove with DoneReasonContextShift
|
||||
if !tt.shouldRemove {
|
||||
t.Error("should remove sequence when context shifting disabled")
|
||||
}
|
||||
if tt.expectedDoneReason != llm.DoneReasonContextShift {
|
||||
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
|
||||
}
|
||||
} else {
|
||||
// Should shift context
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when context shifting enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPredictLimitLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numPredict int
|
||||
numPredicted int
|
||||
expectRemove bool
|
||||
}{
|
||||
{
|
||||
name: "predict limit not reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 3,
|
||||
expectRemove: false,
|
||||
},
|
||||
{
|
||||
name: "predict limit reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 5,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "predict limit exceeded",
|
||||
numPredict: 5,
|
||||
numPredicted: 6,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "no predict limit",
|
||||
numPredict: 0,
|
||||
numPredicted: 100,
|
||||
expectRemove: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
|
||||
if shouldRemove != tt.expectRemove {
|
||||
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,9 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if context shifting should be enabled
|
||||
shiftContext bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -95,11 +98,12 @@ type Sequence struct {
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
enableContextShift bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
@@ -121,7 +125,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
if int32(len(inputs)) > s.cache.numCtx && params.enableContextShift {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
promptStart := params.numKeep + discard
|
||||
|
||||
@@ -175,6 +179,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shiftContext: params.enableContextShift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -341,13 +346,25 @@ func flushPending(seq *Sequence) bool {
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
if seq == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the sequence as being removed to prevent further processing
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
if seq.cache != nil {
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
if len(seq.pendingResponses) > 0 {
|
||||
flushPending(seq)
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
@@ -431,6 +448,11 @@ func (s *Server) processBatch() error {
|
||||
break
|
||||
}
|
||||
|
||||
if !seq.shiftContext {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
|
||||
continue
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
@@ -629,11 +651,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
enableContextShift: req.Options.ShiftContext,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
|
||||
167
runner/ollamarunner/runner_test.go
Normal file
167
runner/ollamarunner/runner_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestEnableContextShiftLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
contextLength int32
|
||||
cacheInputs int
|
||||
pendingInputs int
|
||||
minBatch int
|
||||
expectedDoneReason llm.DoneReason
|
||||
shouldRemove bool
|
||||
}{
|
||||
{
|
||||
name: "context shifting enabled - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - should remove with DoneReasonContextShift",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - within limits",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - exact limit",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 100,
|
||||
pendingInputs: 0,
|
||||
minBatch: 1,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "pending inputs - should break batch",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 20,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "no pending inputs - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "long prompt with context shifting disabled - will be handled at runtime",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 0,
|
||||
pendingInputs: 0,
|
||||
minBatch: 150, // Simulates a long prompt
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch - matches actual implementation
|
||||
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
|
||||
if tt.pendingInputs != 0 {
|
||||
// Should break batch - don't remove sequence
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when pending inputs exist")
|
||||
}
|
||||
} else if !tt.enableContextShift {
|
||||
// Should remove with DoneReasonContextShift
|
||||
if !tt.shouldRemove {
|
||||
t.Error("should remove sequence when context shifting disabled")
|
||||
}
|
||||
if tt.expectedDoneReason != llm.DoneReasonContextShift {
|
||||
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
|
||||
}
|
||||
} else {
|
||||
// Should shift context - don't remove sequence
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when context shifting enabled")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Within limits - should not remove
|
||||
if tt.shouldRemove {
|
||||
t.Errorf("should not remove sequence when within context limits")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPredictLimitLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numPredict int
|
||||
numPredicted int
|
||||
expectRemove bool
|
||||
}{
|
||||
{
|
||||
name: "predict limit not reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 3,
|
||||
expectRemove: false,
|
||||
},
|
||||
{
|
||||
name: "predict limit reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 5,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "predict limit exceeded",
|
||||
numPredict: 5,
|
||||
numPredicted: 6,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "no predict limit",
|
||||
numPredict: 0,
|
||||
numPredicted: 100,
|
||||
expectRemove: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
|
||||
if shouldRemove != tt.expectRemove {
|
||||
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -63,6 +63,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
|
||||
if ctxLen > opts.NumCtx {
|
||||
if !opts.ShiftContext {
|
||||
return "", nil, fmt.Errorf("context length of %d tokens exceeded, context shifting is disabled", opts.NumCtx)
|
||||
}
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
break
|
||||
} else {
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -56,7 +57,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
||||
error: fmt.Errorf("context length of 1 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -69,10 +70,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -85,10 +83,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -156,10 +151,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 1024 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -208,12 +200,25 @@ func TestChatPrompt(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
|
||||
// For truncation tests, disable context shifting to test the truncation behavior
|
||||
if tt.name == "truncate messages" ||
|
||||
tt.name == "truncate messages with image" ||
|
||||
tt.name == "truncate messages with images" ||
|
||||
tt.name == "truncate message with interleaved images" {
|
||||
opts.ShiftContext = false
|
||||
}
|
||||
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
} else if tt.error != nil && err != nil {
|
||||
if err.Error() != tt.error.Error() {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
}
|
||||
} else if tt.error != nil && err == nil {
|
||||
t.Fatalf("expected err '%q', got nil", tt.error)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -968,3 +970,154 @@ func TestWaitForStream(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableContextShiftNonStreamingResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
responses []llm.CompletionResponse
|
||||
expectedDone bool
|
||||
expectedDoneReason string
|
||||
}{
|
||||
{
|
||||
name: "context shifting disabled - should have DoneReasonLength",
|
||||
enableContextShift: false,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
{Content: "", Done: true, DoneReason: llm.DoneReasonLength},
|
||||
},
|
||||
expectedDone: true,
|
||||
expectedDoneReason: "length",
|
||||
},
|
||||
{
|
||||
name: "context shifting enabled - should have DoneReasonStop",
|
||||
enableContextShift: true,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
{Content: "", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedDone: true,
|
||||
expectedDoneReason: "stop",
|
||||
},
|
||||
{
|
||||
name: "no final response with Done=true",
|
||||
enableContextShift: false,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
},
|
||||
expectedDone: false,
|
||||
expectedDoneReason: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// The last response in the channel will naturally be the final state
|
||||
lastResponse := tt.responses[len(tt.responses)-1]
|
||||
|
||||
if lastResponse.Done != tt.expectedDone {
|
||||
t.Errorf("expected Done=%v, got %v", tt.expectedDone, lastResponse.Done)
|
||||
}
|
||||
|
||||
if tt.expectedDoneReason != "" {
|
||||
if lastResponse.DoneReason.String() != tt.expectedDoneReason {
|
||||
t.Errorf("expected DoneReason=%s, got %s", tt.expectedDoneReason, lastResponse.DoneReason.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleScheduleError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errorMessage string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "context length exceeded error",
|
||||
errorMessage: "context length of 100 tokens exceeded, context shifting is disabled",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
errorMessage: "some other error",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
err := errors.New(tt.errorMessage)
|
||||
|
||||
handleScheduleError(c, "test-model", err)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if errorMsg, ok := response["error"].(string); !ok || errorMsg != tt.errorMessage {
|
||||
t.Errorf("expected error message '%s', got '%s'", tt.errorMessage, errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableContextShiftOptions(t *testing.T) {
|
||||
t.Run("default options have enableContextShift=true", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
if !opts.ShiftContext {
|
||||
t.Errorf("expected EnableContextShift=true by default, got %v", opts.ShiftContext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("can set enableContextShift to false", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = false
|
||||
if opts.ShiftContext {
|
||||
t.Errorf("expected EnableContextShift=false after setting, got %v", opts.ShiftContext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON serialization omits false values", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = false
|
||||
|
||||
data, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal options: %v", err)
|
||||
}
|
||||
|
||||
// Check that enable_context_shift is not in the JSON when false
|
||||
if bytes.Contains(data, []byte("enable_context_shift")) {
|
||||
t.Errorf("expected enable_context_shift to be omitted from JSON when false, but found it in: %s", string(data))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON serialization includes true values", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = true
|
||||
|
||||
data, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal options: %v", err)
|
||||
}
|
||||
|
||||
// Check that enable_context_shift is in the JSON when true
|
||||
if !bytes.Contains(data, []byte("enable_context_shift")) {
|
||||
t.Errorf("expected enable_context_shift to be in JSON when true, but not found in: %s", string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user