Compare commits

...

1 Commits

Author SHA1 Message Date
ParthSareen
beaa0e82f3 api: add flag to disable context shifting 2025-06-18 17:58:48 -07:00
9 changed files with 576 additions and 44 deletions

View File

@@ -285,6 +285,7 @@ type Options struct {
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
Stop []string `json:"stop,omitempty"`
ShiftContext bool `json:"shift_context,omitempty"`
}
// Runner options which must be set when the model is loaded into memory
@@ -663,6 +664,7 @@ func DefaultOptions() Options {
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
Seed: -1,
ShiftContext: true,
Runner: Runner{
// options set when the model is loaded

View File

@@ -700,6 +700,8 @@ const (
DoneReasonStop DoneReason = iota
// DoneReasonLength indicates the completion stopped due to length limits
DoneReasonLength
// DoneReasonContextShift indicates the completion stopped due to context shift
DoneReasonContextShift
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
DoneReasonConnectionClosed
)
@@ -710,6 +712,8 @@ func (d DoneReason) String() string {
return "length"
case DoneReasonStop:
return "stop"
case DoneReasonContextShift:
return "context_limit_reached"
default:
return "" // closed
}

View File

@@ -80,6 +80,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// true if context shifting should be enabled
shiftContext bool
doneReason llm.DoneReason
// Metrics
@@ -90,11 +93,12 @@ type Sequence struct {
}
type NewSequenceParams struct {
numPredict int
stop []string
numKeep int
samplingParams *llama.SamplingParams
embedding bool
numPredict int
stop []string
numKeep int
samplingParams *llama.SamplingParams
embedding bool
enableContextShift bool
}
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
@@ -120,7 +124,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if len(inputs) > s.cache.numCtx {
if len(inputs) > s.cache.numCtx && params.enableContextShift {
discard := len(inputs) - s.cache.numCtx
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
@@ -155,6 +159,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
shiftContext: params.enableContextShift,
}, nil
}
@@ -300,13 +305,26 @@ func flushPending(seq *Sequence) bool {
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
if seq == nil {
return
}
// Mark the sequence as being removed to prevent further processing
s.seqs[seqIndex] = nil
if seq.cache != nil {
seq.cache.InUse = false
}
if len(seq.pendingResponses) > 0 {
flushPending(seq)
}
flushPending(seq)
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
@@ -340,7 +358,7 @@ func (s *Server) run(ctx context.Context) {
default:
err := s.processBatch(tokenBatch, embedBatch)
if err != nil {
panic(err)
slog.Error("error processing batch", "error", err)
}
tokenBatch.Clear()
@@ -382,6 +400,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
for i, input := range seq.inputs {
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
if !seq.shiftContext {
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
continue
}
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
@@ -573,11 +595,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: req.Options.NumKeep,
samplingParams: &samplingParams,
embedding: false,
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: req.Options.NumKeep,
samplingParams: &samplingParams,
embedding: false,
enableContextShift: req.Options.ShiftContext,
})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)

View File

@@ -0,0 +1,152 @@
package llamarunner
import (
"testing"
"github.com/ollama/ollama/llm"
)
func TestContextShiftLogic(t *testing.T) {
tests := []struct {
name string
enableContextShift bool
contextLength int32
cacheInputs int
pendingInputs int
minBatch int
expectedDoneReason llm.DoneReason
shouldRemove bool
}{
{
name: "context shifting enabled - should shift",
enableContextShift: true,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "context shifting disabled - should remove",
enableContextShift: false,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonContextShift,
shouldRemove: true,
},
{
name: "context shifting disabled - within limits",
enableContextShift: false,
contextLength: 100,
cacheInputs: 50,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "pending inputs - should break batch",
enableContextShift: true,
contextLength: 100,
cacheInputs: 50,
pendingInputs: 20,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "no pending inputs - should shift",
enableContextShift: true,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "long prompt with context shifting disabled - will be handled at runtime",
enableContextShift: false,
contextLength: 100,
cacheInputs: 0,
pendingInputs: 0,
minBatch: 150, // Simulates a long prompt
expectedDoneReason: llm.DoneReasonContextShift,
shouldRemove: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the core logic from processBatch
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
if tt.pendingInputs != 0 {
// Should break batch
if tt.shouldRemove {
t.Error("should not remove sequence when pending inputs exist")
}
} else if !tt.enableContextShift {
// Should remove with DoneReasonContextShift
if !tt.shouldRemove {
t.Error("should remove sequence when context shifting disabled")
}
if tt.expectedDoneReason != llm.DoneReasonContextShift {
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
}
} else {
// Should shift context
if tt.shouldRemove {
t.Error("should not remove sequence when context shifting enabled")
}
}
}
})
}
}
func TestPredictLimitLogic(t *testing.T) {
tests := []struct {
name string
numPredict int
numPredicted int
expectRemove bool
}{
{
name: "predict limit not reached",
numPredict: 5,
numPredicted: 3,
expectRemove: false,
},
{
name: "predict limit reached",
numPredict: 5,
numPredicted: 5,
expectRemove: true,
},
{
name: "predict limit exceeded",
numPredict: 5,
numPredicted: 6,
expectRemove: true,
},
{
name: "no predict limit",
numPredict: 0,
numPredicted: 100,
expectRemove: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the core logic from processBatch
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
if shouldRemove != tt.expectRemove {
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
}
})
}
}

View File

@@ -85,6 +85,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// true if context shifting should be enabled
shiftContext bool
doneReason llm.DoneReason
// Metrics
@@ -95,11 +98,12 @@ type Sequence struct {
}
type NewSequenceParams struct {
numPredict int
stop []string
numKeep int32
sampler sample.Sampler
embedding bool
numPredict int
stop []string
numKeep int32
sampler sample.Sampler
embedding bool
enableContextShift bool
}
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
@@ -121,7 +125,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if int32(len(inputs)) > s.cache.numCtx {
if int32(len(inputs)) > s.cache.numCtx && params.enableContextShift {
discard := int32(len(inputs)) - s.cache.numCtx
promptStart := params.numKeep + discard
@@ -175,6 +179,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
shiftContext: params.enableContextShift,
}, nil
}
@@ -341,13 +346,25 @@ func flushPending(seq *Sequence) bool {
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
if seq == nil {
return
}
// Mark the sequence as being removed to prevent further processing
s.seqs[seqIndex] = nil
if seq.cache != nil {
seq.cache.InUse = false
}
if len(seq.pendingResponses) > 0 {
flushPending(seq)
}
flushPending(seq)
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
@@ -431,6 +448,11 @@ func (s *Server) processBatch() error {
break
}
if !seq.shiftContext {
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
continue
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
@@ -629,11 +651,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: int32(req.Options.NumKeep),
sampler: sampler,
embedding: false,
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: int32(req.Options.NumKeep),
sampler: sampler,
embedding: false,
enableContextShift: req.Options.ShiftContext,
})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)

View File

@@ -0,0 +1,167 @@
package ollamarunner
import (
"testing"
"github.com/ollama/ollama/llm"
)
func TestEnableContextShiftLogic(t *testing.T) {
tests := []struct {
name string
enableContextShift bool
contextLength int32
cacheInputs int
pendingInputs int
minBatch int
expectedDoneReason llm.DoneReason
shouldRemove bool
}{
{
name: "context shifting enabled - should shift",
enableContextShift: true,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "context shifting disabled - should remove with DoneReasonContextShift",
enableContextShift: false,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonContextShift,
shouldRemove: true,
},
{
name: "context shifting disabled - within limits",
enableContextShift: false,
contextLength: 100,
cacheInputs: 50,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "context shifting disabled - exact limit",
enableContextShift: false,
contextLength: 100,
cacheInputs: 100,
pendingInputs: 0,
minBatch: 1,
expectedDoneReason: llm.DoneReasonContextShift,
shouldRemove: true,
},
{
name: "pending inputs - should break batch",
enableContextShift: true,
contextLength: 100,
cacheInputs: 50,
pendingInputs: 20,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "no pending inputs - should shift",
enableContextShift: true,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "long prompt with context shifting disabled - will be handled at runtime",
enableContextShift: false,
contextLength: 100,
cacheInputs: 0,
pendingInputs: 0,
minBatch: 150, // Simulates a long prompt
expectedDoneReason: llm.DoneReasonContextShift,
shouldRemove: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the core logic from processBatch - matches actual implementation
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
if tt.pendingInputs != 0 {
// Should break batch - don't remove sequence
if tt.shouldRemove {
t.Error("should not remove sequence when pending inputs exist")
}
} else if !tt.enableContextShift {
// Should remove with DoneReasonContextShift
if !tt.shouldRemove {
t.Error("should remove sequence when context shifting disabled")
}
if tt.expectedDoneReason != llm.DoneReasonContextShift {
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
}
} else {
// Should shift context - don't remove sequence
if tt.shouldRemove {
t.Error("should not remove sequence when context shifting enabled")
}
}
} else {
// Within limits - should not remove
if tt.shouldRemove {
t.Errorf("should not remove sequence when within context limits")
}
}
})
}
}
func TestPredictLimitLogic(t *testing.T) {
tests := []struct {
name string
numPredict int
numPredicted int
expectRemove bool
}{
{
name: "predict limit not reached",
numPredict: 5,
numPredicted: 3,
expectRemove: false,
},
{
name: "predict limit reached",
numPredict: 5,
numPredicted: 5,
expectRemove: true,
},
{
name: "predict limit exceeded",
numPredict: 5,
numPredicted: 6,
expectRemove: true,
},
{
name: "no predict limit",
numPredict: 0,
numPredicted: 100,
expectRemove: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the core logic from processBatch
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
if shouldRemove != tt.expectRemove {
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
}
})
}
}

View File

@@ -63,6 +63,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
if ctxLen > opts.NumCtx {
if !opts.ShiftContext {
return "", nil, fmt.Errorf("context length of %d tokens exceeded, context shifting is disabled", opts.NumCtx)
}
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {

View File

@@ -2,6 +2,7 @@ package server
import (
"bytes"
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -56,7 +57,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "A test. And a thumping good one at that, I'd wager. ",
error: fmt.Errorf("context length of 1 tokens exceeded, context shifting is disabled"),
},
},
{
@@ -69,10 +70,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
},
expect: expect{
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
},
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
},
},
{
@@ -85,10 +83,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
},
},
{
@@ -156,10 +151,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
error: fmt.Errorf("context length of 1024 tokens exceeded, context shifting is disabled"),
},
},
{
@@ -208,12 +200,25 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
// For truncation tests, disable context shifting to test the truncation behavior
if tt.name == "truncate messages" ||
tt.name == "truncate messages with image" ||
tt.name == "truncate messages with images" ||
tt.name == "truncate message with interleaved images" {
opts.ShiftContext = false
}
think := false
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
} else if tt.error != nil && err != nil {
if err.Error() != tt.error.Error() {
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
}
} else if tt.error != nil && err == nil {
t.Fatalf("expected err '%q', got nil", tt.error)
}
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
@@ -25,6 +26,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model"
@@ -968,3 +970,154 @@ func TestWaitForStream(t *testing.T) {
})
}
}
func TestEnableContextShiftNonStreamingResponse(t *testing.T) {
tests := []struct {
name string
enableContextShift bool
responses []llm.CompletionResponse
expectedDone bool
expectedDoneReason string
}{
{
name: "context shifting disabled - should have DoneReasonLength",
enableContextShift: false,
responses: []llm.CompletionResponse{
{Content: "Hello", Done: false},
{Content: " world", Done: false},
{Content: "", Done: true, DoneReason: llm.DoneReasonLength},
},
expectedDone: true,
expectedDoneReason: "length",
},
{
name: "context shifting enabled - should have DoneReasonStop",
enableContextShift: true,
responses: []llm.CompletionResponse{
{Content: "Hello", Done: false},
{Content: " world", Done: false},
{Content: "", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedDone: true,
expectedDoneReason: "stop",
},
{
name: "no final response with Done=true",
enableContextShift: false,
responses: []llm.CompletionResponse{
{Content: "Hello", Done: false},
{Content: " world", Done: false},
},
expectedDone: false,
expectedDoneReason: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// The last response in the channel will naturally be the final state
lastResponse := tt.responses[len(tt.responses)-1]
if lastResponse.Done != tt.expectedDone {
t.Errorf("expected Done=%v, got %v", tt.expectedDone, lastResponse.Done)
}
if tt.expectedDoneReason != "" {
if lastResponse.DoneReason.String() != tt.expectedDoneReason {
t.Errorf("expected DoneReason=%s, got %s", tt.expectedDoneReason, lastResponse.DoneReason.String())
}
}
})
}
}
func TestHandleScheduleError(t *testing.T) {
tests := []struct {
name string
errorMessage string
expectedStatus int
}{
{
name: "context length exceeded error",
errorMessage: "context length of 100 tokens exceeded, context shifting is disabled",
expectedStatus: http.StatusInternalServerError,
},
{
name: "other error",
errorMessage: "some other error",
expectedStatus: http.StatusInternalServerError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
err := errors.New(tt.errorMessage)
handleScheduleError(c, "test-model", err)
if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}
var response map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if errorMsg, ok := response["error"].(string); !ok || errorMsg != tt.errorMessage {
t.Errorf("expected error message '%s', got '%s'", tt.errorMessage, errorMsg)
}
})
}
}
func TestEnableContextShiftOptions(t *testing.T) {
t.Run("default options have enableContextShift=true", func(t *testing.T) {
opts := api.DefaultOptions()
if !opts.ShiftContext {
t.Errorf("expected EnableContextShift=true by default, got %v", opts.ShiftContext)
}
})
t.Run("can set enableContextShift to false", func(t *testing.T) {
opts := api.DefaultOptions()
opts.ShiftContext = false
if opts.ShiftContext {
t.Errorf("expected EnableContextShift=false after setting, got %v", opts.ShiftContext)
}
})
t.Run("JSON serialization omits false values", func(t *testing.T) {
opts := api.DefaultOptions()
opts.ShiftContext = false
data, err := json.Marshal(opts)
if err != nil {
t.Fatalf("failed to marshal options: %v", err)
}
// Check that enable_context_shift is not in the JSON when false
if bytes.Contains(data, []byte("enable_context_shift")) {
t.Errorf("expected enable_context_shift to be omitted from JSON when false, but found it in: %s", string(data))
}
})
t.Run("JSON serialization includes true values", func(t *testing.T) {
opts := api.DefaultOptions()
opts.ShiftContext = true
data, err := json.Marshal(opts)
if err != nil {
t.Fatalf("failed to marshal options: %v", err)
}
// Check that enable_context_shift is in the JSON when true
if !bytes.Contains(data, []byte("enable_context_shift")) {
t.Errorf("expected enable_context_shift to be in JSON when true, but not found in: %s", string(data))
}
})
}