mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
chore: refactor endpoints to use same inferencing path, add automatic retrial mechanism in case of errors (#9029)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
3d9ccd1ddc
commit
ee96e5e08d
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -197,54 +198,24 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
xlog.Debug("Anthropic MCP re-templating", "iteration", mcpIteration, "prompt_len", len(predInput))
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertFuncsToOpenAITools(funcs)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
openAIReq.Metadata = input.Metadata
|
||||
|
||||
toolsJSON := ""
|
||||
if len(funcs) > 0 {
|
||||
openAITools := make([]functions.Tool, len(funcs))
|
||||
for i, f := range funcs {
|
||||
openAITools[i] = functions.Tool{Type: "function", Function: f}
|
||||
}
|
||||
if toolsBytes, err := json.Marshal(openAITools); err == nil {
|
||||
toolsJSON = string(toolsBytes)
|
||||
}
|
||||
var result string
|
||||
cb := func(s string, c *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, nil, nil, nil, input.Metadata)
|
||||
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic model inference failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var prediction backend.LLMResponse
|
||||
var result string
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic prediction failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
|
||||
}
|
||||
result = backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
if result != "" || !shouldUseFn {
|
||||
break
|
||||
}
|
||||
xlog.Warn("Anthropic: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
}
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first, fall back to text parsing
|
||||
var toolCalls []functions.FuncCallResults
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] Anthropic: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
toolCalls = deltaToolCalls
|
||||
} else {
|
||||
@@ -350,8 +321,8 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
StopReason: &stopReason,
|
||||
Content: contentBlocks,
|
||||
Usage: schema.AnthropicUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
InputTokens: tokenUsage.Prompt,
|
||||
OutputTokens: tokenUsage.Completion,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -397,12 +368,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
xlog.Debug("Anthropic MCP stream re-templating", "iteration", mcpIteration)
|
||||
}
|
||||
|
||||
openAIMessages := openAIReq.Messages
|
||||
images := []string{}
|
||||
for _, m := range openAIMessages {
|
||||
images = append(images, m.StringImages...)
|
||||
}
|
||||
|
||||
// Track accumulated content for tool call detection
|
||||
accumulatedContent := ""
|
||||
currentBlockIndex := 0
|
||||
@@ -481,38 +446,19 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
return true
|
||||
}
|
||||
|
||||
toolsJSON := ""
|
||||
if len(funcs) > 0 {
|
||||
openAITools := make([]functions.Tool, len(funcs))
|
||||
for i, f := range funcs {
|
||||
openAITools[i] = functions.Tool{Type: "function", Function: f}
|
||||
}
|
||||
if toolsBytes, err := json.Marshal(openAITools); err == nil {
|
||||
toolsJSON = string(toolsBytes)
|
||||
}
|
||||
}
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertFuncsToOpenAITools(funcs)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
openAIReq.Metadata = input.Metadata
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, nil, nil, nil, input.Metadata)
|
||||
_, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic stream model inference failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err))
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Anthropic stream prediction failed", "error", err)
|
||||
return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err))
|
||||
}
|
||||
|
||||
// Also check chat deltas for tool calls
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 {
|
||||
collectedToolCalls = deltaToolCalls
|
||||
}
|
||||
|
||||
@@ -595,7 +541,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
StopReason: &stopReason,
|
||||
},
|
||||
Usage: &schema.AnthropicUsage{
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
OutputTokens: tokenUsage.Completion,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -613,6 +559,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
|
||||
tools := make([]functions.Tool, len(funcs))
|
||||
for i, f := range funcs {
|
||||
tools[i] = functions.Tool{Type: "function", Function: f}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
|
||||
@@ -82,51 +82,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
template = s
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
|
||||
// Track accumulated content for reasoning extraction
|
||||
accumulatedContent := ""
|
||||
lastEmittedReasoning := ""
|
||||
lastEmittedCleanedContent := ""
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
accumulatedContent += s
|
||||
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
// Calculate new reasoning delta (what we haven't emitted yet)
|
||||
var reasoningDelta *string
|
||||
if currentReasoning != lastEmittedReasoning {
|
||||
// Extract only the new part
|
||||
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
|
||||
newReasoning := currentReasoning[len(lastEmittedReasoning):]
|
||||
reasoningDelta = &newReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
} else if currentReasoning != "" {
|
||||
// If reasoning changed in a non-append way, emit the full current reasoning
|
||||
reasoningDelta = ¤tReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate content delta from cleaned content
|
||||
var deltaContent string
|
||||
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
|
||||
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else if cleanedContent != lastEmittedCleanedContent {
|
||||
// If cleaned content changed but not in a simple append, extract delta from cleaned content
|
||||
// This handles cases where thinking tags are removed mid-stream
|
||||
if lastEmittedCleanedContent == "" {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else {
|
||||
// Content changed in non-append way, use the new cleaned content
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
}
|
||||
}
|
||||
// Only emit content if there's actual content (not just thinking tags)
|
||||
// If deltaContent is empty, we still emit the response but with empty content
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(s)
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
@@ -139,12 +98,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
|
||||
delta := &schema.Message{}
|
||||
// Only include content if there's actual content (not just thinking tags)
|
||||
if deltaContent != "" {
|
||||
delta.Content = &deltaContent
|
||||
if contentDelta != "" {
|
||||
delta.Content = &contentDelta
|
||||
}
|
||||
if reasoningDelta != nil && *reasoningDelta != "" {
|
||||
delta.Reasoning = reasoningDelta
|
||||
if reasoningDelta != "" {
|
||||
delta.Reasoning = &reasoningDelta
|
||||
}
|
||||
|
||||
resp := schema.OpenAIResponse{
|
||||
@@ -171,43 +129,25 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
template = prompt
|
||||
}
|
||||
thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig)
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
|
||||
// Track accumulated content for incremental reasoning and content extraction (mirrors process())
|
||||
accumulatedContent := ""
|
||||
lastEmittedReasoning := ""
|
||||
lastEmittedCleanedContent := ""
|
||||
sentInitialRole := false
|
||||
|
||||
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
accumulatedContent += s
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(s)
|
||||
|
||||
// Incremental reasoning extraction — emit reasoning deltas in their own SSE chunks
|
||||
// before any tool-call chunks (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
var reasoningDelta *string
|
||||
if currentReasoning != lastEmittedReasoning {
|
||||
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
|
||||
newReasoning := currentReasoning[len(lastEmittedReasoning):]
|
||||
reasoningDelta = &newReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
} else if currentReasoning != "" {
|
||||
reasoningDelta = ¤tReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
}
|
||||
}
|
||||
|
||||
if reasoningDelta != nil && *reasoningDelta != "" {
|
||||
// Emit reasoning deltas in their own SSE chunks before any tool-call chunks
|
||||
// (OpenAI spec: reasoning and tool_calls never share a delta)
|
||||
if reasoningDelta != "" {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Reasoning: reasoningDelta},
|
||||
Delta: &schema.Message{Reasoning: &reasoningDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
@@ -217,32 +157,22 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// Stream content deltas (cleaned of reasoning tags) while no tool calls
|
||||
// have been detected. Once the incremental parser finds tool calls,
|
||||
// content stops — per OpenAI spec, content and tool_calls don't mix.
|
||||
if lastEmittedCount == 0 && cleanedContent != "" {
|
||||
var deltaContent string
|
||||
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
|
||||
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else if cleanedContent != lastEmittedCleanedContent {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
}
|
||||
if deltaContent != "" {
|
||||
if !sentInitialRole {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole = true
|
||||
}
|
||||
if lastEmittedCount == 0 && contentDelta != "" {
|
||||
if !sentInitialRole {
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: &deltaContent},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
sentInitialRole = true
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
ID: id, Created: created, Model: req.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: &contentDelta},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -349,7 +279,25 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
},
|
||||
func(attempt int) bool {
|
||||
// After streaming completes: check if we got actionable content
|
||||
cleaned := extractor.CleanedContent()
|
||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||
// but we need to know here whether to retry)
|
||||
hasToolCalls := lastEmittedCount > 0
|
||||
if cleaned == "" && !hasToolCalls {
|
||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||
extractor.ResetAndSuppressReasoning()
|
||||
result = ""
|
||||
lastEmittedCount = 0
|
||||
sentInitialRole = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -366,10 +314,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
||||
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
reasoning, result = reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig)
|
||||
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
|
||||
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(result, config.FunctionsConfig)
|
||||
reasoning = extractor.Reasoning()
|
||||
cleanedResult := extractor.CleanedContent()
|
||||
textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig)
|
||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", textContentToReturn)
|
||||
noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0
|
||||
@@ -389,7 +338,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
if sentInitialRole {
|
||||
// Content was already streamed during the callback — just emit usage.
|
||||
delta := &schema.Message{}
|
||||
if reasoning != "" && lastEmittedReasoning == "" {
|
||||
if reasoning != "" && extractor.Reasoning() == "" {
|
||||
delta.Reasoning = &reasoning
|
||||
}
|
||||
responses <- schema.OpenAIResponse{
|
||||
@@ -406,7 +355,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
|
||||
result, err := handleQuestion(config, functionResults, result, prompt)
|
||||
result, err := handleQuestion(config, functionResults, extractor.CleanedContent(), prompt)
|
||||
if err != nil {
|
||||
xlog.Error("error handling question", "error", err)
|
||||
return err
|
||||
@@ -981,7 +930,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// is deferred to after ComputeChoices so we can check chat deltas first
|
||||
// and avoid redundant Go-side parsing.
|
||||
var cbRawResult, cbReasoning string
|
||||
var emptyRetryNeeded bool
|
||||
|
||||
tokenCallback := func(s string, c *[]schema.Choice) {
|
||||
reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig)
|
||||
@@ -1001,146 +949,133 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
cbReasoning = reasoning
|
||||
}
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var result []schema.Choice
|
||||
var tokenUsage backend.TokenUsage
|
||||
var err error
|
||||
|
||||
var chatDeltas []*pb.ChatDelta
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
emptyRetryNeeded = false
|
||||
result, tokenUsage, chatDeltas, err = ComputeChoices(
|
||||
input,
|
||||
predInput,
|
||||
config,
|
||||
cl,
|
||||
startupOptions,
|
||||
ml,
|
||||
tokenCallback,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Tool parsing is deferred here (only when shouldUseFn)
|
||||
if shouldUseFn {
|
||||
var funcResults []functions.FuncCallResults
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls))
|
||||
funcResults = deltaToolCalls
|
||||
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text
|
||||
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
|
||||
textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig)
|
||||
cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig)
|
||||
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
|
||||
result, tokenUsage, chatDeltas, err = ComputeChoices(
|
||||
input,
|
||||
predInput,
|
||||
config,
|
||||
cl,
|
||||
startupOptions,
|
||||
ml,
|
||||
tokenCallback,
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
if !shouldUseFn {
|
||||
return false
|
||||
}
|
||||
|
||||
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
if cbRawResult == "" && textContentToReturn == "" {
|
||||
xlog.Warn("Backend returned empty content in tool-calling context, will retry")
|
||||
emptyRetryNeeded = true
|
||||
continue
|
||||
}
|
||||
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
|
||||
if qErr != nil {
|
||||
xlog.Error("error handling question", "error", qErr)
|
||||
emptyRetryNeeded = true
|
||||
continue
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
message := &schema.Message{Role: "assistant", Content: &qResult}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Message: message,
|
||||
})
|
||||
default:
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice := schema.Choice{
|
||||
FinishReason: &toolCallsReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
toolChoice.Message.Reasoning = &cbReasoning
|
||||
}
|
||||
|
||||
for _, ss := range funcResults {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
toolCallID := ss.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = id
|
||||
}
|
||||
if len(input.Tools) > 0 {
|
||||
toolChoice.Message.Content = textContentToReturn
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Deprecated function_call format
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
message := &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &functionCallReason,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
result = append(result, toolChoice)
|
||||
}
|
||||
// Retry when backend produced only reasoning and no content/tool calls.
|
||||
// Full tool parsing is deferred until after ComputeChoices returns
|
||||
// (when chat deltas are available), but we can detect the empty case here.
|
||||
if cbRawResult == "" && textContentToReturn == "" {
|
||||
xlog.Warn("Backend produced reasoning without actionable content, retrying",
|
||||
"reasoning_len", len(cbReasoning), "attempt", attempt+1)
|
||||
cbRawResult = ""
|
||||
cbReasoning = ""
|
||||
textContentToReturn = ""
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if !emptyRetryNeeded {
|
||||
break
|
||||
}
|
||||
xlog.Warn("Retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
}
|
||||
return false
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if emptyRetryNeeded {
|
||||
xlog.Warn("All retries exhausted, backend still returning empty content")
|
||||
stopReason := FinishReasonStop
|
||||
empty := ""
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Index: 0,
|
||||
Message: &schema.Message{Role: "assistant", Content: &empty},
|
||||
})
|
||||
// Tool parsing is deferred here (only when shouldUseFn) so chat deltas are available
|
||||
if shouldUseFn {
|
||||
var funcResults []functions.FuncCallResults
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls))
|
||||
funcResults = deltaToolCalls
|
||||
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text
|
||||
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
|
||||
textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig)
|
||||
cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig)
|
||||
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
|
||||
}
|
||||
|
||||
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
|
||||
if qErr != nil {
|
||||
xlog.Error("error handling question", "error", qErr)
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
message := &schema.Message{Role: "assistant", Content: &qResult}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Message: message,
|
||||
})
|
||||
default:
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice := schema.Choice{
|
||||
FinishReason: &toolCallsReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
toolChoice.Message.Reasoning = &cbReasoning
|
||||
}
|
||||
|
||||
for _, ss := range funcResults {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
toolCallID := ss.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = id
|
||||
}
|
||||
if len(input.Tools) > 0 {
|
||||
toolChoice.Message.Content = textContentToReturn
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Deprecated function_call format
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
message := &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &functionCallReason,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
result = append(result, toolChoice)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MCP server-side tool execution loop:
|
||||
@@ -1277,5 +1212,5 @@ func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCall
|
||||
|
||||
xlog.Debug("No action received from LLM, without a message, computing a reply")
|
||||
|
||||
return "", fmt.Errorf("no action received from LLM, without a message, computing a reply")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
157
core/http/endpoints/openai/chat_test.go
Normal file
157
core/http/endpoints/openai/chat_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
var _ = Describe("handleQuestion", func() {
|
||||
var cfg *config.ModelConfig
|
||||
|
||||
BeforeEach(func() {
|
||||
cfg = &config.ModelConfig{}
|
||||
})
|
||||
|
||||
Context("with no function results but non-empty result", func() {
|
||||
It("should return the result directly", func() {
|
||||
result, err := handleQuestion(cfg, nil, "Hello world", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("Hello world"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with no function results and empty result", func() {
|
||||
It("should return empty string", func() {
|
||||
result, err := handleQuestion(cfg, nil, "", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with function result containing a message argument", func() {
|
||||
It("should extract the message from function arguments", func() {
|
||||
funcResults := []functions.FuncCallResults{
|
||||
{
|
||||
Name: "answer",
|
||||
Arguments: `{"message": "This is the answer"}`,
|
||||
},
|
||||
}
|
||||
result, err := handleQuestion(cfg, funcResults, "", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("This is the answer"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with function result containing empty message", func() {
|
||||
It("should return empty string when message is empty", func() {
|
||||
funcResults := []functions.FuncCallResults{
|
||||
{
|
||||
Name: "answer",
|
||||
Arguments: `{"message": ""}`,
|
||||
},
|
||||
}
|
||||
result, err := handleQuestion(cfg, funcResults, "", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with function result containing invalid JSON arguments", func() {
|
||||
It("should return empty string gracefully", func() {
|
||||
funcResults := []functions.FuncCallResults{
|
||||
{
|
||||
Name: "answer",
|
||||
Arguments: "not json",
|
||||
},
|
||||
}
|
||||
result, err := handleQuestion(cfg, funcResults, "", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with cleaned content (no think tags)", func() {
|
||||
It("should return content without think tags", func() {
|
||||
// This tests the bug fix: handleQuestion should receive cleaned content,
|
||||
// not raw text with <think> tags
|
||||
result, err := handleQuestion(cfg, nil, "Just the answer", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("Just the answer"))
|
||||
Expect(result).ToNot(ContainSubstring("<think>"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with raw think tags passed as result", func() {
|
||||
It("would return content with think tags", func() {
|
||||
result, err := handleQuestion(cfg, nil, "<think>reasoning</think>answer", "prompt")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).To(Equal("<think>reasoning</think>answer"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("mergeToolCallDeltas", func() {
|
||||
Context("with new tool calls", func() {
|
||||
It("should append new tool calls", func() {
|
||||
existing := []schema.ToolCall{}
|
||||
deltas := []schema.ToolCall{
|
||||
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}},
|
||||
}
|
||||
result := mergeToolCallDeltas(existing, deltas)
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(result[0].ID).To(Equal("tc1"))
|
||||
Expect(result[0].FunctionCall.Name).To(Equal("search"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with argument appending", func() {
|
||||
It("should append arguments to existing tool call", func() {
|
||||
existing := []schema.ToolCall{
|
||||
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search", Arguments: `{"q":`}},
|
||||
}
|
||||
deltas := []schema.ToolCall{
|
||||
{Index: 0, FunctionCall: schema.FunctionCall{Arguments: `"hello"}`}},
|
||||
}
|
||||
result := mergeToolCallDeltas(existing, deltas)
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(result[0].FunctionCall.Arguments).To(Equal(`{"q":"hello"}`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with multiple tool calls", func() {
|
||||
It("should track multiple tool calls by index", func() {
|
||||
existing := []schema.ToolCall{}
|
||||
deltas1 := []schema.ToolCall{
|
||||
{Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}},
|
||||
}
|
||||
result := mergeToolCallDeltas(existing, deltas1)
|
||||
|
||||
deltas2 := []schema.ToolCall{
|
||||
{Index: 1, ID: "tc2", Type: "function", FunctionCall: schema.FunctionCall{Name: "browse"}},
|
||||
}
|
||||
result = mergeToolCallDeltas(result, deltas2)
|
||||
Expect(result).To(HaveLen(2))
|
||||
Expect(result[0].FunctionCall.Name).To(Equal("search"))
|
||||
Expect(result[1].FunctionCall.Name).To(Equal("browse"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with ID update on existing tool call", func() {
|
||||
It("should update ID when provided in delta", func() {
|
||||
existing := []schema.ToolCall{
|
||||
{Index: 0, FunctionCall: schema.FunctionCall{Name: "search"}},
|
||||
}
|
||||
deltas := []schema.ToolCall{
|
||||
{Index: 0, ID: "new-id"},
|
||||
}
|
||||
result := mergeToolCallDeltas(existing, deltas)
|
||||
Expect(result).To(HaveLen(1))
|
||||
Expect(result[0].ID).To(Equal("new-id"))
|
||||
Expect(result[0].FunctionCall.Name).To(Equal("search"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func ComputeChoices(
|
||||
@@ -19,7 +21,9 @@ func ComputeChoices(
|
||||
o *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
shouldRetry ...func(int) bool,
|
||||
) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
|
||||
n := req.N // number of completions to return
|
||||
result := []schema.Choice{}
|
||||
|
||||
@@ -27,6 +31,12 @@ func ComputeChoices(
|
||||
n = 1
|
||||
}
|
||||
|
||||
// Extract the optional shouldRetry callback
|
||||
var shouldRetryFn func(int) bool
|
||||
if len(shouldRetry) > 0 {
|
||||
shouldRetryFn = shouldRetry[0]
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
for _, m := range req.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
@@ -82,7 +92,7 @@ func ComputeChoices(
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(
|
||||
predFunc, err := backend.ModelInferenceFunc(
|
||||
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, req.Metadata)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, nil, err
|
||||
@@ -91,32 +101,49 @@ func ComputeChoices(
|
||||
tokenUsage := backend.TokenUsage{}
|
||||
var allChatDeltas []*pb.ChatDelta
|
||||
|
||||
const maxRetries = 5
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, nil, err
|
||||
var prediction backend.LLMResponse
|
||||
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
p, err := predFunc()
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, nil, err
|
||||
}
|
||||
prediction = p
|
||||
|
||||
// Built-in: retry on truly empty response (no tokens at all)
|
||||
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries {
|
||||
xlog.Warn("Backend returned empty response, retrying",
|
||||
"attempt", attempt+1, "maxRetries", maxRetries)
|
||||
continue
|
||||
}
|
||||
|
||||
tokenUsage.Prompt = prediction.Usage.Prompt
|
||||
tokenUsage.Completion = prediction.Usage.Completion
|
||||
tokenUsage.TimingPromptProcessing = prediction.Usage.TimingPromptProcessing
|
||||
tokenUsage.TimingTokenGeneration = prediction.Usage.TimingTokenGeneration
|
||||
|
||||
allChatDeltas = prediction.ChatDeltas
|
||||
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Caller-driven retry (tool parsing, reasoning-only, etc.)
|
||||
if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries {
|
||||
// Caller has already reset its state inside shouldRetry
|
||||
result = result[:0]
|
||||
allChatDeltas = nil
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
tokenUsage.Prompt += prediction.Usage.Prompt
|
||||
tokenUsage.Completion += prediction.Usage.Completion
|
||||
tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing
|
||||
tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration
|
||||
|
||||
// Collect chat deltas from C++ autoparser
|
||||
if len(prediction.ChatDeltas) > 0 {
|
||||
allChatDeltas = append(allChatDeltas, prediction.ChatDeltas...)
|
||||
}
|
||||
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Add logprobs to the last choice if present
|
||||
if prediction.Logprobs != nil && len(result) > 0 {
|
||||
result[len(result)-1].Logprobs = prediction.Logprobs
|
||||
}
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
return result, tokenUsage, allChatDeltas, err
|
||||
}
|
||||
|
||||
402
core/http/endpoints/openai/inference_test.go
Normal file
402
core/http/endpoints/openai/inference_test.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type modelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error)
|
||||
|
||||
var _ = Describe("ComputeChoices", func() {
|
||||
var (
|
||||
origInference modelInferenceFunc
|
||||
cfg *config.ModelConfig
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
// mockInference installs a stub that yields the given responses sequentially.
|
||||
// After all responses are consumed, the last one is repeated.
|
||||
mockInference := func(responses []backend.LLMResponse) {
|
||||
idx := 0
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
resp := responses[idx]
|
||||
if idx < len(responses)-1 {
|
||||
idx++
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
origInference = backend.ModelInferenceFunc
|
||||
cfg = &config.ModelConfig{}
|
||||
appCfg = config.NewApplicationConfig()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
backend.ModelInferenceFunc = origInference
|
||||
})
|
||||
|
||||
makeReq := func() *schema.OpenAIRequest {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
_ = cancel
|
||||
return &schema.OpenAIRequest{
|
||||
Context: ctx,
|
||||
Cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
Context("normal response (no retry needed)", func() {
|
||||
It("should return choices on first attempt", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "Hello world", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}},
|
||||
})
|
||||
|
||||
var captured string
|
||||
choices, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test prompt", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
captured = s
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(captured).To(Equal("Hello world"))
|
||||
Expect(usage.Prompt).To(Equal(10))
|
||||
Expect(usage.Completion).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Context("empty response triggers built-in retry", func() {
|
||||
It("should retry and eventually return non-empty response", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: ""}, // attempt 0: empty
|
||||
{Response: " "}, // attempt 1: whitespace-only
|
||||
{Response: "Got it", Usage: backend.TokenUsage{Prompt: 8, Completion: 3}}, // attempt 2: success
|
||||
})
|
||||
|
||||
choices, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("Got it"))
|
||||
Expect(usage.Prompt).To(Equal(8))
|
||||
Expect(usage.Completion).To(Equal(3))
|
||||
})
|
||||
})
|
||||
|
||||
Context("all retries exhausted on empty response", func() {
|
||||
It("should return the empty response after max retries", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: ""}, // always empty
|
||||
})
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// After maxRetries, it proceeds with the empty response
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("shouldRetry callback", func() {
|
||||
It("should call shouldRetry and retry when it returns true", func() {
|
||||
callCount := 0
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "reasoning-only", Usage: backend.TokenUsage{Prompt: 5, Completion: 2}},
|
||||
{Response: "actual-answer", Usage: backend.TokenUsage{Prompt: 5, Completion: 4}},
|
||||
})
|
||||
|
||||
retryAttempts := []int{}
|
||||
choices, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
callCount++
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
retryAttempts = append(retryAttempts, attempt)
|
||||
// Retry on first attempt only
|
||||
return attempt == 0
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("actual-answer"))
|
||||
// shouldRetry was called twice: once returning true (retry), once returning false (proceed)
|
||||
Expect(retryAttempts).To(Equal([]int{0, 1}))
|
||||
// cb was called twice (once per attempt)
|
||||
Expect(callCount).To(Equal(2))
|
||||
// Token usage should be from the LATEST attempt
|
||||
Expect(usage.Prompt).To(Equal(5))
|
||||
Expect(usage.Completion).To(Equal(4))
|
||||
})
|
||||
|
||||
It("should not retry when shouldRetry returns false", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "first-response"},
|
||||
})
|
||||
|
||||
shouldRetryCalled := false
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
shouldRetryCalled = true
|
||||
return false
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("first-response"))
|
||||
Expect(shouldRetryCalled).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("shouldRetry not provided (variadic omitted)", func() {
|
||||
It("should work without shouldRetry parameter", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "works"},
|
||||
})
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("works"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("token usage from latest attempt", func() {
|
||||
It("should use token usage from the last attempt, not accumulated", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "retry-me", Usage: backend.TokenUsage{Prompt: 100, Completion: 50}},
|
||||
{Response: "final", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}},
|
||||
})
|
||||
|
||||
_, usage, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool { return attempt == 0 },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Should be the LATEST attempt's usage, not accumulated
|
||||
Expect(usage.Prompt).To(Equal(10))
|
||||
Expect(usage.Completion).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Context("chat deltas from latest attempt", func() {
|
||||
It("should return chat deltas from the last attempt only", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{
|
||||
Response: "retry-me",
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: "old"}},
|
||||
},
|
||||
{
|
||||
Response: "final",
|
||||
ChatDeltas: []*pb.ChatDelta{{Content: "new"}},
|
||||
},
|
||||
})
|
||||
|
||||
_, _, deltas, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool { return attempt == 0 },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(deltas).To(HaveLen(1))
|
||||
Expect(deltas[0].Content).To(Equal("new"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("result choices cleared on retry", func() {
|
||||
It("should only contain choices from the final attempt", func() {
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "bad-choice"},
|
||||
{Response: "good-choice"},
|
||||
})
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool { return attempt == 0 },
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(choices[0].Text).To(Equal("good-choice"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("shouldRetry with max retries cap", func() {
|
||||
It("should stop retrying after maxRetries even if shouldRetry returns true", func() {
|
||||
attempts := 0
|
||||
mockInference([]backend.LLMResponse{
|
||||
{Response: "always-retry"},
|
||||
})
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
func(attempt int) bool {
|
||||
attempts++
|
||||
return true // always want to retry
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
// maxRetries is 5, so shouldRetry is called for attempts 0..4,
|
||||
// but attempt 5 is the final one where shouldRetry can't trigger continue
|
||||
Expect(attempts).To(BeNumerically("<=", 6))
|
||||
})
|
||||
})
|
||||
|
||||
Context("N > 1 completions", func() {
|
||||
It("should produce N separate completions", func() {
|
||||
callIdx := 0
|
||||
responses := []string{"first", "second", "third"}
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
resp := backend.LLMResponse{Response: responses[callIdx]}
|
||||
if callIdx < len(responses)-1 {
|
||||
callIdx++
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
|
||||
req := makeReq()
|
||||
req.N = 3
|
||||
choices, _, _, err := ComputeChoices(
|
||||
req, "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(3))
|
||||
Expect(choices[0].Text).To(Equal("first"))
|
||||
Expect(choices[1].Text).To(Equal("second"))
|
||||
Expect(choices[2].Text).To(Equal("third"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with streaming token callback", func() {
|
||||
It("should call tokenCallback for streaming responses", func() {
|
||||
var streamedTokens []string
|
||||
backend.ModelInferenceFunc = func(
|
||||
ctx context.Context, s string, messages schema.Messages,
|
||||
images, videos, audios []string,
|
||||
loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
tokenCallback func(string, backend.TokenUsage) bool,
|
||||
tools, toolChoice string,
|
||||
logprobs, topLogprobs *int,
|
||||
logitBias map[string]float64,
|
||||
metadata map[string]string,
|
||||
) (func() (backend.LLMResponse, error), error) {
|
||||
predFunc := func() (backend.LLMResponse, error) {
|
||||
if tokenCallback != nil {
|
||||
tokenCallback("Hello", backend.TokenUsage{Prompt: 5})
|
||||
tokenCallback(" world", backend.TokenUsage{Prompt: 5, Completion: 2})
|
||||
}
|
||||
return backend.LLMResponse{
|
||||
Response: "Hello world",
|
||||
Usage: backend.TokenUsage{Prompt: 5, Completion: 2},
|
||||
}, nil
|
||||
}
|
||||
return predFunc, nil
|
||||
}
|
||||
|
||||
choices, _, _, err := ComputeChoices(
|
||||
makeReq(), "test", cfg, nil, appCfg, nil,
|
||||
func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
},
|
||||
func(s string, usage backend.TokenUsage) bool {
|
||||
streamedTokens = append(streamedTokens, s)
|
||||
return true
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
Expect(streamedTokens).To(Equal([]string{"Hello", " world"}))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -879,34 +880,14 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
xlog.Debug("Background MCP re-templating", "iteration", mcpIteration)
|
||||
}
|
||||
|
||||
images := []string{}
|
||||
videos := []string{}
|
||||
audios := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
toolsJSON := serializeToolsForBackend(input.Tools)
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
|
||||
var logprobs *int
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
logprobs = input.TopLogprobs
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model inference failed: %w", err)
|
||||
openAIReq.TopLogprobs = input.TopLogprobs
|
||||
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
|
||||
}
|
||||
openAIReq.LogitBias = input.LogitBias
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -914,24 +895,19 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
default:
|
||||
}
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var prediction backend.LLMResponse
|
||||
var result string
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("prediction failed: %w", err)
|
||||
}
|
||||
result = backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
if result != "" || !shouldUseFn {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
xlog.Warn("Open Responses background: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
cb := func(s string, c *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model inference failed: %w", err)
|
||||
}
|
||||
|
||||
// Extract logprobs from choices if available
|
||||
var resultLogprobs *schema.Logprobs
|
||||
if len(choices) > 0 {
|
||||
resultLogprobs = choices[0].Logprobs
|
||||
}
|
||||
|
||||
// Parse tool calls
|
||||
@@ -939,9 +915,9 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
var textContent string
|
||||
|
||||
if shouldUseFn {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
funcCallResults = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
textContent = functions.ContentFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||
funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
@@ -1021,7 +997,7 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed", Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
for _, tc := range toolCalls {
|
||||
@@ -1034,22 +1010,22 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed", Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
allOutputItems = append(allOutputItems, schema.ORItemField{
|
||||
Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed", Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
return buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, allOutputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: tokenUsage.Prompt,
|
||||
OutputTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
}, true), nil
|
||||
} // end MCP iteration loop
|
||||
|
||||
@@ -1058,23 +1034,14 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
|
||||
// handleBackgroundStream handles background streaming responses with event buffering
|
||||
func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) {
|
||||
images := []string{}
|
||||
videos := []string{}
|
||||
audios := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
toolsJSON := serializeToolsForBackend(input.Tools)
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
openAIReq.TopLogprobs = input.TopLogprobs
|
||||
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
|
||||
}
|
||||
openAIReq.LogitBias = input.LogitBias
|
||||
|
||||
sequenceNumber := 0
|
||||
|
||||
@@ -1105,20 +1072,13 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
}
|
||||
hasMCPTools := len(mcpToolInfos) > 0
|
||||
|
||||
var prediction backend.LLMResponse
|
||||
var lastTokenUsage backend.TokenUsage
|
||||
var lastLogprobs *schema.Logprobs
|
||||
|
||||
for mcpIter := 0; mcpIter <= mcpBgStreamMaxIterations; mcpIter++ {
|
||||
if mcpIter > 0 {
|
||||
predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
|
||||
xlog.Debug("Background stream MCP re-templating", "iteration", mcpIter)
|
||||
images = images[:0]
|
||||
videos = videos[:0]
|
||||
audios = audios[:0]
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
}
|
||||
|
||||
accumulatedText = ""
|
||||
@@ -1177,28 +1137,23 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
return true
|
||||
}
|
||||
|
||||
var streamLogprobs *int
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
streamLogprobs = input.TopLogprobs
|
||||
var result string
|
||||
cb := func(s string, c *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, tokenCallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model inference failed: %w", err)
|
||||
}
|
||||
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("prediction failed: %w", err)
|
||||
lastTokenUsage = tokenUsage
|
||||
if len(choices) > 0 {
|
||||
lastLogprobs = choices[0].Logprobs
|
||||
}
|
||||
|
||||
result := backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
|
||||
// Check for MCP tool calls in the streamed result
|
||||
if shouldUseFn && hasMCPTools {
|
||||
var funcCallResults []functions.FuncCallResults
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
funcCallResults = deltaToolCalls
|
||||
} else {
|
||||
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||
@@ -1315,7 +1270,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
}
|
||||
|
||||
// No MCP tools — close the message and break
|
||||
streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
|
||||
streamEventLogprobs := convertLogprobsForStreaming(lastLogprobs)
|
||||
bufferEvent(store, responseID, &schema.ORStreamEvent{
|
||||
Type: "response.output_text.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -1327,7 +1282,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
})
|
||||
sequenceNumber++
|
||||
|
||||
textPart := makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs)
|
||||
textPart := makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs)
|
||||
bufferEvent(store, responseID, &schema.ORStreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -1343,7 +1298,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
ID: currentMessageID,
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs)},
|
||||
}
|
||||
bufferEvent(store, responseID, &schema.ORStreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
@@ -1360,9 +1315,9 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI
|
||||
// Build final response
|
||||
now := time.Now().Unix()
|
||||
response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, collectedOutputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: lastTokenUsage.Prompt,
|
||||
OutputTokens: lastTokenUsage.Completion,
|
||||
TotalTokens: lastTokenUsage.Prompt + lastTokenUsage.Completion,
|
||||
}, true)
|
||||
|
||||
// Emit response.completed
|
||||
@@ -1391,52 +1346,27 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
if mcpIteration > mcpMaxIterations {
|
||||
return sendOpenResponsesError(c, 500, "server_error", "MCP iteration limit reached", "")
|
||||
}
|
||||
images := []string{}
|
||||
videos := []string{}
|
||||
audios := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
// Convert and serialize tools to OpenAI format for the backend
|
||||
toolsJSON := serializeToolsForBackend(input.Tools)
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Pass logprobs and logit_bias parameters if requested
|
||||
var logprobs *int
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
logprobs = input.TopLogprobs
|
||||
openAIReq.TopLogprobs = input.TopLogprobs
|
||||
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
|
||||
}
|
||||
openAIReq.LogitBias = input.LogitBias
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
var result string
|
||||
cb := func(s string, c *[]schema.Choice) {
|
||||
result = s
|
||||
}
|
||||
choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil)
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses model inference failed", "error", err)
|
||||
return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("model inference failed: %v", err), "")
|
||||
}
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var prediction backend.LLMResponse
|
||||
var result string
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses prediction failed", "error", err)
|
||||
return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("prediction failed: %v", err), "")
|
||||
}
|
||||
result = backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
if result != "" || !shouldUseFn {
|
||||
break
|
||||
}
|
||||
xlog.Warn("Open Responses: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
var resultLogprobs *schema.Logprobs
|
||||
if len(choices) > 0 {
|
||||
resultLogprobs = choices[0].Logprobs
|
||||
}
|
||||
xlog.Debug("Open Responses - Raw model result", "result", result, "shouldUseFn", shouldUseFn)
|
||||
|
||||
@@ -1473,10 +1403,10 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
var textContent string
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
funcCallResults = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
textContent = functions.ContentFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
// Clean up the result (already extracted reasoning above)
|
||||
@@ -1574,7 +1504,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1605,7 +1535,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
@@ -1615,7 +1545,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
ID: fmt.Sprintf("msg_%s", uuid.New().String()),
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)},
|
||||
}
|
||||
outputItems = append(outputItems, messageItem)
|
||||
}
|
||||
@@ -1633,9 +1563,9 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
// Build response with all required fields
|
||||
now := time.Now().Unix()
|
||||
response := buildORResponse(responseID, createdAt, &now, "completed", input, outputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: tokenUsage.Prompt,
|
||||
OutputTokens: tokenUsage.Completion,
|
||||
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
|
||||
OutputTokensDetails: &schema.OROutputTokensDetails{
|
||||
ReasoningTokens: reasoningTokens,
|
||||
},
|
||||
@@ -1675,24 +1605,14 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
})
|
||||
sequenceNumber++
|
||||
|
||||
images := []string{}
|
||||
videos := []string{}
|
||||
audios := []string{}
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
// Convert and serialize tools to OpenAI format for the backend
|
||||
toolsJSON := serializeToolsForBackend(input.Tools)
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
// Populate openAIReq fields for ComputeChoices
|
||||
openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools)
|
||||
openAIReq.ToolsChoice = input.ToolChoice
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
openAIReq.TopLogprobs = input.TopLogprobs
|
||||
openAIReq.Logprobs = schema.LogprobsValue{Enabled: true}
|
||||
}
|
||||
openAIReq.LogitBias = input.LogitBias
|
||||
|
||||
// Detect if thinking token is already in prompt or template
|
||||
var template string
|
||||
@@ -1714,10 +1634,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
// Track reasoning state for streaming
|
||||
var currentReasoningID string
|
||||
var currentReasoningContentIndex int
|
||||
var accumulatedContent string
|
||||
var lastEmittedReasoning string
|
||||
var lastEmittedCleanedContent string
|
||||
var reasoningTokens int
|
||||
extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig)
|
||||
|
||||
// Collect all output items for storage
|
||||
var collectedOutputItems []schema.ORItemField
|
||||
@@ -1729,34 +1647,25 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
hasMCPToolsStream := len(mcpToolInfos) > 0
|
||||
|
||||
var prediction backend.LLMResponse
|
||||
var result, finalReasoning, finalCleanedResult string
|
||||
var textContent string
|
||||
var parsedToolCalls []functions.FuncCallResults
|
||||
var toolCalls []functions.FuncCallResults
|
||||
var lastStreamTokenUsage backend.TokenUsage
|
||||
var lastStreamLogprobs *schema.Logprobs
|
||||
|
||||
for mcpStreamIter := 0; mcpStreamIter <= mcpStreamMaxIterations; mcpStreamIter++ {
|
||||
if mcpStreamIter > 0 {
|
||||
// Reset reasoning and tool-call state for re-inference so reasoning
|
||||
// extraction runs again on subsequent iterations
|
||||
inToolCallMode = false
|
||||
accumulatedContent = ""
|
||||
lastEmittedReasoning = ""
|
||||
lastEmittedCleanedContent = ""
|
||||
extractor.Reset()
|
||||
currentMessageID = ""
|
||||
lastEmittedToolCallCount = 0
|
||||
currentReasoningID = ""
|
||||
|
||||
predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn)
|
||||
xlog.Debug("Open Responses stream MCP re-templating", "iteration", mcpStreamIter)
|
||||
images = images[:0]
|
||||
videos = videos[:0]
|
||||
audios = audios[:0]
|
||||
for _, m := range openAIReq.Messages {
|
||||
images = append(images, m.StringImages...)
|
||||
videos = append(videos, m.StringVideos...)
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
}
|
||||
|
||||
// For tool calls, we need to track accumulated result and parse incrementally
|
||||
@@ -1911,11 +1820,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
// If no tool calls detected yet, handle reasoning and text
|
||||
if !inToolCallMode {
|
||||
accumulatedContent += token
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, cfg.ReasoningConfig)
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(token)
|
||||
|
||||
// Handle reasoning item
|
||||
if currentReasoning != "" {
|
||||
if extractor.Reasoning() != "" {
|
||||
// Check if we need to create reasoning item
|
||||
if currentReasoningID == "" {
|
||||
outputIndex++
|
||||
@@ -1947,16 +1855,6 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
sequenceNumber++
|
||||
}
|
||||
|
||||
// Calculate reasoning delta
|
||||
var reasoningDelta string
|
||||
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
|
||||
reasoningDelta = currentReasoning[len(lastEmittedReasoning):]
|
||||
lastEmittedReasoning = currentReasoning
|
||||
} else if currentReasoning != lastEmittedReasoning {
|
||||
reasoningDelta = currentReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
}
|
||||
|
||||
// Emit reasoning delta if there's new content
|
||||
if reasoningDelta != "" {
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -1973,23 +1871,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
}
|
||||
|
||||
// Handle message content (cleaned content without reasoning tags)
|
||||
var deltaContent string
|
||||
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
|
||||
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else if cleanedContent != lastEmittedCleanedContent {
|
||||
if lastEmittedCleanedContent == "" {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
}
|
||||
}
|
||||
|
||||
// Only emit message content if there's actual content (not just reasoning)
|
||||
if deltaContent != "" {
|
||||
if contentDelta != "" {
|
||||
if currentMessageID == "" {
|
||||
// Emit output_item.added for message
|
||||
outputIndex++
|
||||
@@ -2030,7 +1913,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
ItemID: currentMessageID,
|
||||
OutputIndex: &outputIndex,
|
||||
ContentIndex: ¤tContentIndex,
|
||||
Delta: strPtr(deltaContent),
|
||||
Delta: strPtr(contentDelta),
|
||||
Logprobs: emptyLogprobs(),
|
||||
})
|
||||
sequenceNumber++
|
||||
@@ -2040,14 +1923,11 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass logprobs and logit_bias parameters if requested
|
||||
var streamLogprobs *int
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
streamLogprobs = input.TopLogprobs
|
||||
var ccResult string
|
||||
ccCb := func(s string, c *[]schema.Choice) {
|
||||
ccResult = s
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
choices, ccTokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, ccCb, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses stream model inference failed", "error", err)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -2071,36 +1951,27 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
prediction, err = predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses stream prediction failed", "error", err)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "error",
|
||||
SequenceNumber: sequenceNumber,
|
||||
Error: &schema.ORErrorPayload{
|
||||
Type: "model_error",
|
||||
Message: fmt.Sprintf("prediction failed: %v", err),
|
||||
},
|
||||
})
|
||||
sequenceNumber++
|
||||
responseFailed := responseCreated
|
||||
responseFailed.Status = "failed"
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.failed",
|
||||
SequenceNumber: sequenceNumber,
|
||||
Response: responseFailed,
|
||||
})
|
||||
// Send [DONE] even on error
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
result = ccResult
|
||||
lastStreamTokenUsage = ccTokenUsage
|
||||
if len(choices) > 0 {
|
||||
lastStreamLogprobs = choices[0].Logprobs
|
||||
}
|
||||
|
||||
result = backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
|
||||
// Extract reasoning from final result
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
// Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's
|
||||
// streaming state, (3) final extraction from the finetuned result.
|
||||
if chatDeltaReasoning := functions.ReasoningFromChatDeltas(chatDeltas); chatDeltaReasoning != "" {
|
||||
finalReasoning = chatDeltaReasoning
|
||||
finalCleanedResult = functions.ContentFromChatDeltas(chatDeltas)
|
||||
if finalCleanedResult == "" {
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
} else {
|
||||
finalReasoning = extractor.Reasoning()
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
if finalReasoning == "" && finalCleanedResult == "" {
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
}
|
||||
|
||||
// Close reasoning item if it exists and wasn't closed yet
|
||||
if currentReasoningID != "" && finalReasoning != "" {
|
||||
@@ -2157,10 +2028,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
textContent = ""
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
parsedToolCalls = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
textContent = functions.ContentFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig)
|
||||
@@ -2279,8 +2150,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
|
||||
|
||||
// Convert prediction logprobs for streaming events
|
||||
streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
|
||||
// Convert logprobs for streaming events
|
||||
streamEventLogprobs := convertLogprobsForStreaming(lastStreamLogprobs)
|
||||
|
||||
// If we have no output but the model did produce something, use the cleaned result (without reasoning tags)
|
||||
if textContent == "" && len(toolCalls) == 0 && finalCleanedResult != "" {
|
||||
@@ -2303,7 +2174,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
sequenceNumber++
|
||||
|
||||
// Emit content_part.done (with actual logprobs)
|
||||
textPart := makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)
|
||||
textPart := makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -2320,7 +2191,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
ID: currentMessageID,
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)},
|
||||
}
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
@@ -2389,7 +2260,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
ID: currentMessageID,
|
||||
Status: "completed",
|
||||
Role: "assistant",
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)},
|
||||
Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)},
|
||||
})
|
||||
}
|
||||
// Add tool call items
|
||||
@@ -2408,9 +2279,9 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
// Emit response.completed
|
||||
now := time.Now().Unix()
|
||||
responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, allOutputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: lastStreamTokenUsage.Prompt,
|
||||
OutputTokens: lastStreamTokenUsage.Completion,
|
||||
TotalTokens: lastStreamTokenUsage.Prompt + lastStreamTokenUsage.Completion,
|
||||
OutputTokensDetails: &schema.OROutputTokensDetails{
|
||||
ReasoningTokens: reasoningTokens,
|
||||
},
|
||||
@@ -2469,12 +2340,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
// Stream text deltas with reasoning extraction
|
||||
tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool {
|
||||
accumulatedText += token
|
||||
accumulatedContent += token
|
||||
// Prepend thinking token if needed, then extract reasoning
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, cfg.ReasoningConfig)
|
||||
reasoningDelta, contentDelta := extractor.ProcessToken(token)
|
||||
|
||||
// Handle reasoning item
|
||||
if currentReasoning != "" {
|
||||
if extractor.Reasoning() != "" {
|
||||
// Check if we need to create reasoning item
|
||||
if currentReasoningID == "" {
|
||||
outputIndex++
|
||||
@@ -2506,16 +2375,6 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
sequenceNumber++
|
||||
}
|
||||
|
||||
// Calculate reasoning delta
|
||||
var reasoningDelta string
|
||||
if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) {
|
||||
reasoningDelta = currentReasoning[len(lastEmittedReasoning):]
|
||||
lastEmittedReasoning = currentReasoning
|
||||
} else if currentReasoning != lastEmittedReasoning {
|
||||
reasoningDelta = currentReasoning
|
||||
lastEmittedReasoning = currentReasoning
|
||||
}
|
||||
|
||||
// Emit reasoning delta if there's new content
|
||||
if reasoningDelta != "" {
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -2532,23 +2391,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
}
|
||||
|
||||
// Handle message content (cleaned content without reasoning tags)
|
||||
var deltaContent string
|
||||
if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) {
|
||||
deltaContent = cleanedContent[len(lastEmittedCleanedContent):]
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else if cleanedContent != lastEmittedCleanedContent {
|
||||
if lastEmittedCleanedContent == "" {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
} else {
|
||||
deltaContent = cleanedContent
|
||||
lastEmittedCleanedContent = cleanedContent
|
||||
}
|
||||
}
|
||||
|
||||
// Only emit message content if there's actual content (not just reasoning)
|
||||
if deltaContent != "" {
|
||||
if contentDelta != "" {
|
||||
// Emit text delta
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
@@ -2556,7 +2400,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
ItemID: currentMessageID,
|
||||
OutputIndex: &outputIndex,
|
||||
ContentIndex: ¤tContentIndex,
|
||||
Delta: strPtr(deltaContent),
|
||||
Delta: strPtr(contentDelta),
|
||||
Logprobs: emptyLogprobs(),
|
||||
})
|
||||
sequenceNumber++
|
||||
@@ -2565,14 +2409,11 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
return true
|
||||
}
|
||||
|
||||
// Pass logprobs and logit_bias parameters if requested
|
||||
var mcpLogprobs *int
|
||||
if input.TopLogprobs != nil && *input.TopLogprobs > 0 {
|
||||
mcpLogprobs = input.TopLogprobs
|
||||
var noToolResult string
|
||||
noToolCb := func(s string, c *[]schema.Choice) {
|
||||
noToolResult = s
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(
|
||||
input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, mcpLogprobs, input.TopLogprobs, input.LogitBias, nil)
|
||||
noToolChoices, noToolTokenUsage, noToolChatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, noToolCb, tokenCallback)
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses stream model inference failed", "error", err)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -2596,36 +2437,28 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
xlog.Error("Open Responses stream prediction failed", "error", err)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "error",
|
||||
SequenceNumber: sequenceNumber,
|
||||
Error: &schema.ORErrorPayload{
|
||||
Type: "model_error",
|
||||
Message: fmt.Sprintf("prediction failed: %v", err),
|
||||
},
|
||||
})
|
||||
sequenceNumber++
|
||||
responseFailed := responseCreated
|
||||
responseFailed.Status = "failed"
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.failed",
|
||||
SequenceNumber: sequenceNumber,
|
||||
Response: responseFailed,
|
||||
})
|
||||
// Send [DONE] even on error
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
result := noToolResult
|
||||
var noToolLogprobs *schema.Logprobs
|
||||
if len(noToolChoices) > 0 {
|
||||
noToolLogprobs = noToolChoices[0].Logprobs
|
||||
}
|
||||
|
||||
result := backend.Finetune(*cfg, predInput, prediction.Response)
|
||||
|
||||
// Extract reasoning from final result for non-tool-call path
|
||||
finalReasoning, finalCleanedResult := reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
// Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's
|
||||
// streaming state, (3) final extraction from the finetuned result.
|
||||
var finalReasoning, finalCleanedResult string
|
||||
if chatDeltaReasoning := functions.ReasoningFromChatDeltas(noToolChatDeltas); chatDeltaReasoning != "" {
|
||||
finalReasoning = chatDeltaReasoning
|
||||
finalCleanedResult = functions.ContentFromChatDeltas(noToolChatDeltas)
|
||||
if finalCleanedResult == "" {
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
} else {
|
||||
finalReasoning = extractor.Reasoning()
|
||||
finalCleanedResult = extractor.CleanedContent()
|
||||
}
|
||||
if finalReasoning == "" && finalCleanedResult == "" {
|
||||
finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig)
|
||||
}
|
||||
|
||||
// Close reasoning item if it exists and wasn't closed yet
|
||||
if currentReasoningID != "" && finalReasoning != "" {
|
||||
@@ -2680,8 +2513,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
result = finalCleanedResult
|
||||
|
||||
// Convert prediction logprobs for streaming events
|
||||
mcpStreamLogprobs := convertLogprobsForStreaming(prediction.Logprobs)
|
||||
// Convert logprobs for streaming events
|
||||
mcpStreamLogprobs := convertLogprobsForStreaming(noToolLogprobs)
|
||||
|
||||
// Emit output_text.done
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
@@ -2696,7 +2529,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
sequenceNumber++
|
||||
|
||||
// Emit content_part.done (with actual logprobs)
|
||||
resultPart := makeOutputTextPartWithLogprobs(result, prediction.Logprobs)
|
||||
resultPart := makeOutputTextPartWithLogprobs(result, noToolLogprobs)
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.content_part.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -2709,7 +2542,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
|
||||
// Emit output_item.done (with actual logprobs)
|
||||
messageItem.Status = "completed"
|
||||
messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}
|
||||
messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, noToolLogprobs)}
|
||||
sendSSEEvent(c, &schema.ORStreamEvent{
|
||||
Type: "response.output_item.done",
|
||||
SequenceNumber: sequenceNumber,
|
||||
@@ -2744,9 +2577,9 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
finalOutputItems = append(finalOutputItems, *messageItem)
|
||||
}
|
||||
responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, finalOutputItems, &schema.ORUsage{
|
||||
InputTokens: prediction.Usage.Prompt,
|
||||
OutputTokens: prediction.Usage.Completion,
|
||||
TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion,
|
||||
InputTokens: noToolTokenUsage.Prompt,
|
||||
OutputTokens: noToolTokenUsage.Completion,
|
||||
TotalTokens: noToolTokenUsage.Prompt + noToolTokenUsage.Completion,
|
||||
OutputTokensDetails: &schema.OROutputTokensDetails{
|
||||
ReasoningTokens: reasoningTokens,
|
||||
},
|
||||
@@ -3035,19 +2868,6 @@ func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.T
|
||||
return result
|
||||
}
|
||||
|
||||
// serializeToolsForBackend converts and serializes Open Responses tools to JSON for the backend
|
||||
func serializeToolsForBackend(orTools []schema.ORFunctionTool) string {
|
||||
if len(orTools) == 0 {
|
||||
return ""
|
||||
}
|
||||
openAITools := convertORToolsToOpenAIFormat(orTools)
|
||||
toolsBytes, err := json.Marshal(openAITools)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(toolsBytes)
|
||||
}
|
||||
|
||||
// GetResponseEndpoint returns a handler for GET /responses/:id
|
||||
// This endpoint is used for polling background responses or resuming streaming
|
||||
// @Summary Get a response by ID
|
||||
|
||||
Reference in New Issue
Block a user