Compare commits

...

1 Commits

Author SHA1 Message Date
ParthSareen
d65ccbc85c server: compute numKeep to protect system prompts during context shift
Previously, NumKeep defaulted to 4, causing system prompts to be
truncated when the context window filled up and a shift operation
occurred. This resulted in models losing their persona/instructions
during long conversations.

Changes:
- chatPrompt() now returns numKeep (token count of system messages + tools)
- ChatHandler and GenerateHandler set opts.NumKeep from the computed value
- Error if system+tools exceeds NumCtx-100 (too little room for conversation)
- Cap numKeep at NumCtx-200 to ensure at least 200 tokens for generation
2026-01-07 14:38:55 -08:00
3 changed files with 47 additions and 10 deletions

View File

@@ -20,7 +20,8 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
// It also returns numKeep, the number of tokens in system messages + tools that should be protected from truncation.
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, numKeep int, _ error) {
var system []api.Message
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -44,12 +45,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil {
return "", nil, err
return "", nil, 0, err
}
s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
return "", nil, 0, err
}
ctxLen := len(s)
@@ -71,7 +72,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
for cnt, msg := range msgs[currMsgIdx:] {
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
return "", nil, errors.New("this model only supports one image while more than one image requested")
return "", nil, 0, errors.New("this model only supports one image while more than one image requested")
}
var prefix string
@@ -98,10 +99,40 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
if err != nil {
return "", nil, err
return "", nil, 0, err
}
return p, images, nil
// Compute numKeep: tokens for system messages + tools that should be protected from truncation
// Re-collect all system messages from the entire conversation
allSystemMsgs := make([]api.Message, 0)
for _, msg := range msgs {
if msg.Role == "system" {
allSystemMsgs = append(allSystemMsgs, msg)
}
}
protectedPrompt, err := renderPrompt(m, allSystemMsgs, tools, think)
if err != nil {
return "", nil, 0, err
}
protectedTokens, err := tokenize(ctx, protectedPrompt)
if err != nil {
return "", nil, 0, err
}
numKeep = len(protectedTokens)
// Error if system+tools leaves less than 100 tokens for conversation
if numKeep > 0 && numKeep > opts.NumCtx-100 {
return "", nil, 0, fmt.Errorf("system prompt and tools (%d tokens) exceed context length (%d) minus required buffer (100 tokens)", numKeep, opts.NumCtx)
}
// Cap numKeep to ensure at least 200 tokens can be generated
if opts.NumCtx > 200 {
numKeep = min(numKeep, opts.NumCtx-200)
}
return p, images, numKeep, nil
}
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {

View File

@@ -235,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
think := false
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
prompt, images, _, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {

View File

@@ -459,11 +459,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
var numKeep int
prompt, images, numKeep, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Set numKeep to protect system prompt + tools from truncation during context shift
opts.NumKeep = numKeep
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
if req.Context != nil {
b.WriteString(prompt)
@@ -2076,13 +2079,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
truncate := req.Truncate == nil || *req.Truncate
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
prompt, images, numKeep, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Set numKeep to protect system prompt + tools from truncation during context shift
opts.NumKeep = numKeep
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.ChatResponse{
@@ -2289,7 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = append(msgs, msg)
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
prompt, _, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error applying structured outputs", "error", err)
ch <- gin.H{"error": err.Error()}