From ee96e5e08d044f255995015ecd114fe9cab9be64 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 16 Mar 2026 21:31:02 +0100 Subject: [PATCH] chore: refactor endpoints to use same inferencing path, add automatic retrial mechanism in case of errors (#9029) Signed-off-by: Ettore Di Giacinto --- core/backend/llm.go | 4 + core/http/endpoints/anthropic/messages.go | 100 +--- core/http/endpoints/openai/chat.go | 401 ++++++-------- core/http/endpoints/openai/chat_test.go | 157 ++++++ core/http/endpoints/openai/inference.go | 69 ++- core/http/endpoints/openai/inference_test.go | 402 ++++++++++++++ .../http/endpoints/openresponses/responses.go | 490 ++++++------------ pkg/reasoning/extractor.go | 104 ++++ pkg/reasoning/extractor_test.go | 198 +++++++ 9 files changed, 1263 insertions(+), 662 deletions(-) create mode 100644 core/http/endpoints/openai/chat_test.go create mode 100644 core/http/endpoints/openai/inference_test.go create mode 100644 pkg/reasoning/extractor.go create mode 100644 pkg/reasoning/extractor_test.go diff --git a/core/backend/llm.go b/core/backend/llm.go index db407e5a1..1316f72e4 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -38,6 +38,10 @@ type TokenUsage struct { TimingTokenGeneration float64 } +// ModelInferenceFunc is a test-friendly indirection to call model inference logic. +// Tests can override this variable to provide a stub implementation. +var ModelInferenceFunc = ModelInference + func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64, metadata map[string]string) (func() (LLMResponse, error), error) { modelFile := c.Model diff --git a/core/http/endpoints/anthropic/messages.go b/core/http/endpoints/anthropic/messages.go index 5119f2df5..a08230510 100644 --- a/core/http/endpoints/anthropic/messages.go +++ b/core/http/endpoints/anthropic/messages.go @@ -10,6 +10,7 @@ import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" + openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" @@ -197,54 +198,24 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic xlog.Debug("Anthropic MCP re-templating", "iteration", mcpIteration, "prompt_len", len(predInput)) } - images := []string{} - for _, m := range openAIReq.Messages { - images = append(images, m.StringImages...) - } + // Populate openAIReq fields for ComputeChoices + openAIReq.Tools = convertFuncsToOpenAITools(funcs) + openAIReq.ToolsChoice = input.ToolChoice + openAIReq.Metadata = input.Metadata - toolsJSON := "" - if len(funcs) > 0 { - openAITools := make([]functions.Tool, len(funcs)) - for i, f := range funcs { - openAITools[i] = functions.Tool{Type: "function", Function: f} - } - if toolsBytes, err := json.Marshal(openAITools); err == nil { - toolsJSON = string(toolsBytes) - } + var result string + cb := func(s string, c *[]schema.Choice) { + result = s } - toolChoiceJSON := "" - if input.ToolChoice != nil { - if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil { - toolChoiceJSON = string(toolChoiceBytes) - } - } - - predFunc, err := backend.ModelInference( - input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, nil, nil, nil, input.Metadata) + _, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil) if err != nil { xlog.Error("Anthropic model inference failed", "error", err) return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) } - const maxEmptyRetries = 5 - var prediction backend.LLMResponse - var result string - for attempt := 0; attempt <= maxEmptyRetries; attempt++ { - prediction, err = predFunc() - if err != nil { - xlog.Error("Anthropic prediction failed", "error", err) - return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err)) - } - result = backend.Finetune(*cfg, predInput, prediction.Response) - if result != "" || !shouldUseFn { - break - } - xlog.Warn("Anthropic: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries) - } - // Try pre-parsed tool calls from C++ autoparser first, fall back to text parsing var toolCalls []functions.FuncCallResults - if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 { + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { xlog.Debug("[ChatDeltas] Anthropic: using pre-parsed tool calls", "count", len(deltaToolCalls)) toolCalls = deltaToolCalls } else { @@ -350,8 +321,8 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic StopReason: &stopReason, Content: contentBlocks, Usage: schema.AnthropicUsage{ - InputTokens: prediction.Usage.Prompt, - OutputTokens: prediction.Usage.Completion, + InputTokens: tokenUsage.Prompt, + OutputTokens: tokenUsage.Completion, }, } @@ -397,12 +368,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq xlog.Debug("Anthropic MCP stream re-templating", "iteration", mcpIteration) } - openAIMessages := openAIReq.Messages - images := []string{} - for _, m := range openAIMessages { - images = append(images, m.StringImages...) - } - // Track accumulated content for tool call detection accumulatedContent := "" currentBlockIndex := 0 @@ -481,38 +446,19 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq return true } - toolsJSON := "" - if len(funcs) > 0 { - openAITools := make([]functions.Tool, len(funcs)) - for i, f := range funcs { - openAITools[i] = functions.Tool{Type: "function", Function: f} - } - if toolsBytes, err := json.Marshal(openAITools); err == nil { - toolsJSON = string(toolsBytes) - } - } - toolChoiceJSON := "" - if input.ToolChoice != nil { - if toolChoiceBytes, err := json.Marshal(input.ToolChoice); err == nil { - toolChoiceJSON = string(toolChoiceBytes) - } - } + // Populate openAIReq fields for ComputeChoices + openAIReq.Tools = convertFuncsToOpenAITools(funcs) + openAIReq.ToolsChoice = input.ToolChoice + openAIReq.Metadata = input.Metadata - predFunc, err := backend.ModelInference( - input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, nil, nil, nil, input.Metadata) + _, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback) if err != nil { xlog.Error("Anthropic stream model inference failed", "error", err) return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) } - prediction, err := predFunc() - if err != nil { - xlog.Error("Anthropic stream prediction failed", "error", err) - return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err)) - } - // Also check chat deltas for tool calls - if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 { + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 { collectedToolCalls = deltaToolCalls } @@ -595,7 +541,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq StopReason: &stopReason, }, Usage: &schema.AnthropicUsage{ - OutputTokens: prediction.Usage.Completion, + OutputTokens: tokenUsage.Completion, }, }) @@ -613,6 +559,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq return nil } +func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool { + tools := make([]functions.Tool, len(funcs)) + for i, f := range funcs { + tools[i] = functions.Tool{Type: "function", Function: f} + } + return tools +} + func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) { data, err := json.Marshal(event) if err != nil { diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 04bd28b36..db7e2613b 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -82,51 +82,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator template = s } thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig) - - // Track accumulated content for reasoning extraction - accumulatedContent := "" - lastEmittedReasoning := "" - lastEmittedCleanedContent := "" + extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig) _, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { - accumulatedContent += s - - currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig) - - // Calculate new reasoning delta (what we haven't emitted yet) - var reasoningDelta *string - if currentReasoning != lastEmittedReasoning { - // Extract only the new part - if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) { - newReasoning := currentReasoning[len(lastEmittedReasoning):] - reasoningDelta = &newReasoning - lastEmittedReasoning = currentReasoning - } else if currentReasoning != "" { - // If reasoning changed in a non-append way, emit the full current reasoning - reasoningDelta = ¤tReasoning - lastEmittedReasoning = currentReasoning - } - } - - // Calculate content delta from cleaned content - var deltaContent string - if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) { - deltaContent = cleanedContent[len(lastEmittedCleanedContent):] - lastEmittedCleanedContent = cleanedContent - } else if cleanedContent != lastEmittedCleanedContent { - // If cleaned content changed but not in a simple append, extract delta from cleaned content - // This handles cases where thinking tags are removed mid-stream - if lastEmittedCleanedContent == "" { - deltaContent = cleanedContent - lastEmittedCleanedContent = cleanedContent - } else { - // Content changed in non-append way, use the new cleaned content - deltaContent = cleanedContent - lastEmittedCleanedContent = cleanedContent - } - } - // Only emit content if there's actual content (not just thinking tags) - // If deltaContent is empty, we still emit the response but with empty content + reasoningDelta, contentDelta := extractor.ProcessToken(s) usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, @@ -139,12 +98,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } delta := &schema.Message{} - // Only include content if there's actual content (not just thinking tags) - if deltaContent != "" { - delta.Content = &deltaContent + if contentDelta != "" { + delta.Content = &contentDelta } - if reasoningDelta != nil && *reasoningDelta != "" { - delta.Reasoning = reasoningDelta + if reasoningDelta != "" { + delta.Reasoning = &reasoningDelta } resp := schema.OpenAIResponse{ @@ -171,43 +129,25 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator template = prompt } thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig) + extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig) result := "" lastEmittedCount := 0 - - // Track accumulated content for incremental reasoning and content extraction (mirrors process()) - accumulatedContent := "" - lastEmittedReasoning := "" - lastEmittedCleanedContent := "" sentInitialRole := false _, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { result += s - accumulatedContent += s + reasoningDelta, contentDelta := extractor.ProcessToken(s) - // Incremental reasoning extraction — emit reasoning deltas in their own SSE chunks - // before any tool-call chunks (OpenAI spec: reasoning and tool_calls never share a delta) - currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig) - - var reasoningDelta *string - if currentReasoning != lastEmittedReasoning { - if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) { - newReasoning := currentReasoning[len(lastEmittedReasoning):] - reasoningDelta = &newReasoning - lastEmittedReasoning = currentReasoning - } else if currentReasoning != "" { - reasoningDelta = ¤tReasoning - lastEmittedReasoning = currentReasoning - } - } - - if reasoningDelta != nil && *reasoningDelta != "" { + // Emit reasoning deltas in their own SSE chunks before any tool-call chunks + // (OpenAI spec: reasoning and tool_calls never share a delta) + if reasoningDelta != "" { responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{ - Delta: &schema.Message{Reasoning: reasoningDelta}, + Delta: &schema.Message{Reasoning: &reasoningDelta}, Index: 0, }}, Object: "chat.completion.chunk", @@ -217,32 +157,22 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // Stream content deltas (cleaned of reasoning tags) while no tool calls // have been detected. Once the incremental parser finds tool calls, // content stops — per OpenAI spec, content and tool_calls don't mix. - if lastEmittedCount == 0 && cleanedContent != "" { - var deltaContent string - if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) { - deltaContent = cleanedContent[len(lastEmittedCleanedContent):] - lastEmittedCleanedContent = cleanedContent - } else if cleanedContent != lastEmittedCleanedContent { - deltaContent = cleanedContent - lastEmittedCleanedContent = cleanedContent - } - if deltaContent != "" { - if !sentInitialRole { - responses <- schema.OpenAIResponse{ - ID: id, Created: created, Model: req.Model, - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}}, - Object: "chat.completion.chunk", - } - sentInitialRole = true - } + if lastEmittedCount == 0 && contentDelta != "" { + if !sentInitialRole { responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, - Choices: []schema.Choice{{ - Delta: &schema.Message{Content: &deltaContent}, - Index: 0, - }}, - Object: "chat.completion.chunk", + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}}, + Object: "chat.completion.chunk", } + sentInitialRole = true + } + responses <- schema.OpenAIResponse{ + ID: id, Created: created, Model: req.Model, + Choices: []schema.Choice{{ + Delta: &schema.Message{Content: &contentDelta}, + Index: 0, + }}, + Object: "chat.completion.chunk", } } @@ -349,7 +279,25 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } } return true - }) + }, + func(attempt int) bool { + // After streaming completes: check if we got actionable content + cleaned := extractor.CleanedContent() + // Check for tool calls from chat deltas (will be re-checked after ComputeChoices, + // but we need to know here whether to retry) + hasToolCalls := lastEmittedCount > 0 + if cleaned == "" && !hasToolCalls { + xlog.Warn("Streaming: backend produced only reasoning, retrying", + "reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1) + extractor.ResetAndSuppressReasoning() + result = "" + lastEmittedCount = 0 + sentInitialRole = false + return true + } + return false + }, + ) if err != nil { return err } @@ -366,10 +314,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } else { // Fallback: parse tool calls from raw text (no chat deltas from backend) xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing") - reasoning, result = reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig) - textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig) - result = functions.CleanupLLMResult(result, config.FunctionsConfig) - functionResults = functions.ParseFunctionCall(result, config.FunctionsConfig) + reasoning = extractor.Reasoning() + cleanedResult := extractor.CleanedContent() + textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig) + cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig) + functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig) } xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", textContentToReturn) noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0 @@ -389,7 +338,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator if sentInitialRole { // Content was already streamed during the callback — just emit usage. delta := &schema.Message{} - if reasoning != "" && lastEmittedReasoning == "" { + if reasoning != "" && extractor.Reasoning() == "" { delta.Reasoning = &reasoning } responses <- schema.OpenAIResponse{ @@ -406,7 +355,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator Object: "chat.completion.chunk", } - result, err := handleQuestion(config, functionResults, result, prompt) + result, err := handleQuestion(config, functionResults, extractor.CleanedContent(), prompt) if err != nil { xlog.Error("error handling question", "error", err) return err @@ -981,7 +930,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // is deferred to after ComputeChoices so we can check chat deltas first // and avoid redundant Go-side parsing. var cbRawResult, cbReasoning string - var emptyRetryNeeded bool tokenCallback := func(s string, c *[]schema.Choice) { reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig) @@ -1001,146 +949,133 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator cbReasoning = reasoning } - const maxEmptyRetries = 5 var result []schema.Choice var tokenUsage backend.TokenUsage var err error var chatDeltas []*pb.ChatDelta - for attempt := 0; attempt <= maxEmptyRetries; attempt++ { - emptyRetryNeeded = false - result, tokenUsage, chatDeltas, err = ComputeChoices( - input, - predInput, - config, - cl, - startupOptions, - ml, - tokenCallback, - nil, - ) - if err != nil { - break - } - - // Tool parsing is deferred here (only when shouldUseFn) - if shouldUseFn { - var funcResults []functions.FuncCallResults - - // Try pre-parsed tool calls from C++ autoparser first - if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { - xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls)) - funcResults = deltaToolCalls - textContentToReturn = functions.ContentFromChatDeltas(chatDeltas) - cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas) - } else { - // Fallback: parse tool calls from raw text - xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing") - textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig) - cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig) - funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) + result, tokenUsage, chatDeltas, err = ComputeChoices( + input, + predInput, + config, + cl, + startupOptions, + ml, + tokenCallback, + nil, + func(attempt int) bool { + if !shouldUseFn { + return false } - - noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0 - - switch { - case noActionsToRun: - if cbRawResult == "" && textContentToReturn == "" { - xlog.Warn("Backend returned empty content in tool-calling context, will retry") - emptyRetryNeeded = true - continue - } - qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput) - if qErr != nil { - xlog.Error("error handling question", "error", qErr) - emptyRetryNeeded = true - continue - } - - stopReason := FinishReasonStop - message := &schema.Message{Role: "assistant", Content: &qResult} - if cbReasoning != "" { - message.Reasoning = &cbReasoning - } - result = append(result, schema.Choice{ - FinishReason: &stopReason, - Message: message, - }) - default: - toolCallsReason := FinishReasonToolCalls - toolChoice := schema.Choice{ - FinishReason: &toolCallsReason, - Message: &schema.Message{ - Role: "assistant", - }, - } - if cbReasoning != "" { - toolChoice.Message.Reasoning = &cbReasoning - } - - for _, ss := range funcResults { - name, args := ss.Name, ss.Arguments - toolCallID := ss.ID - if toolCallID == "" { - toolCallID = id - } - if len(input.Tools) > 0 { - toolChoice.Message.Content = textContentToReturn - toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, - schema.ToolCall{ - ID: toolCallID, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: name, - Arguments: args, - }, - }, - ) - } else { - // Deprecated function_call format - functionCallReason := FinishReasonFunctionCall - message := &schema.Message{ - Role: "assistant", - Content: &textContentToReturn, - FunctionCall: map[string]interface{}{ - "name": name, - "arguments": args, - }, - } - if cbReasoning != "" { - message.Reasoning = &cbReasoning - } - result = append(result, schema.Choice{ - FinishReason: &functionCallReason, - Message: message, - }) - } - } - - if len(input.Tools) > 0 { - result = append(result, toolChoice) - } + // Retry when backend produced only reasoning and no content/tool calls. + // Full tool parsing is deferred until after ComputeChoices returns + // (when chat deltas are available), but we can detect the empty case here. + if cbRawResult == "" && textContentToReturn == "" { + xlog.Warn("Backend produced reasoning without actionable content, retrying", + "reasoning_len", len(cbReasoning), "attempt", attempt+1) + cbRawResult = "" + cbReasoning = "" + textContentToReturn = "" + return true } - } - - if !emptyRetryNeeded { - break - } - xlog.Warn("Retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries) - } + return false + }, + ) if err != nil { return err } - if emptyRetryNeeded { - xlog.Warn("All retries exhausted, backend still returning empty content") - stopReason := FinishReasonStop - empty := "" - result = append(result, schema.Choice{ - FinishReason: &stopReason, - Index: 0, - Message: &schema.Message{Role: "assistant", Content: &empty}, - }) + // Tool parsing is deferred here (only when shouldUseFn) so chat deltas are available + if shouldUseFn { + var funcResults []functions.FuncCallResults + + // Try pre-parsed tool calls from C++ autoparser first + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { + xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls)) + funcResults = deltaToolCalls + textContentToReturn = functions.ContentFromChatDeltas(chatDeltas) + cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas) + } else { + // Fallback: parse tool calls from raw text + xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing") + textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig) + cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig) + funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) + } + + noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0 + + switch { + case noActionsToRun: + qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput) + if qErr != nil { + xlog.Error("error handling question", "error", qErr) + } + + stopReason := FinishReasonStop + message := &schema.Message{Role: "assistant", Content: &qResult} + if cbReasoning != "" { + message.Reasoning = &cbReasoning + } + result = append(result, schema.Choice{ + FinishReason: &stopReason, + Message: message, + }) + default: + toolCallsReason := FinishReasonToolCalls + toolChoice := schema.Choice{ + FinishReason: &toolCallsReason, + Message: &schema.Message{ + Role: "assistant", + }, + } + if cbReasoning != "" { + toolChoice.Message.Reasoning = &cbReasoning + } + + for _, ss := range funcResults { + name, args := ss.Name, ss.Arguments + toolCallID := ss.ID + if toolCallID == "" { + toolCallID = id + } + if len(input.Tools) > 0 { + toolChoice.Message.Content = textContentToReturn + toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + schema.ToolCall{ + ID: toolCallID, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + ) + } else { + // Deprecated function_call format + functionCallReason := FinishReasonFunctionCall + message := &schema.Message{ + Role: "assistant", + Content: &textContentToReturn, + FunctionCall: map[string]interface{}{ + "name": name, + "arguments": args, + }, + } + if cbReasoning != "" { + message.Reasoning = &cbReasoning + } + result = append(result, schema.Choice{ + FinishReason: &functionCallReason, + Message: message, + }) + } + } + + if len(input.Tools) > 0 { + result = append(result, toolChoice) + } + } } // MCP server-side tool execution loop: @@ -1277,5 +1212,5 @@ func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCall xlog.Debug("No action received from LLM, without a message, computing a reply") - return "", fmt.Errorf("no action received from LLM, without a message, computing a reply") + return "", nil } diff --git a/core/http/endpoints/openai/chat_test.go b/core/http/endpoints/openai/chat_test.go new file mode 100644 index 000000000..51984bfb0 --- /dev/null +++ b/core/http/endpoints/openai/chat_test.go @@ -0,0 +1,157 @@ +package openai + +import ( + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/pkg/functions" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/schema" +) + +var _ = Describe("handleQuestion", func() { + var cfg *config.ModelConfig + + BeforeEach(func() { + cfg = &config.ModelConfig{} + }) + + Context("with no function results but non-empty result", func() { + It("should return the result directly", func() { + result, err := handleQuestion(cfg, nil, "Hello world", "prompt") + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(Equal("Hello world")) + }) + }) + + Context("with no function results and empty result", func() { + It("should return empty string", func() { + result, err := handleQuestion(cfg, nil, "", "prompt") + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(BeEmpty()) + }) + }) + + Context("with function result containing a message argument", func() { + It("should extract the message from function arguments", func() { + funcResults := []functions.FuncCallResults{ + { + Name: "answer", + Arguments: `{"message": "This is the answer"}`, + }, + } + result, err := handleQuestion(cfg, funcResults, "", "prompt") + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(Equal("This is the answer")) + }) + }) + + Context("with function result containing empty message", func() { + It("should return empty string when message is empty", func() { + funcResults := []functions.FuncCallResults{ + { + Name: "answer", + Arguments: `{"message": ""}`, + }, + } + result, err := handleQuestion(cfg, funcResults, "", "prompt") + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(BeEmpty()) + }) + }) + + Context("with function result containing invalid JSON arguments", func() { + It("should return empty string gracefully", func() { + funcResults := []functions.FuncCallResults{ + { + Name: "answer", + Arguments: "not json", + }, + } + result, err := handleQuestion(cfg, funcResults, "", "prompt") + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(BeEmpty()) + }) + }) + + Context("with cleaned content (no think tags)", func() { + It("should return content without think tags", func() { + // This tests the bug fix: handleQuestion should receive cleaned content, + // not raw text with tags + result, err := handleQuestion(cfg, nil, "Just the answer", "prompt") + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(Equal("Just the answer")) + Expect(result).ToNot(ContainSubstring("")) + }) + }) + + Context("with raw think tags passed as result", func() { + It("would return content with think tags", func() { + result, err := handleQuestion(cfg, nil, "reasoninganswer", "prompt") + Expect(err).ToNot(HaveOccurred()) + Expect(result).To(Equal("reasoninganswer")) + }) + }) +}) + +var _ = Describe("mergeToolCallDeltas", func() { + Context("with new tool calls", func() { + It("should append new tool calls", func() { + existing := []schema.ToolCall{} + deltas := []schema.ToolCall{ + {Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}}, + } + result := mergeToolCallDeltas(existing, deltas) + Expect(result).To(HaveLen(1)) + Expect(result[0].ID).To(Equal("tc1")) + Expect(result[0].FunctionCall.Name).To(Equal("search")) + }) + }) + + Context("with argument appending", func() { + It("should append arguments to existing tool call", func() { + existing := []schema.ToolCall{ + {Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search", Arguments: `{"q":`}}, + } + deltas := []schema.ToolCall{ + {Index: 0, FunctionCall: schema.FunctionCall{Arguments: `"hello"}`}}, + } + result := mergeToolCallDeltas(existing, deltas) + Expect(result).To(HaveLen(1)) + Expect(result[0].FunctionCall.Arguments).To(Equal(`{"q":"hello"}`)) + }) + }) + + Context("with multiple tool calls", func() { + It("should track multiple tool calls by index", func() { + existing := []schema.ToolCall{} + deltas1 := []schema.ToolCall{ + {Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}}, + } + result := mergeToolCallDeltas(existing, deltas1) + + deltas2 := []schema.ToolCall{ + {Index: 1, ID: "tc2", Type: "function", FunctionCall: schema.FunctionCall{Name: "browse"}}, + } + result = mergeToolCallDeltas(result, deltas2) + Expect(result).To(HaveLen(2)) + Expect(result[0].FunctionCall.Name).To(Equal("search")) + Expect(result[1].FunctionCall.Name).To(Equal("browse")) + }) + }) + + Context("with ID update on existing tool call", func() { + It("should update ID when provided in delta", func() { + existing := []schema.ToolCall{ + {Index: 0, FunctionCall: schema.FunctionCall{Name: "search"}}, + } + deltas := []schema.ToolCall{ + {Index: 0, ID: "new-id"}, + } + result := mergeToolCallDeltas(existing, deltas) + Expect(result).To(HaveLen(1)) + Expect(result[0].ID).To(Equal("new-id")) + Expect(result[0].FunctionCall.Name).To(Equal("search")) + }) + }) +}) diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index 46fd41445..9189870a0 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -2,6 +2,7 @@ package openai import ( "encoding/json" + "strings" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" @@ -9,6 +10,7 @@ import ( "github.com/mudler/LocalAI/core/schema" pb "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/xlog" ) func ComputeChoices( @@ -19,7 +21,9 @@ func ComputeChoices( o *config.ApplicationConfig, loader *model.ModelLoader, cb func(string, *[]schema.Choice), - tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) { + tokenCallback func(string, backend.TokenUsage) bool, + shouldRetry ...func(int) bool, +) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) { n := req.N // number of completions to return result := []schema.Choice{} @@ -27,6 +31,12 @@ func ComputeChoices( n = 1 } + // Extract the optional shouldRetry callback + var shouldRetryFn func(int) bool + if len(shouldRetry) > 0 { + shouldRetryFn = shouldRetry[0] + } + images := []string{} for _, m := range req.Messages { images = append(images, m.StringImages...) @@ -82,7 +92,7 @@ func ComputeChoices( } // get the model function to call for the result - predFunc, err := backend.ModelInference( + predFunc, err := backend.ModelInferenceFunc( req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, req.Metadata) if err != nil { return result, backend.TokenUsage{}, nil, err @@ -91,32 +101,49 @@ func ComputeChoices( tokenUsage := backend.TokenUsage{} var allChatDeltas []*pb.ChatDelta + const maxRetries = 5 + for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return result, backend.TokenUsage{}, nil, err + var prediction backend.LLMResponse + + for attempt := 0; attempt <= maxRetries; attempt++ { + p, err := predFunc() + if err != nil { + return result, backend.TokenUsage{}, nil, err + } + prediction = p + + // Built-in: retry on truly empty response (no tokens at all) + if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries { + xlog.Warn("Backend returned empty response, retrying", + "attempt", attempt+1, "maxRetries", maxRetries) + continue + } + + tokenUsage.Prompt = prediction.Usage.Prompt + tokenUsage.Completion = prediction.Usage.Completion + tokenUsage.TimingPromptProcessing = prediction.Usage.TimingPromptProcessing + tokenUsage.TimingTokenGeneration = prediction.Usage.TimingTokenGeneration + + allChatDeltas = prediction.ChatDeltas + + finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) + cb(finetunedResponse, &result) + + // Caller-driven retry (tool parsing, reasoning-only, etc.) + if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries { + // Caller has already reset its state inside shouldRetry + result = result[:0] + allChatDeltas = nil + continue + } + break } - tokenUsage.Prompt += prediction.Usage.Prompt - tokenUsage.Completion += prediction.Usage.Completion - tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing - tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration - - // Collect chat deltas from C++ autoparser - if len(prediction.ChatDeltas) > 0 { - allChatDeltas = append(allChatDeltas, prediction.ChatDeltas...) - } - - finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) - cb(finetunedResponse, &result) - // Add logprobs to the last choice if present if prediction.Logprobs != nil && len(result) > 0 { result[len(result)-1].Logprobs = prediction.Logprobs } - - //result = append(result, Choice{Text: prediction}) - } return result, tokenUsage, allChatDeltas, err } diff --git a/core/http/endpoints/openai/inference_test.go b/core/http/endpoints/openai/inference_test.go new file mode 100644 index 000000000..7b5ab39dc --- /dev/null +++ b/core/http/endpoints/openai/inference_test.go @@ -0,0 +1,402 @@ +package openai + +import ( + "context" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + model "github.com/mudler/LocalAI/pkg/model" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +type modelInferenceFunc = func( + ctx context.Context, s string, messages schema.Messages, + images, videos, audios []string, + loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, + o *config.ApplicationConfig, + tokenCallback func(string, backend.TokenUsage) bool, + tools, toolChoice string, + logprobs, topLogprobs *int, + logitBias map[string]float64, + metadata map[string]string, +) (func() (backend.LLMResponse, error), error) + +var _ = Describe("ComputeChoices", func() { + var ( + origInference modelInferenceFunc + cfg *config.ModelConfig + appCfg *config.ApplicationConfig + ) + + // mockInference installs a stub that yields the given responses sequentially. + // After all responses are consumed, the last one is repeated. + mockInference := func(responses []backend.LLMResponse) { + idx := 0 + backend.ModelInferenceFunc = func( + ctx context.Context, s string, messages schema.Messages, + images, videos, audios []string, + loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, + o *config.ApplicationConfig, + tokenCallback func(string, backend.TokenUsage) bool, + tools, toolChoice string, + logprobs, topLogprobs *int, + logitBias map[string]float64, + metadata map[string]string, + ) (func() (backend.LLMResponse, error), error) { + predFunc := func() (backend.LLMResponse, error) { + resp := responses[idx] + if idx < len(responses)-1 { + idx++ + } + return resp, nil + } + return predFunc, nil + } + } + + BeforeEach(func() { + origInference = backend.ModelInferenceFunc + cfg = &config.ModelConfig{} + appCfg = config.NewApplicationConfig() + }) + + AfterEach(func() { + backend.ModelInferenceFunc = origInference + }) + + makeReq := func() *schema.OpenAIRequest { + ctx, cancel := context.WithCancel(context.Background()) + _ = cancel + return &schema.OpenAIRequest{ + Context: ctx, + Cancel: cancel, + } + } + + Context("normal response (no retry needed)", func() { + It("should return choices on first attempt", func() { + mockInference([]backend.LLMResponse{ + {Response: "Hello world", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}}, + }) + + var captured string + choices, usage, _, err := ComputeChoices( + makeReq(), "test prompt", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + captured = s + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(1)) + Expect(captured).To(Equal("Hello world")) + Expect(usage.Prompt).To(Equal(10)) + Expect(usage.Completion).To(Equal(5)) + }) + }) + + Context("empty response triggers built-in retry", func() { + It("should retry and eventually return non-empty response", func() { + mockInference([]backend.LLMResponse{ + {Response: ""}, // attempt 0: empty + {Response: " "}, // attempt 1: whitespace-only + {Response: "Got it", Usage: backend.TokenUsage{Prompt: 8, Completion: 3}}, // attempt 2: success + }) + + choices, usage, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(1)) + Expect(choices[0].Text).To(Equal("Got it")) + Expect(usage.Prompt).To(Equal(8)) + Expect(usage.Completion).To(Equal(3)) + }) + }) + + Context("all retries exhausted on empty response", func() { + It("should return the empty response after max retries", func() { + mockInference([]backend.LLMResponse{ + {Response: ""}, // always empty + }) + + choices, _, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + ) + Expect(err).ToNot(HaveOccurred()) + // After maxRetries, it proceeds with the empty response + Expect(choices).To(HaveLen(1)) + Expect(choices[0].Text).To(BeEmpty()) + }) + }) + + Context("shouldRetry callback", func() { + It("should call shouldRetry and retry when it returns true", func() { + callCount := 0 + mockInference([]backend.LLMResponse{ + {Response: "reasoning-only", Usage: backend.TokenUsage{Prompt: 5, Completion: 2}}, + {Response: "actual-answer", Usage: backend.TokenUsage{Prompt: 5, Completion: 4}}, + }) + + retryAttempts := []int{} + choices, usage, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + callCount++ + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + func(attempt int) bool { + retryAttempts = append(retryAttempts, attempt) + // Retry on first attempt only + return attempt == 0 + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(1)) + Expect(choices[0].Text).To(Equal("actual-answer")) + // shouldRetry was called twice: once returning true (retry), once returning false (proceed) + Expect(retryAttempts).To(Equal([]int{0, 1})) + // cb was called twice (once per attempt) + Expect(callCount).To(Equal(2)) + // Token usage should be from the LATEST attempt + Expect(usage.Prompt).To(Equal(5)) + Expect(usage.Completion).To(Equal(4)) + }) + + It("should not retry when shouldRetry returns false", func() { + mockInference([]backend.LLMResponse{ + {Response: "first-response"}, + }) + + shouldRetryCalled := false + choices, _, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + func(attempt int) bool { + shouldRetryCalled = true + return false + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(1)) + Expect(choices[0].Text).To(Equal("first-response")) + Expect(shouldRetryCalled).To(BeTrue()) + }) + }) + + Context("shouldRetry not provided (variadic omitted)", func() { + It("should work without shouldRetry parameter", func() { + mockInference([]backend.LLMResponse{ + {Response: "works"}, + }) + + choices, _, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(1)) + Expect(choices[0].Text).To(Equal("works")) + }) + }) + + Context("token usage from latest attempt", func() { + It("should use token usage from the last attempt, not accumulated", func() { + mockInference([]backend.LLMResponse{ + {Response: "retry-me", Usage: backend.TokenUsage{Prompt: 100, Completion: 50}}, + {Response: "final", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}}, + }) + + _, usage, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + func(attempt int) bool { return attempt == 0 }, + ) + Expect(err).ToNot(HaveOccurred()) + // Should be the LATEST attempt's usage, not accumulated + Expect(usage.Prompt).To(Equal(10)) + Expect(usage.Completion).To(Equal(5)) + }) + }) + + Context("chat deltas from latest attempt", func() { + It("should return chat deltas from the last attempt only", func() { + mockInference([]backend.LLMResponse{ + { + Response: "retry-me", + ChatDeltas: []*pb.ChatDelta{{Content: "old"}}, + }, + { + Response: "final", + ChatDeltas: []*pb.ChatDelta{{Content: "new"}}, + }, + }) + + _, _, deltas, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + func(attempt int) bool { return attempt == 0 }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(deltas).To(HaveLen(1)) + Expect(deltas[0].Content).To(Equal("new")) + }) + }) + + Context("result choices cleared on retry", func() { + It("should only contain choices from the final attempt", func() { + mockInference([]backend.LLMResponse{ + {Response: "bad-choice"}, + {Response: "good-choice"}, + }) + + choices, _, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + func(attempt int) bool { return attempt == 0 }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(1)) + Expect(choices[0].Text).To(Equal("good-choice")) + }) + }) + + Context("shouldRetry with max retries cap", func() { + It("should stop retrying after maxRetries even if shouldRetry returns true", func() { + attempts := 0 + mockInference([]backend.LLMResponse{ + {Response: "always-retry"}, + }) + + choices, _, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + func(attempt int) bool { + attempts++ + return true // always want to retry + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(1)) + // maxRetries is 5, so shouldRetry is called for attempts 0..4, + // but attempt 5 is the final one where shouldRetry can't trigger continue + Expect(attempts).To(BeNumerically("<=", 6)) + }) + }) + + Context("N > 1 completions", func() { + It("should produce N separate completions", func() { + callIdx := 0 + responses := []string{"first", "second", "third"} + backend.ModelInferenceFunc = func( + ctx context.Context, s string, messages schema.Messages, + images, videos, audios []string, + loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, + o *config.ApplicationConfig, + tokenCallback func(string, backend.TokenUsage) bool, + tools, toolChoice string, + logprobs, topLogprobs *int, + logitBias map[string]float64, + metadata map[string]string, + ) (func() (backend.LLMResponse, error), error) { + predFunc := func() (backend.LLMResponse, error) { + resp := backend.LLMResponse{Response: responses[callIdx]} + if callIdx < len(responses)-1 { + callIdx++ + } + return resp, nil + } + return predFunc, nil + } + + req := makeReq() + req.N = 3 + choices, _, _, err := ComputeChoices( + req, "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + nil, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(3)) + Expect(choices[0].Text).To(Equal("first")) + Expect(choices[1].Text).To(Equal("second")) + Expect(choices[2].Text).To(Equal("third")) + }) + }) + + Context("with streaming token callback", func() { + It("should call tokenCallback for streaming responses", func() { + var streamedTokens []string + backend.ModelInferenceFunc = func( + ctx context.Context, s string, messages schema.Messages, + images, videos, audios []string, + loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, + o *config.ApplicationConfig, + tokenCallback func(string, backend.TokenUsage) bool, + tools, toolChoice string, + logprobs, topLogprobs *int, + logitBias map[string]float64, + metadata map[string]string, + ) (func() (backend.LLMResponse, error), error) { + predFunc := func() (backend.LLMResponse, error) { + if tokenCallback != nil { + tokenCallback("Hello", backend.TokenUsage{Prompt: 5}) + tokenCallback(" world", backend.TokenUsage{Prompt: 5, Completion: 2}) + } + return backend.LLMResponse{ + Response: "Hello world", + Usage: backend.TokenUsage{Prompt: 5, Completion: 2}, + }, nil + } + return predFunc, nil + } + + choices, _, _, err := ComputeChoices( + makeReq(), "test", cfg, nil, appCfg, nil, + func(s string, c *[]schema.Choice) { + *c = append(*c, schema.Choice{Text: s}) + }, + func(s string, usage backend.TokenUsage) bool { + streamedTokens = append(streamedTokens, s) + return true + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(choices).To(HaveLen(1)) + Expect(streamedTokens).To(Equal([]string{"Hello", " world"})) + }) + }) +}) diff --git a/core/http/endpoints/openresponses/responses.go b/core/http/endpoints/openresponses/responses.go index 9b0ae2a23..09f3b82e8 100644 --- a/core/http/endpoints/openresponses/responses.go +++ b/core/http/endpoints/openresponses/responses.go @@ -12,6 +12,7 @@ import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" + openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" @@ -879,34 +880,14 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon xlog.Debug("Background MCP re-templating", "iteration", mcpIteration) } - images := []string{} - videos := []string{} - audios := []string{} - for _, m := range openAIReq.Messages { - images = append(images, m.StringImages...) - videos = append(videos, m.StringVideos...) - audios = append(audios, m.StringAudios...) - } - - toolsJSON := serializeToolsForBackend(input.Tools) - toolChoiceJSON := "" - if input.ToolChoice != nil { - toolChoiceBytes, err := json.Marshal(input.ToolChoice) - if err == nil { - toolChoiceJSON = string(toolChoiceBytes) - } - } - - var logprobs *int + // Populate openAIReq fields for ComputeChoices + openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) + openAIReq.ToolsChoice = input.ToolChoice if input.TopLogprobs != nil && *input.TopLogprobs > 0 { - logprobs = input.TopLogprobs - } - - predFunc, err := backend.ModelInference( - ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias, nil) - if err != nil { - return nil, fmt.Errorf("model inference failed: %w", err) + openAIReq.TopLogprobs = input.TopLogprobs + openAIReq.Logprobs = schema.LogprobsValue{Enabled: true} } + openAIReq.LogitBias = input.LogitBias select { case <-ctx.Done(): @@ -914,24 +895,19 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon default: } - const maxEmptyRetries = 5 - var prediction backend.LLMResponse var result string - for attempt := 0; attempt <= maxEmptyRetries; attempt++ { - prediction, err = predFunc() - if err != nil { - return nil, fmt.Errorf("prediction failed: %w", err) - } - result = backend.Finetune(*cfg, predInput, prediction.Response) - if result != "" || !shouldUseFn { - break - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - xlog.Warn("Open Responses background: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries) + cb := func(s string, c *[]schema.Choice) { + result = s + } + choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil) + if err != nil { + return nil, fmt.Errorf("model inference failed: %w", err) + } + + // Extract logprobs from choices if available + var resultLogprobs *schema.Logprobs + if len(choices) > 0 { + resultLogprobs = choices[0].Logprobs } // Parse tool calls @@ -939,9 +915,9 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon var textContent string if shouldUseFn { - if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 { + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { funcCallResults = deltaToolCalls - textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas) + textContent = functions.ContentFromChatDeltas(chatDeltas) } else { cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig) funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) @@ -1021,7 +997,7 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)}, }) } for _, tc := range toolCalls { @@ -1034,22 +1010,22 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)}, }) } } else { allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)}, }) } now := time.Now().Unix() return buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, allOutputItems, &schema.ORUsage{ - InputTokens: prediction.Usage.Prompt, - OutputTokens: prediction.Usage.Completion, - TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + InputTokens: tokenUsage.Prompt, + OutputTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, }, true), nil } // end MCP iteration loop @@ -1058,23 +1034,14 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon // handleBackgroundStream handles background streaming responses with event buffering func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) { - images := []string{} - videos := []string{} - audios := []string{} - for _, m := range openAIReq.Messages { - images = append(images, m.StringImages...) - videos = append(videos, m.StringVideos...) - audios = append(audios, m.StringAudios...) - } - - toolsJSON := serializeToolsForBackend(input.Tools) - toolChoiceJSON := "" - if input.ToolChoice != nil { - toolChoiceBytes, err := json.Marshal(input.ToolChoice) - if err == nil { - toolChoiceJSON = string(toolChoiceBytes) - } + // Populate openAIReq fields for ComputeChoices + openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) + openAIReq.ToolsChoice = input.ToolChoice + if input.TopLogprobs != nil && *input.TopLogprobs > 0 { + openAIReq.TopLogprobs = input.TopLogprobs + openAIReq.Logprobs = schema.LogprobsValue{Enabled: true} } + openAIReq.LogitBias = input.LogitBias sequenceNumber := 0 @@ -1105,20 +1072,13 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI } hasMCPTools := len(mcpToolInfos) > 0 - var prediction backend.LLMResponse + var lastTokenUsage backend.TokenUsage + var lastLogprobs *schema.Logprobs for mcpIter := 0; mcpIter <= mcpBgStreamMaxIterations; mcpIter++ { if mcpIter > 0 { predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Background stream MCP re-templating", "iteration", mcpIter) - images = images[:0] - videos = videos[:0] - audios = audios[:0] - for _, m := range openAIReq.Messages { - images = append(images, m.StringImages...) - videos = append(videos, m.StringVideos...) - audios = append(audios, m.StringAudios...) - } } accumulatedText = "" @@ -1177,28 +1137,23 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI return true } - var streamLogprobs *int - if input.TopLogprobs != nil && *input.TopLogprobs > 0 { - streamLogprobs = input.TopLogprobs + var result string + cb := func(s string, c *[]schema.Choice) { + result = s } - - predFunc, err := backend.ModelInference( - ctx, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias, nil) + choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, tokenCallback) if err != nil { return nil, fmt.Errorf("model inference failed: %w", err) } - - prediction, err = predFunc() - if err != nil { - return nil, fmt.Errorf("prediction failed: %w", err) + lastTokenUsage = tokenUsage + if len(choices) > 0 { + lastLogprobs = choices[0].Logprobs } - result := backend.Finetune(*cfg, predInput, prediction.Response) - // Check for MCP tool calls in the streamed result if shouldUseFn && hasMCPTools { var funcCallResults []functions.FuncCallResults - if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 { + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { funcCallResults = deltaToolCalls } else { cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig) @@ -1315,7 +1270,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI } // No MCP tools — close the message and break - streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs) + streamEventLogprobs := convertLogprobsForStreaming(lastLogprobs) bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_text.done", SequenceNumber: sequenceNumber, @@ -1327,7 +1282,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI }) sequenceNumber++ - textPart := makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs) + textPart := makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs) bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, @@ -1343,7 +1298,7 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI ID: currentMessageID, Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs)}, } bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_item.done", @@ -1360,9 +1315,9 @@ func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseI // Build final response now := time.Now().Unix() response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, collectedOutputItems, &schema.ORUsage{ - InputTokens: prediction.Usage.Prompt, - OutputTokens: prediction.Usage.Completion, - TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + InputTokens: lastTokenUsage.Prompt, + OutputTokens: lastTokenUsage.Completion, + TotalTokens: lastTokenUsage.Prompt + lastTokenUsage.Completion, }, true) // Emit response.completed @@ -1391,52 +1346,27 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i if mcpIteration > mcpMaxIterations { return sendOpenResponsesError(c, 500, "server_error", "MCP iteration limit reached", "") } - images := []string{} - videos := []string{} - audios := []string{} - for _, m := range openAIReq.Messages { - images = append(images, m.StringImages...) - videos = append(videos, m.StringVideos...) - audios = append(audios, m.StringAudios...) - } - - // Convert and serialize tools to OpenAI format for the backend - toolsJSON := serializeToolsForBackend(input.Tools) - toolChoiceJSON := "" - if input.ToolChoice != nil { - toolChoiceBytes, err := json.Marshal(input.ToolChoice) - if err == nil { - toolChoiceJSON = string(toolChoiceBytes) - } - } - - // Pass logprobs and logit_bias parameters if requested - var logprobs *int + // Populate openAIReq fields for ComputeChoices + openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) + openAIReq.ToolsChoice = input.ToolChoice if input.TopLogprobs != nil && *input.TopLogprobs > 0 { - logprobs = input.TopLogprobs + openAIReq.TopLogprobs = input.TopLogprobs + openAIReq.Logprobs = schema.LogprobsValue{Enabled: true} } + openAIReq.LogitBias = input.LogitBias - predFunc, err := backend.ModelInference( - input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, nil, toolsJSON, toolChoiceJSON, logprobs, input.TopLogprobs, input.LogitBias, nil) + var result string + cb := func(s string, c *[]schema.Choice) { + result = s + } + choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil) if err != nil { xlog.Error("Open Responses model inference failed", "error", err) return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("model inference failed: %v", err), "") } - - const maxEmptyRetries = 5 - var prediction backend.LLMResponse - var result string - for attempt := 0; attempt <= maxEmptyRetries; attempt++ { - prediction, err = predFunc() - if err != nil { - xlog.Error("Open Responses prediction failed", "error", err) - return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("prediction failed: %v", err), "") - } - result = backend.Finetune(*cfg, predInput, prediction.Response) - if result != "" || !shouldUseFn { - break - } - xlog.Warn("Open Responses: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries) + var resultLogprobs *schema.Logprobs + if len(choices) > 0 { + resultLogprobs = choices[0].Logprobs } xlog.Debug("Open Responses - Raw model result", "result", result, "shouldUseFn", shouldUseFn) @@ -1473,10 +1403,10 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i var textContent string // Try pre-parsed tool calls from C++ autoparser first - if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 { + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls)) funcCallResults = deltaToolCalls - textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas) + textContent = functions.ContentFromChatDeltas(chatDeltas) } else { xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing") // Clean up the result (already extracted reasoning above) @@ -1574,7 +1504,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)}, }) } @@ -1605,7 +1535,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)}, }) } } else { @@ -1615,7 +1545,7 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)}, } outputItems = append(outputItems, messageItem) } @@ -1633,9 +1563,9 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i // Build response with all required fields now := time.Now().Unix() response := buildORResponse(responseID, createdAt, &now, "completed", input, outputItems, &schema.ORUsage{ - InputTokens: prediction.Usage.Prompt, - OutputTokens: prediction.Usage.Completion, - TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + InputTokens: tokenUsage.Prompt, + OutputTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, OutputTokensDetails: &schema.OROutputTokensDetails{ ReasoningTokens: reasoningTokens, }, @@ -1675,24 +1605,14 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 }) sequenceNumber++ - images := []string{} - videos := []string{} - audios := []string{} - for _, m := range openAIReq.Messages { - images = append(images, m.StringImages...) - videos = append(videos, m.StringVideos...) - audios = append(audios, m.StringAudios...) - } - - // Convert and serialize tools to OpenAI format for the backend - toolsJSON := serializeToolsForBackend(input.Tools) - toolChoiceJSON := "" - if input.ToolChoice != nil { - toolChoiceBytes, err := json.Marshal(input.ToolChoice) - if err == nil { - toolChoiceJSON = string(toolChoiceBytes) - } + // Populate openAIReq fields for ComputeChoices + openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) + openAIReq.ToolsChoice = input.ToolChoice + if input.TopLogprobs != nil && *input.TopLogprobs > 0 { + openAIReq.TopLogprobs = input.TopLogprobs + openAIReq.Logprobs = schema.LogprobsValue{Enabled: true} } + openAIReq.LogitBias = input.LogitBias // Detect if thinking token is already in prompt or template var template string @@ -1714,10 +1634,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 // Track reasoning state for streaming var currentReasoningID string var currentReasoningContentIndex int - var accumulatedContent string - var lastEmittedReasoning string - var lastEmittedCleanedContent string var reasoningTokens int + extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig) // Collect all output items for storage var collectedOutputItems []schema.ORItemField @@ -1729,34 +1647,25 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 } hasMCPToolsStream := len(mcpToolInfos) > 0 - var prediction backend.LLMResponse var result, finalReasoning, finalCleanedResult string var textContent string var parsedToolCalls []functions.FuncCallResults var toolCalls []functions.FuncCallResults + var lastStreamTokenUsage backend.TokenUsage + var lastStreamLogprobs *schema.Logprobs for mcpStreamIter := 0; mcpStreamIter <= mcpStreamMaxIterations; mcpStreamIter++ { if mcpStreamIter > 0 { // Reset reasoning and tool-call state for re-inference so reasoning // extraction runs again on subsequent iterations inToolCallMode = false - accumulatedContent = "" - lastEmittedReasoning = "" - lastEmittedCleanedContent = "" + extractor.Reset() currentMessageID = "" lastEmittedToolCallCount = 0 currentReasoningID = "" predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Open Responses stream MCP re-templating", "iteration", mcpStreamIter) - images = images[:0] - videos = videos[:0] - audios = audios[:0] - for _, m := range openAIReq.Messages { - images = append(images, m.StringImages...) - videos = append(videos, m.StringVideos...) - audios = append(audios, m.StringAudios...) - } } // For tool calls, we need to track accumulated result and parse incrementally @@ -1911,11 +1820,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 // If no tool calls detected yet, handle reasoning and text if !inToolCallMode { - accumulatedContent += token - currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, cfg.ReasoningConfig) + reasoningDelta, contentDelta := extractor.ProcessToken(token) // Handle reasoning item - if currentReasoning != "" { + if extractor.Reasoning() != "" { // Check if we need to create reasoning item if currentReasoningID == "" { outputIndex++ @@ -1947,16 +1855,6 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 sequenceNumber++ } - // Calculate reasoning delta - var reasoningDelta string - if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) { - reasoningDelta = currentReasoning[len(lastEmittedReasoning):] - lastEmittedReasoning = currentReasoning - } else if currentReasoning != lastEmittedReasoning { - reasoningDelta = currentReasoning - lastEmittedReasoning = currentReasoning - } - // Emit reasoning delta if there's new content if reasoningDelta != "" { sendSSEEvent(c, &schema.ORStreamEvent{ @@ -1973,23 +1871,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 } } - // Handle message content (cleaned content without reasoning tags) - var deltaContent string - if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) { - deltaContent = cleanedContent[len(lastEmittedCleanedContent):] - lastEmittedCleanedContent = cleanedContent - } else if cleanedContent != lastEmittedCleanedContent { - if lastEmittedCleanedContent == "" { - deltaContent = cleanedContent - lastEmittedCleanedContent = cleanedContent - } else { - deltaContent = cleanedContent - lastEmittedCleanedContent = cleanedContent - } - } - // Only emit message content if there's actual content (not just reasoning) - if deltaContent != "" { + if contentDelta != "" { if currentMessageID == "" { // Emit output_item.added for message outputIndex++ @@ -2030,7 +1913,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, - Delta: strPtr(deltaContent), + Delta: strPtr(contentDelta), Logprobs: emptyLogprobs(), }) sequenceNumber++ @@ -2040,14 +1923,11 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 return true } - // Pass logprobs and logit_bias parameters if requested - var streamLogprobs *int - if input.TopLogprobs != nil && *input.TopLogprobs > 0 { - streamLogprobs = input.TopLogprobs + var ccResult string + ccCb := func(s string, c *[]schema.Choice) { + ccResult = s } - - predFunc, err := backend.ModelInference( - input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, streamLogprobs, input.TopLogprobs, input.LogitBias, nil) + choices, ccTokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, ccCb, tokenCallback) if err != nil { xlog.Error("Open Responses stream model inference failed", "error", err) sendSSEEvent(c, &schema.ORStreamEvent{ @@ -2071,36 +1951,27 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 c.Response().Flush() return nil } - - prediction, err = predFunc() - if err != nil { - xlog.Error("Open Responses stream prediction failed", "error", err) - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "error", - SequenceNumber: sequenceNumber, - Error: &schema.ORErrorPayload{ - Type: "model_error", - Message: fmt.Sprintf("prediction failed: %v", err), - }, - }) - sequenceNumber++ - responseFailed := responseCreated - responseFailed.Status = "failed" - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.failed", - SequenceNumber: sequenceNumber, - Response: responseFailed, - }) - // Send [DONE] even on error - fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") - c.Response().Flush() - return nil + result = ccResult + lastStreamTokenUsage = ccTokenUsage + if len(choices) > 0 { + lastStreamLogprobs = choices[0].Logprobs } - result = backend.Finetune(*cfg, predInput, prediction.Response) - - // Extract reasoning from final result - finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) + // Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's + // streaming state, (3) final extraction from the finetuned result. + if chatDeltaReasoning := functions.ReasoningFromChatDeltas(chatDeltas); chatDeltaReasoning != "" { + finalReasoning = chatDeltaReasoning + finalCleanedResult = functions.ContentFromChatDeltas(chatDeltas) + if finalCleanedResult == "" { + finalCleanedResult = extractor.CleanedContent() + } + } else { + finalReasoning = extractor.Reasoning() + finalCleanedResult = extractor.CleanedContent() + } + if finalReasoning == "" && finalCleanedResult == "" { + finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) + } // Close reasoning item if it exists and wasn't closed yet if currentReasoningID != "" && finalReasoning != "" { @@ -2157,10 +2028,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 textContent = "" // Try pre-parsed tool calls from C++ autoparser first - if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 { + if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls)) parsedToolCalls = deltaToolCalls - textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas) + textContent = functions.ContentFromChatDeltas(chatDeltas) } else { xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing") cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig) @@ -2279,8 +2150,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 } - // Convert prediction logprobs for streaming events - streamEventLogprobs := convertLogprobsForStreaming(prediction.Logprobs) + // Convert logprobs for streaming events + streamEventLogprobs := convertLogprobsForStreaming(lastStreamLogprobs) // If we have no output but the model did produce something, use the cleaned result (without reasoning tags) if textContent == "" && len(toolCalls) == 0 && finalCleanedResult != "" { @@ -2303,7 +2174,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 sequenceNumber++ // Emit content_part.done (with actual logprobs) - textPart := makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs) + textPart := makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, @@ -2320,7 +2191,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 ID: currentMessageID, Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)}, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", @@ -2389,7 +2260,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 ID: currentMessageID, Status: "completed", Role: "assistant", - Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, prediction.Logprobs)}, + Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)}, }) } // Add tool call items @@ -2408,9 +2279,9 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 // Emit response.completed now := time.Now().Unix() responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, allOutputItems, &schema.ORUsage{ - InputTokens: prediction.Usage.Prompt, - OutputTokens: prediction.Usage.Completion, - TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + InputTokens: lastStreamTokenUsage.Prompt, + OutputTokens: lastStreamTokenUsage.Completion, + TotalTokens: lastStreamTokenUsage.Prompt + lastStreamTokenUsage.Completion, OutputTokensDetails: &schema.OROutputTokensDetails{ ReasoningTokens: reasoningTokens, }, @@ -2469,12 +2340,10 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 // Stream text deltas with reasoning extraction tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { accumulatedText += token - accumulatedContent += token - // Prepend thinking token if needed, then extract reasoning - currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, cfg.ReasoningConfig) + reasoningDelta, contentDelta := extractor.ProcessToken(token) // Handle reasoning item - if currentReasoning != "" { + if extractor.Reasoning() != "" { // Check if we need to create reasoning item if currentReasoningID == "" { outputIndex++ @@ -2506,16 +2375,6 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 sequenceNumber++ } - // Calculate reasoning delta - var reasoningDelta string - if len(currentReasoning) > len(lastEmittedReasoning) && strings.HasPrefix(currentReasoning, lastEmittedReasoning) { - reasoningDelta = currentReasoning[len(lastEmittedReasoning):] - lastEmittedReasoning = currentReasoning - } else if currentReasoning != lastEmittedReasoning { - reasoningDelta = currentReasoning - lastEmittedReasoning = currentReasoning - } - // Emit reasoning delta if there's new content if reasoningDelta != "" { sendSSEEvent(c, &schema.ORStreamEvent{ @@ -2532,23 +2391,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 } } - // Handle message content (cleaned content without reasoning tags) - var deltaContent string - if len(cleanedContent) > len(lastEmittedCleanedContent) && strings.HasPrefix(cleanedContent, lastEmittedCleanedContent) { - deltaContent = cleanedContent[len(lastEmittedCleanedContent):] - lastEmittedCleanedContent = cleanedContent - } else if cleanedContent != lastEmittedCleanedContent { - if lastEmittedCleanedContent == "" { - deltaContent = cleanedContent - lastEmittedCleanedContent = cleanedContent - } else { - deltaContent = cleanedContent - lastEmittedCleanedContent = cleanedContent - } - } - // Only emit message content if there's actual content (not just reasoning) - if deltaContent != "" { + if contentDelta != "" { // Emit text delta sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.delta", @@ -2556,7 +2400,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, - Delta: strPtr(deltaContent), + Delta: strPtr(contentDelta), Logprobs: emptyLogprobs(), }) sequenceNumber++ @@ -2565,14 +2409,11 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 return true } - // Pass logprobs and logit_bias parameters if requested - var mcpLogprobs *int - if input.TopLogprobs != nil && *input.TopLogprobs > 0 { - mcpLogprobs = input.TopLogprobs + var noToolResult string + noToolCb := func(s string, c *[]schema.Choice) { + noToolResult = s } - - predFunc, err := backend.ModelInference( - input.Context, predInput, openAIReq.Messages, images, videos, audios, ml, cfg, cl, appConfig, tokenCallback, toolsJSON, toolChoiceJSON, mcpLogprobs, input.TopLogprobs, input.LogitBias, nil) + noToolChoices, noToolTokenUsage, noToolChatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, noToolCb, tokenCallback) if err != nil { xlog.Error("Open Responses stream model inference failed", "error", err) sendSSEEvent(c, &schema.ORStreamEvent{ @@ -2596,36 +2437,28 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 c.Response().Flush() return nil } - - prediction, err := predFunc() - if err != nil { - xlog.Error("Open Responses stream prediction failed", "error", err) - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "error", - SequenceNumber: sequenceNumber, - Error: &schema.ORErrorPayload{ - Type: "model_error", - Message: fmt.Sprintf("prediction failed: %v", err), - }, - }) - sequenceNumber++ - responseFailed := responseCreated - responseFailed.Status = "failed" - sendSSEEvent(c, &schema.ORStreamEvent{ - Type: "response.failed", - SequenceNumber: sequenceNumber, - Response: responseFailed, - }) - // Send [DONE] even on error - fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") - c.Response().Flush() - return nil + result := noToolResult + var noToolLogprobs *schema.Logprobs + if len(noToolChoices) > 0 { + noToolLogprobs = noToolChoices[0].Logprobs } - result := backend.Finetune(*cfg, predInput, prediction.Response) - - // Extract reasoning from final result for non-tool-call path - finalReasoning, finalCleanedResult := reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) + // Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's + // streaming state, (3) final extraction from the finetuned result. + var finalReasoning, finalCleanedResult string + if chatDeltaReasoning := functions.ReasoningFromChatDeltas(noToolChatDeltas); chatDeltaReasoning != "" { + finalReasoning = chatDeltaReasoning + finalCleanedResult = functions.ContentFromChatDeltas(noToolChatDeltas) + if finalCleanedResult == "" { + finalCleanedResult = extractor.CleanedContent() + } + } else { + finalReasoning = extractor.Reasoning() + finalCleanedResult = extractor.CleanedContent() + } + if finalReasoning == "" && finalCleanedResult == "" { + finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) + } // Close reasoning item if it exists and wasn't closed yet if currentReasoningID != "" && finalReasoning != "" { @@ -2680,8 +2513,8 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 result = finalCleanedResult - // Convert prediction logprobs for streaming events - mcpStreamLogprobs := convertLogprobsForStreaming(prediction.Logprobs) + // Convert logprobs for streaming events + mcpStreamLogprobs := convertLogprobsForStreaming(noToolLogprobs) // Emit output_text.done sendSSEEvent(c, &schema.ORStreamEvent{ @@ -2696,7 +2529,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 sequenceNumber++ // Emit content_part.done (with actual logprobs) - resultPart := makeOutputTextPartWithLogprobs(result, prediction.Logprobs) + resultPart := makeOutputTextPartWithLogprobs(result, noToolLogprobs) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, @@ -2709,7 +2542,7 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 // Emit output_item.done (with actual logprobs) messageItem.Status = "completed" - messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, prediction.Logprobs)} + messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, noToolLogprobs)} sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, @@ -2744,9 +2577,9 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6 finalOutputItems = append(finalOutputItems, *messageItem) } responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, finalOutputItems, &schema.ORUsage{ - InputTokens: prediction.Usage.Prompt, - OutputTokens: prediction.Usage.Completion, - TotalTokens: prediction.Usage.Prompt + prediction.Usage.Completion, + InputTokens: noToolTokenUsage.Prompt, + OutputTokens: noToolTokenUsage.Completion, + TotalTokens: noToolTokenUsage.Prompt + noToolTokenUsage.Completion, OutputTokensDetails: &schema.OROutputTokensDetails{ ReasoningTokens: reasoningTokens, }, @@ -3035,19 +2868,6 @@ func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.T return result } -// serializeToolsForBackend converts and serializes Open Responses tools to JSON for the backend -func serializeToolsForBackend(orTools []schema.ORFunctionTool) string { - if len(orTools) == 0 { - return "" - } - openAITools := convertORToolsToOpenAIFormat(orTools) - toolsBytes, err := json.Marshal(openAITools) - if err != nil { - return "" - } - return string(toolsBytes) -} - // GetResponseEndpoint returns a handler for GET /responses/:id // This endpoint is used for polling background responses or resuming streaming // @Summary Get a response by ID diff --git a/pkg/reasoning/extractor.go b/pkg/reasoning/extractor.go new file mode 100644 index 000000000..f5b5c6c82 --- /dev/null +++ b/pkg/reasoning/extractor.go @@ -0,0 +1,104 @@ +package reasoning + +import "strings" + +// ReasoningExtractor tracks streaming reasoning extraction state, computing +// incremental deltas so callers don't need to duplicate the ~30-line +// accumulated-content / last-emitted tracking logic. +// +// Usage: +// +// extractor := NewReasoningExtractor(thinkingStartToken, cfg) +// // In your streaming token callback: +// reasoningDelta, contentDelta := extractor.ProcessToken(token) +// // After streaming completes: +// finalReasoning := extractor.Reasoning() +// finalContent := extractor.CleanedContent() +type ReasoningExtractor struct { + thinkingStartToken string + config Config + accumulated string + lastReasoning string + lastCleaned string + suppressReasoning bool +} + +// NewReasoningExtractor creates a new extractor for the given thinking token and config. +func NewReasoningExtractor(thinkingStartToken string, cfg Config) *ReasoningExtractor { + return &ReasoningExtractor{ + thinkingStartToken: thinkingStartToken, + config: cfg, + } +} + +// ProcessToken processes a new streaming token and returns the reasoning +// and content deltas (the new portions not yet emitted). +func (e *ReasoningExtractor) ProcessToken(token string) (reasoningDelta, contentDelta string) { + e.accumulated += token + currentReasoning, cleanedContent := ExtractReasoningWithConfig(e.accumulated, e.thinkingStartToken, e.config) + + // Calculate reasoning delta + if currentReasoning != e.lastReasoning { + if len(currentReasoning) > len(e.lastReasoning) && strings.HasPrefix(currentReasoning, e.lastReasoning) { + reasoningDelta = currentReasoning[len(e.lastReasoning):] + } else if currentReasoning != "" { + // Reasoning changed in a non-append way, emit the full current reasoning + reasoningDelta = currentReasoning + } + e.lastReasoning = currentReasoning + } + + // Calculate content delta + if len(cleanedContent) > len(e.lastCleaned) && strings.HasPrefix(cleanedContent, e.lastCleaned) { + contentDelta = cleanedContent[len(e.lastCleaned):] + e.lastCleaned = cleanedContent + } else if cleanedContent != e.lastCleaned { + contentDelta = cleanedContent + e.lastCleaned = cleanedContent + } + + if e.suppressReasoning { + reasoningDelta = "" + } + + return reasoningDelta, contentDelta +} + +// Reasoning returns the total accumulated reasoning after streaming. +func (e *ReasoningExtractor) Reasoning() string { + return e.lastReasoning +} + +// CleanedContent returns the total accumulated content (reasoning stripped). +func (e *ReasoningExtractor) CleanedContent() string { + return e.lastCleaned +} + +// Accumulated returns the total raw accumulated content. +func (e *ReasoningExtractor) Accumulated() string { + return e.accumulated +} + +// Reset clears the extractor state for reuse. +func (e *ReasoningExtractor) Reset() { + e.accumulated = "" + e.lastReasoning = "" + e.lastCleaned = "" +} + +// ResetAndSuppressReasoning clears state and suppresses future reasoning deltas. +// ProcessToken() still extracts reasoning internally (CleanedContent works), +// but returns empty reasoningDelta — reasoning is not surfaced to the caller. +// This is used on retry after streaming: reasoning from the first attempt was +// already sent to the client; re-streaming it would cause duplicates. +func (e *ReasoningExtractor) ResetAndSuppressReasoning() { + e.accumulated = "" + e.lastReasoning = "" + e.lastCleaned = "" + e.suppressReasoning = true +} + +// Suppressed returns whether reasoning delta suppression is active. +func (e *ReasoningExtractor) Suppressed() bool { + return e.suppressReasoning +} diff --git a/pkg/reasoning/extractor_test.go b/pkg/reasoning/extractor_test.go new file mode 100644 index 000000000..854f59cf0 --- /dev/null +++ b/pkg/reasoning/extractor_test.go @@ -0,0 +1,198 @@ +package reasoning_test + +import ( + . "github.com/mudler/LocalAI/pkg/reasoning" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ReasoningExtractor", func() { + Context("basic streaming with tags", func() { + It("should extract reasoning and content deltas incrementally", func() { + ext := NewReasoningExtractor("", Config{}) + + // Simulate tokens arriving one at a time + tokens := []string{"", "I need", " to think", "", "Hello", " world"} + var allReasoningDeltas, allContentDeltas string + + for _, tok := range tokens { + rDelta, cDelta := ext.ProcessToken(tok) + allReasoningDeltas += rDelta + allContentDeltas += cDelta + } + + Expect(ext.Reasoning()).To(Equal("I need to think")) + Expect(ext.CleanedContent()).To(Equal("Hello world")) + Expect(allReasoningDeltas).To(Equal("I need to think")) + Expect(allContentDeltas).To(Equal("Hello world")) + }) + }) + + Context("no reasoning tags", func() { + It("should pass all content through as content deltas", func() { + ext := NewReasoningExtractor("", Config{}) + + rDelta1, cDelta1 := ext.ProcessToken("Hello") + rDelta2, cDelta2 := ext.ProcessToken(" world") + + Expect(rDelta1).To(BeEmpty()) + Expect(cDelta1).To(Equal("Hello")) + Expect(rDelta2).To(BeEmpty()) + Expect(cDelta2).To(Equal(" world")) + Expect(ext.Reasoning()).To(BeEmpty()) + Expect(ext.CleanedContent()).To(Equal("Hello world")) + }) + }) + + Context("unclosed thinking tags", func() { + It("should treat content after unclosed tag as reasoning", func() { + ext := NewReasoningExtractor("", Config{}) + + ext.ProcessToken("") + ext.ProcessToken("still thinking") + // No closing tag - reasoning is extracted from unclosed tag + + Expect(ext.Reasoning()).To(Equal("still thinking")) + Expect(ext.CleanedContent()).To(BeEmpty()) + }) + }) + + Context("empty tokens", func() { + It("should handle empty tokens gracefully", func() { + ext := NewReasoningExtractor("", Config{}) + + rDelta, cDelta := ext.ProcessToken("") + Expect(rDelta).To(BeEmpty()) + Expect(cDelta).To(BeEmpty()) + + rDelta, cDelta = ext.ProcessToken("Hello") + Expect(rDelta).To(BeEmpty()) + Expect(cDelta).To(Equal("Hello")) + }) + }) + + Context("Reset", func() { + It("should clear all state", func() { + ext := NewReasoningExtractor("", Config{}) + + ext.ProcessToken("reasoncontent") + Expect(ext.Reasoning()).ToNot(BeEmpty()) + Expect(ext.CleanedContent()).ToNot(BeEmpty()) + + ext.Reset() + Expect(ext.Reasoning()).To(BeEmpty()) + Expect(ext.CleanedContent()).To(BeEmpty()) + Expect(ext.Accumulated()).To(BeEmpty()) + }) + }) + + Context("disabled reasoning", func() { + It("should pass all content through when reasoning is disabled", func() { + disabled := true + ext := NewReasoningExtractor("", Config{DisableReasoning: &disabled}) + + rDelta, cDelta := ext.ProcessToken("reasoncontent") + Expect(rDelta).To(BeEmpty()) + Expect(cDelta).To(Equal("reasoncontent")) + Expect(ext.Reasoning()).To(BeEmpty()) + }) + }) + + Context("split tags across tokens", func() { + It("should handle tags split across multiple tokens", func() { + ext := NewReasoningExtractor("", Config{}) + + // Tag arrives in pieces + ext.ProcessToken("reasoning herefinal answer") + + Expect(ext.Reasoning()).To(Equal("reasoning here")) + Expect(ext.CleanedContent()).To(Equal("final answer")) + }) + }) + + Context("ResetAndSuppressReasoning", func() { + It("should suppress reasoning deltas but still extract reasoning internally", func() { + ext := NewReasoningExtractor("", Config{}) + + // First pass: reasoning is emitted normally + rDelta1, cDelta1 := ext.ProcessToken("first reasoningfirst content") + Expect(rDelta1).To(Equal("first reasoning")) + Expect(cDelta1).To(Equal("first content")) + Expect(ext.Suppressed()).To(BeFalse()) + + // Simulate retry: suppress reasoning + ext.ResetAndSuppressReasoning() + Expect(ext.Suppressed()).To(BeTrue()) + Expect(ext.Reasoning()).To(BeEmpty()) + Expect(ext.CleanedContent()).To(BeEmpty()) + Expect(ext.Accumulated()).To(BeEmpty()) + + // Second pass: reasoning deltas suppressed, content still works + rDelta2, cDelta2 := ext.ProcessToken("retry reasoningretry content") + Expect(rDelta2).To(BeEmpty(), "reasoning delta should be suppressed after ResetAndSuppressReasoning") + Expect(cDelta2).To(Equal("retry content")) + + // Internal state still tracks reasoning (for CleanedContent to work) + Expect(ext.Reasoning()).To(Equal("retry reasoning")) + Expect(ext.CleanedContent()).To(Equal("retry content")) + }) + + It("should suppress reasoning across multiple streaming tokens", func() { + ext := NewReasoningExtractor("", Config{}) + ext.ResetAndSuppressReasoning() + + tokens := []string{"", "suppressed", " thought", "", "visible", " answer"} + var allReasoningDeltas, allContentDeltas string + + for _, tok := range tokens { + rDelta, cDelta := ext.ProcessToken(tok) + allReasoningDeltas += rDelta + allContentDeltas += cDelta + } + + Expect(allReasoningDeltas).To(BeEmpty(), "no reasoning deltas should be emitted when suppressed") + Expect(allContentDeltas).To(Equal("visible answer")) + Expect(ext.Reasoning()).To(Equal("suppressed thought")) + Expect(ext.CleanedContent()).To(Equal("visible answer")) + }) + }) + + Context("Accumulated", func() { + It("should return all raw tokens concatenated", func() { + ext := NewReasoningExtractor("", Config{}) + + ext.ProcessToken("reason") + ext.ProcessToken("content") + + Expect(ext.Accumulated()).To(Equal("reasoncontent")) + }) + }) + + Context("with thinking start token prefill", func() { + It("should prepend thinking token when prefill is not disabled", func() { + ext := NewReasoningExtractor("", Config{}) + + // Content without explicit tag - extractor should prepend it + ext.ProcessToken("I am thinking") + ext.ProcessToken("") + ext.ProcessToken("Answer here") + + Expect(ext.Reasoning()).To(Equal("I am thinking")) + Expect(ext.CleanedContent()).To(Equal("Answer here")) + }) + }) + + Context("strip reasoning only", func() { + It("should strip reasoning from content but not return it", func() { + strip := true + ext := NewReasoningExtractor("", Config{StripReasoningOnly: &strip}) + + ext.ProcessToken("secret reasoningvisible content") + + Expect(ext.Reasoning()).To(BeEmpty()) + Expect(ext.CleanedContent()).To(Equal("visible content")) + }) + }) +})