mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
1 Commits
parth/decr
...
parth/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65ccbc85c |
@@ -20,7 +20,8 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
|||||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||||
// latest message and 2) system messages
|
// latest message and 2) system messages
|
||||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
|
// It also returns numKeep, the number of tokens in system messages + tools that should be protected from truncation.
|
||||||
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, numKeep int, _ error) {
|
||||||
var system []api.Message
|
var system []api.Message
|
||||||
|
|
||||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||||
@@ -44,12 +45,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
|
|
||||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := tokenize(ctx, p)
|
s, err := tokenize(ctx, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctxLen := len(s)
|
ctxLen := len(s)
|
||||||
@@ -71,7 +72,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
|
|
||||||
for cnt, msg := range msgs[currMsgIdx:] {
|
for cnt, msg := range msgs[currMsgIdx:] {
|
||||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||||
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
return "", nil, 0, errors.New("this model only supports one image while more than one image requested")
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefix string
|
var prefix string
|
||||||
@@ -98,10 +99,40 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
// truncate any messages that do not fit into the context window
|
// truncate any messages that do not fit into the context window
|
||||||
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return p, images, nil
|
// Compute numKeep: tokens for system messages + tools that should be protected from truncation
|
||||||
|
// Re-collect all system messages from the entire conversation
|
||||||
|
allSystemMsgs := make([]api.Message, 0)
|
||||||
|
for _, msg := range msgs {
|
||||||
|
if msg.Role == "system" {
|
||||||
|
allSystemMsgs = append(allSystemMsgs, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
protectedPrompt, err := renderPrompt(m, allSystemMsgs, tools, think)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protectedTokens, err := tokenize(ctx, protectedPrompt)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
numKeep = len(protectedTokens)
|
||||||
|
|
||||||
|
// Error if system+tools leaves less than 100 tokens for conversation
|
||||||
|
if numKeep > 0 && numKeep > opts.NumCtx-100 {
|
||||||
|
return "", nil, 0, fmt.Errorf("system prompt and tools (%d tokens) exceed context length (%d) minus required buffer (100 tokens)", numKeep, opts.NumCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap numKeep to ensure at least 200 tokens can be generated
|
||||||
|
if opts.NumCtx > 200 {
|
||||||
|
numKeep = min(numKeep, opts.NumCtx-200)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, images, numKeep, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
model := tt.model
|
model := tt.model
|
||||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
think := false
|
think := false
|
||||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
prompt, images, _, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||||
if tt.error == nil && err != nil {
|
if tt.error == nil && err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if tt.error != nil && err != tt.error {
|
} else if tt.error != nil && err != tt.error {
|
||||||
|
|||||||
@@ -459,11 +459,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
// the real chat handler, but doing this as a stopgap to get renderer
|
// the real chat handler, but doing this as a stopgap to get renderer
|
||||||
// support for generate
|
// support for generate
|
||||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
var numKeep int
|
||||||
|
prompt, images, numKeep, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Set numKeep to protect system prompt + tools from truncation during context shift
|
||||||
|
opts.NumKeep = numKeep
|
||||||
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
|
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
|
||||||
if req.Context != nil {
|
if req.Context != nil {
|
||||||
b.WriteString(prompt)
|
b.WriteString(prompt)
|
||||||
@@ -2076,13 +2079,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := req.Truncate == nil || *req.Truncate
|
truncate := req.Truncate == nil || *req.Truncate
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
prompt, images, numKeep, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set numKeep to protect system prompt + tools from truncation during context shift
|
||||||
|
opts.NumKeep = numKeep
|
||||||
|
|
||||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||||
if req.DebugRenderOnly {
|
if req.DebugRenderOnly {
|
||||||
c.JSON(http.StatusOK, api.ChatResponse{
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
@@ -2289,7 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
msgs = append(msgs, msg)
|
msgs = append(msgs, msg)
|
||||||
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
prompt, _, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error applying structured outputs", "error", err)
|
slog.Error("chat prompt error applying structured outputs", "error", err)
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
|||||||
Reference in New Issue
Block a user