mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 21:08:16 -05:00
Compare commits
1 Commits
parth/decr
...
parth/opt-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
beaa0e82f3 |
@@ -285,6 +285,7 @@ type Options struct {
|
|||||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||||
Stop []string `json:"stop,omitempty"`
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
ShiftContext bool `json:"shift_context,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runner options which must be set when the model is loaded into memory
|
// Runner options which must be set when the model is loaded into memory
|
||||||
@@ -663,6 +664,7 @@ func DefaultOptions() Options {
|
|||||||
PresencePenalty: 0.0,
|
PresencePenalty: 0.0,
|
||||||
FrequencyPenalty: 0.0,
|
FrequencyPenalty: 0.0,
|
||||||
Seed: -1,
|
Seed: -1,
|
||||||
|
ShiftContext: true,
|
||||||
|
|
||||||
Runner: Runner{
|
Runner: Runner{
|
||||||
// options set when the model is loaded
|
// options set when the model is loaded
|
||||||
|
|||||||
@@ -700,6 +700,8 @@ const (
|
|||||||
DoneReasonStop DoneReason = iota
|
DoneReasonStop DoneReason = iota
|
||||||
// DoneReasonLength indicates the completion stopped due to length limits
|
// DoneReasonLength indicates the completion stopped due to length limits
|
||||||
DoneReasonLength
|
DoneReasonLength
|
||||||
|
// DoneReasonContextShift indicates the completion stopped due to context shift
|
||||||
|
DoneReasonContextShift
|
||||||
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
|
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
|
||||||
DoneReasonConnectionClosed
|
DoneReasonConnectionClosed
|
||||||
)
|
)
|
||||||
@@ -710,6 +712,8 @@ func (d DoneReason) String() string {
|
|||||||
return "length"
|
return "length"
|
||||||
case DoneReasonStop:
|
case DoneReasonStop:
|
||||||
return "stop"
|
return "stop"
|
||||||
|
case DoneReasonContextShift:
|
||||||
|
return "context_limit_reached"
|
||||||
default:
|
default:
|
||||||
return "" // closed
|
return "" // closed
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,9 @@ type Sequence struct {
|
|||||||
// true if an embedding are to be returned instead of text generation
|
// true if an embedding are to be returned instead of text generation
|
||||||
embeddingOnly bool
|
embeddingOnly bool
|
||||||
|
|
||||||
|
// true if context shifting should be enabled
|
||||||
|
shiftContext bool
|
||||||
|
|
||||||
doneReason llm.DoneReason
|
doneReason llm.DoneReason
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
@@ -95,6 +98,7 @@ type NewSequenceParams struct {
|
|||||||
numKeep int
|
numKeep int
|
||||||
samplingParams *llama.SamplingParams
|
samplingParams *llama.SamplingParams
|
||||||
embedding bool
|
embedding bool
|
||||||
|
enableContextShift bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
@@ -120,7 +124,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// Ensure that at least 1 input can be discarded during shift
|
// Ensure that at least 1 input can be discarded during shift
|
||||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||||
|
|
||||||
if len(inputs) > s.cache.numCtx {
|
if len(inputs) > s.cache.numCtx && params.enableContextShift {
|
||||||
discard := len(inputs) - s.cache.numCtx
|
discard := len(inputs) - s.cache.numCtx
|
||||||
newInputs := inputs[:params.numKeep]
|
newInputs := inputs[:params.numKeep]
|
||||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||||
@@ -155,6 +159,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
embeddingOnly: params.embedding,
|
embeddingOnly: params.embedding,
|
||||||
stop: params.stop,
|
stop: params.stop,
|
||||||
numKeep: params.numKeep,
|
numKeep: params.numKeep,
|
||||||
|
shiftContext: params.enableContextShift,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,13 +305,26 @@ func flushPending(seq *Sequence) bool {
|
|||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||||
seq := s.seqs[seqIndex]
|
seq := s.seqs[seqIndex]
|
||||||
|
if seq == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the sequence as being removed to prevent further processing
|
||||||
|
s.seqs[seqIndex] = nil
|
||||||
|
|
||||||
|
if seq.cache != nil {
|
||||||
|
seq.cache.InUse = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(seq.pendingResponses) > 0 {
|
||||||
flushPending(seq)
|
flushPending(seq)
|
||||||
|
}
|
||||||
|
|
||||||
seq.doneReason = reason
|
seq.doneReason = reason
|
||||||
|
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
seq.cache.InUse = false
|
|
||||||
s.seqs[seqIndex] = nil
|
|
||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,7 +358,7 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
default:
|
default:
|
||||||
err := s.processBatch(tokenBatch, embedBatch)
|
err := s.processBatch(tokenBatch, embedBatch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
slog.Error("error processing batch", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenBatch.Clear()
|
tokenBatch.Clear()
|
||||||
@@ -382,6 +400,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
for i, input := range seq.inputs {
|
for i, input := range seq.inputs {
|
||||||
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
||||||
|
if !seq.shiftContext {
|
||||||
|
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
|
||||||
|
continue
|
||||||
|
}
|
||||||
if len(seq.pendingInputs) == 0 {
|
if len(seq.pendingInputs) == 0 {
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -578,6 +600,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
numKeep: req.Options.NumKeep,
|
numKeep: req.Options.NumKeep,
|
||||||
samplingParams: &samplingParams,
|
samplingParams: &samplingParams,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
|
enableContextShift: req.Options.ShiftContext,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
|
|||||||
152
runner/llamarunner/runner_test.go
Normal file
152
runner/llamarunner/runner_test.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package llamarunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextShiftLogic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
enableContextShift bool
|
||||||
|
contextLength int32
|
||||||
|
cacheInputs int
|
||||||
|
pendingInputs int
|
||||||
|
minBatch int
|
||||||
|
expectedDoneReason llm.DoneReason
|
||||||
|
shouldRemove bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "context shifting enabled - should shift",
|
||||||
|
enableContextShift: true,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 80,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonStop,
|
||||||
|
shouldRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context shifting disabled - should remove",
|
||||||
|
enableContextShift: false,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 80,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonContextShift,
|
||||||
|
shouldRemove: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context shifting disabled - within limits",
|
||||||
|
enableContextShift: false,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 50,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonStop,
|
||||||
|
shouldRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pending inputs - should break batch",
|
||||||
|
enableContextShift: true,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 50,
|
||||||
|
pendingInputs: 20,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonStop,
|
||||||
|
shouldRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no pending inputs - should shift",
|
||||||
|
enableContextShift: true,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 80,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonStop,
|
||||||
|
shouldRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "long prompt with context shifting disabled - will be handled at runtime",
|
||||||
|
enableContextShift: false,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 0,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 150, // Simulates a long prompt
|
||||||
|
expectedDoneReason: llm.DoneReasonContextShift,
|
||||||
|
shouldRemove: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test the core logic from processBatch
|
||||||
|
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
|
||||||
|
if tt.pendingInputs != 0 {
|
||||||
|
// Should break batch
|
||||||
|
if tt.shouldRemove {
|
||||||
|
t.Error("should not remove sequence when pending inputs exist")
|
||||||
|
}
|
||||||
|
} else if !tt.enableContextShift {
|
||||||
|
// Should remove with DoneReasonContextShift
|
||||||
|
if !tt.shouldRemove {
|
||||||
|
t.Error("should remove sequence when context shifting disabled")
|
||||||
|
}
|
||||||
|
if tt.expectedDoneReason != llm.DoneReasonContextShift {
|
||||||
|
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Should shift context
|
||||||
|
if tt.shouldRemove {
|
||||||
|
t.Error("should not remove sequence when context shifting enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPredictLimitLogic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
numPredict int
|
||||||
|
numPredicted int
|
||||||
|
expectRemove bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "predict limit not reached",
|
||||||
|
numPredict: 5,
|
||||||
|
numPredicted: 3,
|
||||||
|
expectRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "predict limit reached",
|
||||||
|
numPredict: 5,
|
||||||
|
numPredicted: 5,
|
||||||
|
expectRemove: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "predict limit exceeded",
|
||||||
|
numPredict: 5,
|
||||||
|
numPredicted: 6,
|
||||||
|
expectRemove: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no predict limit",
|
||||||
|
numPredict: 0,
|
||||||
|
numPredicted: 100,
|
||||||
|
expectRemove: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test the core logic from processBatch
|
||||||
|
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
|
||||||
|
if shouldRemove != tt.expectRemove {
|
||||||
|
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -85,6 +85,9 @@ type Sequence struct {
|
|||||||
// true if an embedding are to be returned instead of text generation
|
// true if an embedding are to be returned instead of text generation
|
||||||
embeddingOnly bool
|
embeddingOnly bool
|
||||||
|
|
||||||
|
// true if context shifting should be enabled
|
||||||
|
shiftContext bool
|
||||||
|
|
||||||
doneReason llm.DoneReason
|
doneReason llm.DoneReason
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
@@ -100,6 +103,7 @@ type NewSequenceParams struct {
|
|||||||
numKeep int32
|
numKeep int32
|
||||||
sampler sample.Sampler
|
sampler sample.Sampler
|
||||||
embedding bool
|
embedding bool
|
||||||
|
enableContextShift bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
@@ -121,7 +125,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// Ensure that at least 1 input can be discarded during shift
|
// Ensure that at least 1 input can be discarded during shift
|
||||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||||
|
|
||||||
if int32(len(inputs)) > s.cache.numCtx {
|
if int32(len(inputs)) > s.cache.numCtx && params.enableContextShift {
|
||||||
discard := int32(len(inputs)) - s.cache.numCtx
|
discard := int32(len(inputs)) - s.cache.numCtx
|
||||||
promptStart := params.numKeep + discard
|
promptStart := params.numKeep + discard
|
||||||
|
|
||||||
@@ -175,6 +179,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
embeddingOnly: params.embedding,
|
embeddingOnly: params.embedding,
|
||||||
stop: params.stop,
|
stop: params.stop,
|
||||||
numKeep: params.numKeep,
|
numKeep: params.numKeep,
|
||||||
|
shiftContext: params.enableContextShift,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -341,13 +346,25 @@ func flushPending(seq *Sequence) bool {
|
|||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||||
seq := s.seqs[seqIndex]
|
seq := s.seqs[seqIndex]
|
||||||
|
if seq == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the sequence as being removed to prevent further processing
|
||||||
|
s.seqs[seqIndex] = nil
|
||||||
|
|
||||||
|
if seq.cache != nil {
|
||||||
|
seq.cache.InUse = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(seq.pendingResponses) > 0 {
|
||||||
flushPending(seq)
|
flushPending(seq)
|
||||||
|
}
|
||||||
|
|
||||||
seq.doneReason = reason
|
seq.doneReason = reason
|
||||||
|
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
seq.cache.InUse = false
|
|
||||||
s.seqs[seqIndex] = nil
|
|
||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -431,6 +448,11 @@ func (s *Server) processBatch() error {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !seq.shiftContext {
|
||||||
|
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var reprocess *ErrReprocessInputs
|
var reprocess *ErrReprocessInputs
|
||||||
@@ -634,6 +656,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
numKeep: int32(req.Options.NumKeep),
|
numKeep: int32(req.Options.NumKeep),
|
||||||
sampler: sampler,
|
sampler: sampler,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
|
enableContextShift: req.Options.ShiftContext,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
|
|||||||
167
runner/ollamarunner/runner_test.go
Normal file
167
runner/ollamarunner/runner_test.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package ollamarunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnableContextShiftLogic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
enableContextShift bool
|
||||||
|
contextLength int32
|
||||||
|
cacheInputs int
|
||||||
|
pendingInputs int
|
||||||
|
minBatch int
|
||||||
|
expectedDoneReason llm.DoneReason
|
||||||
|
shouldRemove bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "context shifting enabled - should shift",
|
||||||
|
enableContextShift: true,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 80,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonStop,
|
||||||
|
shouldRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context shifting disabled - should remove with DoneReasonContextShift",
|
||||||
|
enableContextShift: false,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 80,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonContextShift,
|
||||||
|
shouldRemove: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context shifting disabled - within limits",
|
||||||
|
enableContextShift: false,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 50,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonStop,
|
||||||
|
shouldRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context shifting disabled - exact limit",
|
||||||
|
enableContextShift: false,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 100,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 1,
|
||||||
|
expectedDoneReason: llm.DoneReasonContextShift,
|
||||||
|
shouldRemove: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pending inputs - should break batch",
|
||||||
|
enableContextShift: true,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 50,
|
||||||
|
pendingInputs: 20,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonStop,
|
||||||
|
shouldRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no pending inputs - should shift",
|
||||||
|
enableContextShift: true,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 80,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 30,
|
||||||
|
expectedDoneReason: llm.DoneReasonStop,
|
||||||
|
shouldRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "long prompt with context shifting disabled - will be handled at runtime",
|
||||||
|
enableContextShift: false,
|
||||||
|
contextLength: 100,
|
||||||
|
cacheInputs: 0,
|
||||||
|
pendingInputs: 0,
|
||||||
|
minBatch: 150, // Simulates a long prompt
|
||||||
|
expectedDoneReason: llm.DoneReasonContextShift,
|
||||||
|
shouldRemove: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test the core logic from processBatch - matches actual implementation
|
||||||
|
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
|
||||||
|
if tt.pendingInputs != 0 {
|
||||||
|
// Should break batch - don't remove sequence
|
||||||
|
if tt.shouldRemove {
|
||||||
|
t.Error("should not remove sequence when pending inputs exist")
|
||||||
|
}
|
||||||
|
} else if !tt.enableContextShift {
|
||||||
|
// Should remove with DoneReasonContextShift
|
||||||
|
if !tt.shouldRemove {
|
||||||
|
t.Error("should remove sequence when context shifting disabled")
|
||||||
|
}
|
||||||
|
if tt.expectedDoneReason != llm.DoneReasonContextShift {
|
||||||
|
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Should shift context - don't remove sequence
|
||||||
|
if tt.shouldRemove {
|
||||||
|
t.Error("should not remove sequence when context shifting enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Within limits - should not remove
|
||||||
|
if tt.shouldRemove {
|
||||||
|
t.Errorf("should not remove sequence when within context limits")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPredictLimitLogic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
numPredict int
|
||||||
|
numPredicted int
|
||||||
|
expectRemove bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "predict limit not reached",
|
||||||
|
numPredict: 5,
|
||||||
|
numPredicted: 3,
|
||||||
|
expectRemove: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "predict limit reached",
|
||||||
|
numPredict: 5,
|
||||||
|
numPredicted: 5,
|
||||||
|
expectRemove: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "predict limit exceeded",
|
||||||
|
numPredict: 5,
|
||||||
|
numPredicted: 6,
|
||||||
|
expectRemove: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no predict limit",
|
||||||
|
numPredict: 0,
|
||||||
|
numPredicted: 100,
|
||||||
|
expectRemove: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test the core logic from processBatch
|
||||||
|
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
|
||||||
|
if shouldRemove != tt.expectRemove {
|
||||||
|
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -63,6 +63,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ctxLen > opts.NumCtx {
|
if ctxLen > opts.NumCtx {
|
||||||
|
if !opts.ShiftContext {
|
||||||
|
return "", nil, fmt.Errorf("context length of %d tokens exceeded, context shifting is disabled", opts.NumCtx)
|
||||||
|
}
|
||||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -56,7 +57,7 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
},
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
error: fmt.Errorf("context length of 1 tokens exceeded, context shifting is disabled"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -69,10 +70,7 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||||
},
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
|
||||||
images: [][]byte{
|
|
||||||
[]byte("something"),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -85,10 +83,7 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
},
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
|
||||||
images: [][]byte{
|
|
||||||
[]byte("somethingelse"),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -156,10 +151,7 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
},
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
error: fmt.Errorf("context length of 1024 tokens exceeded, context shifting is disabled"),
|
||||||
images: [][]byte{
|
|
||||||
[]byte("somethingelse"),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -208,13 +200,26 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
model := tt.model
|
model := tt.model
|
||||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
|
|
||||||
|
// For truncation tests, disable context shifting to test the truncation behavior
|
||||||
|
if tt.name == "truncate messages" ||
|
||||||
|
tt.name == "truncate messages with image" ||
|
||||||
|
tt.name == "truncate messages with images" ||
|
||||||
|
tt.name == "truncate message with interleaved images" {
|
||||||
|
opts.ShiftContext = false
|
||||||
|
}
|
||||||
|
|
||||||
think := false
|
think := false
|
||||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
||||||
if tt.error == nil && err != nil {
|
if tt.error == nil && err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if tt.error != nil && err != tt.error {
|
} else if tt.error != nil && err != nil {
|
||||||
|
if err.Error() != tt.error.Error() {
|
||||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||||
}
|
}
|
||||||
|
} else if tt.error != nil && err == nil {
|
||||||
|
t.Fatalf("expected err '%q', got nil", tt.error)
|
||||||
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
@@ -25,6 +26,7 @@ import (
|
|||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -968,3 +970,154 @@ func TestWaitForStream(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnableContextShiftNonStreamingResponse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
enableContextShift bool
|
||||||
|
responses []llm.CompletionResponse
|
||||||
|
expectedDone bool
|
||||||
|
expectedDoneReason string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "context shifting disabled - should have DoneReasonLength",
|
||||||
|
enableContextShift: false,
|
||||||
|
responses: []llm.CompletionResponse{
|
||||||
|
{Content: "Hello", Done: false},
|
||||||
|
{Content: " world", Done: false},
|
||||||
|
{Content: "", Done: true, DoneReason: llm.DoneReasonLength},
|
||||||
|
},
|
||||||
|
expectedDone: true,
|
||||||
|
expectedDoneReason: "length",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context shifting enabled - should have DoneReasonStop",
|
||||||
|
enableContextShift: true,
|
||||||
|
responses: []llm.CompletionResponse{
|
||||||
|
{Content: "Hello", Done: false},
|
||||||
|
{Content: " world", Done: false},
|
||||||
|
{Content: "", Done: true, DoneReason: llm.DoneReasonStop},
|
||||||
|
},
|
||||||
|
expectedDone: true,
|
||||||
|
expectedDoneReason: "stop",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no final response with Done=true",
|
||||||
|
enableContextShift: false,
|
||||||
|
responses: []llm.CompletionResponse{
|
||||||
|
{Content: "Hello", Done: false},
|
||||||
|
{Content: " world", Done: false},
|
||||||
|
},
|
||||||
|
expectedDone: false,
|
||||||
|
expectedDoneReason: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// The last response in the channel will naturally be the final state
|
||||||
|
lastResponse := tt.responses[len(tt.responses)-1]
|
||||||
|
|
||||||
|
if lastResponse.Done != tt.expectedDone {
|
||||||
|
t.Errorf("expected Done=%v, got %v", tt.expectedDone, lastResponse.Done)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectedDoneReason != "" {
|
||||||
|
if lastResponse.DoneReason.String() != tt.expectedDoneReason {
|
||||||
|
t.Errorf("expected DoneReason=%s, got %s", tt.expectedDoneReason, lastResponse.DoneReason.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleScheduleError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
errorMessage string
|
||||||
|
expectedStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "context length exceeded error",
|
||||||
|
errorMessage: "context length of 100 tokens exceeded, context shifting is disabled",
|
||||||
|
expectedStatus: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other error",
|
||||||
|
errorMessage: "some other error",
|
||||||
|
expectedStatus: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
err := errors.New(tt.errorMessage)
|
||||||
|
|
||||||
|
handleScheduleError(c, "test-model", err)
|
||||||
|
|
||||||
|
if w.Code != tt.expectedStatus {
|
||||||
|
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response map[string]any
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errorMsg, ok := response["error"].(string); !ok || errorMsg != tt.errorMessage {
|
||||||
|
t.Errorf("expected error message '%s', got '%s'", tt.errorMessage, errorMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnableContextShiftOptions(t *testing.T) {
|
||||||
|
t.Run("default options have enableContextShift=true", func(t *testing.T) {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
if !opts.ShiftContext {
|
||||||
|
t.Errorf("expected EnableContextShift=true by default, got %v", opts.ShiftContext)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("can set enableContextShift to false", func(t *testing.T) {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
opts.ShiftContext = false
|
||||||
|
if opts.ShiftContext {
|
||||||
|
t.Errorf("expected EnableContextShift=false after setting, got %v", opts.ShiftContext)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JSON serialization omits false values", func(t *testing.T) {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
opts.ShiftContext = false
|
||||||
|
|
||||||
|
data, err := json.Marshal(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal options: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that enable_context_shift is not in the JSON when false
|
||||||
|
if bytes.Contains(data, []byte("enable_context_shift")) {
|
||||||
|
t.Errorf("expected enable_context_shift to be omitted from JSON when false, but found it in: %s", string(data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("JSON serialization includes true values", func(t *testing.T) {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
opts.ShiftContext = true
|
||||||
|
|
||||||
|
data, err := json.Marshal(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal options: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that enable_context_shift is in the JSON when true
|
||||||
|
if !bytes.Contains(data, []byte("enable_context_shift")) {
|
||||||
|
t.Errorf("expected enable_context_shift to be in JSON when true, but not found in: %s", string(data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user