mirror of
https://github.com/ollama/ollama.git
synced 2026-01-15 19:09:25 -05:00
Compare commits
1 Commits
pdevine/x-
...
parth/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65ccbc85c |
@@ -20,7 +20,8 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
// It also returns numKeep, the number of tokens in system messages + tools that should be protected from truncation.
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, numKeep int, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -44,12 +45,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
ctxLen := len(s)
|
||||
@@ -71,7 +72,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
||||
return "", nil, 0, errors.New("this model only supports one image while more than one image requested")
|
||||
}
|
||||
|
||||
var prefix string
|
||||
@@ -98,10 +99,40 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
// truncate any messages that do not fit into the context window
|
||||
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
return p, images, nil
|
||||
// Compute numKeep: tokens for system messages + tools that should be protected from truncation
|
||||
// Re-collect all system messages from the entire conversation
|
||||
allSystemMsgs := make([]api.Message, 0)
|
||||
for _, msg := range msgs {
|
||||
if msg.Role == "system" {
|
||||
allSystemMsgs = append(allSystemMsgs, msg)
|
||||
}
|
||||
}
|
||||
protectedPrompt, err := renderPrompt(m, allSystemMsgs, tools, think)
|
||||
if err != nil {
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
protectedTokens, err := tokenize(ctx, protectedPrompt)
|
||||
if err != nil {
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
numKeep = len(protectedTokens)
|
||||
|
||||
// Error if system+tools leaves less than 100 tokens for conversation
|
||||
if numKeep > 0 && numKeep > opts.NumCtx-100 {
|
||||
return "", nil, 0, fmt.Errorf("system prompt and tools (%d tokens) exceed context length (%d) minus required buffer (100 tokens)", numKeep, opts.NumCtx)
|
||||
}
|
||||
|
||||
// Cap numKeep to ensure at least 200 tokens can be generated
|
||||
if opts.NumCtx > 200 {
|
||||
numKeep = min(numKeep, opts.NumCtx-200)
|
||||
}
|
||||
|
||||
return p, images, numKeep, nil
|
||||
}
|
||||
|
||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
|
||||
@@ -235,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||
prompt, images, _, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
||||
@@ -459,11 +459,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
var numKeep int
|
||||
prompt, images, numKeep, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// Set numKeep to protect system prompt + tools from truncation during context shift
|
||||
opts.NumKeep = numKeep
|
||||
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
|
||||
if req.Context != nil {
|
||||
b.WriteString(prompt)
|
||||
@@ -2076,13 +2079,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
prompt, images, numKeep, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Set numKeep to protect system prompt + tools from truncation during context shift
|
||||
opts.NumKeep = numKeep
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
@@ -2289,7 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
msgs = append(msgs, msg)
|
||||
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
prompt, _, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error applying structured outputs", "error", err)
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
|
||||
Reference in New Issue
Block a user