mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-16 21:08:16 -04:00
fix(chat): do not retry if we had chatdeltas or tooldeltas from backend (#9244)
* fix(chat): do not retry if we had chatdeltas or tooldeltas from backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: use oai compat for llama.cpp Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: apply to non-streaming path too Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * map also other fields Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
06fbe48b3f
commit
773489eeb1
@@ -1309,6 +1309,7 @@ public:
|
||||
|
||||
body_json["messages"] = messages_json;
|
||||
body_json["stream"] = true; // PredictStream is always streaming
|
||||
body_json["stream_options"] = {{"include_usage", true}}; // Ensure token counts in final chunk
|
||||
|
||||
// Check if grammar is provided from Go layer (NoGrammar=false)
|
||||
// If grammar is provided, we must use it and NOT let template generate grammar from tools
|
||||
@@ -1616,8 +1617,11 @@ public:
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
// Without this, the PEG parser never produces diffs and the Go side
|
||||
// cannot detect tool calls or separate reasoning from content.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -1642,19 +1646,47 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||
}
|
||||
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result.
|
||||
// Handles both native format ({"content": "..."}) and OAI chat format
|
||||
// ({"choices": [{"delta": {"content": "...", "reasoning": "..."}}]}).
|
||||
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
||||
backend::Reply reply;
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
reply.set_message(completion_text);
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
std::string completion_text;
|
||||
|
||||
if (res_json.contains("choices")) {
|
||||
// OAI chat format — extract content from choices[0].delta
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & delta = choices[0].value("delta", json::object());
|
||||
if (delta.contains("content") && !delta.at("content").is_null()) {
|
||||
completion_text = delta.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = res_json.value("content", "");
|
||||
}
|
||||
|
||||
reply.set_message(completion_text);
|
||||
|
||||
// Token counts: native format has top-level fields,
|
||||
// OAI format has them in "usage" (final chunk only)
|
||||
if (res_json.contains("usage")) {
|
||||
const auto & usage = res_json.at("usage");
|
||||
reply.set_tokens(usage.value("completion_tokens", 0));
|
||||
reply.set_prompt_tokens(usage.value("prompt_tokens", 0));
|
||||
} else {
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
}
|
||||
|
||||
// Timings: present as top-level "timings" in both formats
|
||||
if (res_json.contains("timings")) {
|
||||
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
reply.set_logprobs(logprobs_json.dump());
|
||||
@@ -1663,21 +1695,17 @@ public:
|
||||
return reply;
|
||||
};
|
||||
|
||||
// Attach chat deltas from the autoparser to a Reply.
|
||||
// When diffs are available, populate ChatDeltas on the reply.
|
||||
// The raw message is always preserved so the Go side can use it
|
||||
// for reasoning extraction and tool call parsing as a fallback
|
||||
// (important in distributed mode where ChatDeltas may not be
|
||||
// the primary parsing path).
|
||||
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
||||
// Try streaming partial result first
|
||||
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
||||
if (partial) {
|
||||
if (!partial->oaicompat_msg_diffs.empty()) {
|
||||
populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs);
|
||||
} else if (partial->is_updated) {
|
||||
// Autoparser is active but hasn't classified this chunk yet
|
||||
// (PEG parser warming up). Clear the raw message so the Go
|
||||
// side doesn't try to parse partial tag tokens (e.g. "<|channel>"
|
||||
// before the full "<|channel>thought\n" is received).
|
||||
// This matches llama.cpp server behavior which only emits SSE
|
||||
// chunks when the parser produces diffs.
|
||||
reply.set_message("");
|
||||
}
|
||||
if (partial && !partial->oaicompat_msg_diffs.empty()) {
|
||||
populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs);
|
||||
return;
|
||||
}
|
||||
// Try final result
|
||||
@@ -2357,8 +2385,9 @@ public:
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
// OAI-compat: enable autoparser (PEG-based chat parsing) so that
|
||||
// reasoning, tool calls, and content are classified into ChatDeltas.
|
||||
task.params.res_type = TASK_RESPONSE_TYPE_OAI_CHAT;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
@@ -2389,25 +2418,48 @@ public:
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
||||
GGML_ASSERT(final_res != nullptr);
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
// Handle both native format ({"content": "...", "tokens_predicted": N})
|
||||
// and OAI chat format ({"choices": [{"message": {"content": "..."}}],
|
||||
// "usage": {"completion_tokens": N, "prompt_tokens": N}}).
|
||||
std::string completion_text;
|
||||
int32_t tokens_predicted = 0;
|
||||
int32_t tokens_evaluated = 0;
|
||||
|
||||
if (result_json.contains("choices")) {
|
||||
// OAI chat format
|
||||
const auto & choices = result_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
completion_text = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
if (result_json.contains("usage")) {
|
||||
const auto & usage = result_json.at("usage");
|
||||
tokens_predicted = usage.value("completion_tokens", 0);
|
||||
tokens_evaluated = usage.value("prompt_tokens", 0);
|
||||
}
|
||||
} else {
|
||||
// Native llama.cpp format
|
||||
completion_text = result_json.value("content", "");
|
||||
tokens_predicted = result_json.value("tokens_predicted", 0);
|
||||
tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
}
|
||||
reply->set_message(completion_text);
|
||||
reply->set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
|
||||
reply->set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
// Timings: present in both formats as a top-level "timings" object
|
||||
if (result_json.contains("timings")) {
|
||||
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply->set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply->set_timing_token_generation(timing_token_generation);
|
||||
reply->set_timing_prompt_processing(result_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply->set_timing_token_generation(result_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
// Logprobs: extract_logprobs_from_json handles both formats
|
||||
json logprobs_json = extract_logprobs_from_json(result_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply->set_logprobs(logprobs_str);
|
||||
reply->set_logprobs(logprobs_json.dump());
|
||||
}
|
||||
|
||||
// Populate chat deltas from the autoparser's final parsed message
|
||||
@@ -2423,7 +2475,20 @@ public:
|
||||
for (auto & res : all_results.results) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
json res_json = res->to_json();
|
||||
arr.push_back(res_json.value("content", ""));
|
||||
// Handle both native and OAI chat formats
|
||||
std::string result_content;
|
||||
if (res_json.contains("choices")) {
|
||||
const auto & choices = res_json.at("choices");
|
||||
if (!choices.empty()) {
|
||||
const auto & msg = choices[0].value("message", json::object());
|
||||
if (msg.contains("content") && !msg.at("content").is_null()) {
|
||||
result_content = msg.at("content").get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result_content = res_json.value("content", "");
|
||||
}
|
||||
arr.push_back(result_content);
|
||||
|
||||
// Extract logprobs for each result
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
|
||||
@@ -147,10 +147,23 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
sentInitialRole := false
|
||||
hasChatDeltaToolCalls := false
|
||||
hasChatDeltaContent := false
|
||||
|
||||
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
|
||||
// Track whether ChatDeltas from the C++ autoparser contain
|
||||
// tool calls or content, so the retry decision can account for them.
|
||||
for _, d := range usage.ChatDeltas {
|
||||
if len(d.ToolCalls) > 0 {
|
||||
hasChatDeltaToolCalls = true
|
||||
}
|
||||
if d.Content != "" {
|
||||
hasChatDeltaContent = true
|
||||
}
|
||||
}
|
||||
|
||||
var reasoningDelta, contentDelta string
|
||||
|
||||
goReasoning, goContent := extractor.ProcessToken(s)
|
||||
@@ -309,15 +322,22 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// After streaming completes: check if we got actionable content
|
||||
cleaned := extractor.CleanedContent()
|
||||
// Check for tool calls from chat deltas (will be re-checked after ComputeChoices,
|
||||
// but we need to know here whether to retry)
|
||||
hasToolCalls := lastEmittedCount > 0
|
||||
if cleaned == "" && !hasToolCalls {
|
||||
// but we need to know here whether to retry).
|
||||
// Also check ChatDelta flags — when the C++ autoparser is active,
|
||||
// tool calls and content are delivered via ChatDeltas while the
|
||||
// raw message is cleared. Without this check, we'd retry
|
||||
// unnecessarily, losing valid results and concatenating output.
|
||||
hasToolCalls := lastEmittedCount > 0 || hasChatDeltaToolCalls
|
||||
hasContent := cleaned != "" || hasChatDeltaContent
|
||||
if !hasContent && !hasToolCalls {
|
||||
xlog.Warn("Streaming: backend produced only reasoning, retrying",
|
||||
"reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1)
|
||||
extractor.ResetAndSuppressReasoning()
|
||||
result = ""
|
||||
lastEmittedCount = 0
|
||||
sentInitialRole = false
|
||||
hasChatDeltaToolCalls = false
|
||||
hasChatDeltaContent = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -113,11 +113,23 @@ func ComputeChoices(
|
||||
}
|
||||
prediction = p
|
||||
|
||||
// Built-in: retry on truly empty response (no tokens at all)
|
||||
// Built-in: retry on truly empty response (no tokens at all).
|
||||
// However, when the C++ autoparser is active, it clears the raw
|
||||
// message and delivers content via ChatDeltas instead. Do NOT
|
||||
// retry if ChatDeltas contain tool calls or content.
|
||||
if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries {
|
||||
xlog.Warn("Backend returned empty response, retrying",
|
||||
"attempt", attempt+1, "maxRetries", maxRetries)
|
||||
continue
|
||||
hasChatDeltaData := false
|
||||
for _, d := range prediction.ChatDeltas {
|
||||
if d.Content != "" || len(d.ToolCalls) > 0 {
|
||||
hasChatDeltaData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasChatDeltaData {
|
||||
xlog.Warn("Backend returned empty response, retrying",
|
||||
"attempt", attempt+1, "maxRetries", maxRetries)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage.Prompt = prediction.Usage.Prompt
|
||||
@@ -130,8 +142,21 @@ func ComputeChoices(
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Caller-driven retry (tool parsing, reasoning-only, etc.)
|
||||
if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries {
|
||||
// Caller-driven retry (tool parsing, reasoning-only, etc.).
|
||||
// When the C++ autoparser is active, it clears the raw response
|
||||
// and delivers data via ChatDeltas. If the response is empty but
|
||||
// ChatDeltas contain actionable data, skip the caller retry —
|
||||
// the autoparser already parsed the response successfully.
|
||||
skipCallerRetry := false
|
||||
if strings.TrimSpace(prediction.Response) == "" && len(prediction.ChatDeltas) > 0 {
|
||||
for _, d := range prediction.ChatDeltas {
|
||||
if d.Content != "" || len(d.ToolCalls) > 0 {
|
||||
skipCallerRetry = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if shouldRetryFn != nil && !skipCallerRetry && shouldRetryFn(attempt) && attempt < maxRetries {
|
||||
// Caller has already reset its state inside shouldRetry
|
||||
result = result[:0]
|
||||
allChatDeltas = nil
|
||||
|
||||
@@ -101,6 +101,25 @@ var _ = BeforeSuite(func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(configPath, configYAML, 0644)).To(Succeed())
|
||||
|
||||
// Create model config for autoparser tests (NoGrammar so tool calls
|
||||
// are driven entirely by the backend's ChatDeltas, not grammar enforcement)
|
||||
autoparserConfig := map[string]any{
|
||||
"name": "mock-model-autoparser",
|
||||
"backend": "mock-backend",
|
||||
"parameters": map[string]any{
|
||||
"model": "mock-model.bin",
|
||||
},
|
||||
"function": map[string]any{
|
||||
"grammar": map[string]any{
|
||||
"disable": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
autoparserPath := filepath.Join(modelsPath, "mock-model-autoparser.yaml")
|
||||
autoparserYAML, err := yaml.Marshal(autoparserConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(autoparserPath, autoparserYAML, 0644)).To(Succeed())
|
||||
|
||||
// Start mock MCP server and create MCP-enabled model config
|
||||
mcpServerURL, mcpServerShutdown = startMockMCPServer()
|
||||
mcpConfig := mcpModelConfig(mcpServerURL)
|
||||
|
||||
@@ -55,6 +55,46 @@ func (m *MockBackend) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.R
|
||||
if strings.Contains(in.Prompt, "MOCK_ERROR") {
|
||||
return nil, fmt.Errorf("mock backend predict error: simulated failure")
|
||||
}
|
||||
|
||||
// Simulate C++ autoparser: tool call via ChatDeltas, empty message
|
||||
if strings.Contains(in.Prompt, "AUTOPARSER_TOOL_CALL") {
|
||||
toolName := mockToolNameFromRequest(in)
|
||||
if toolName == "" {
|
||||
toolName = "search_collections"
|
||||
}
|
||||
return &pb.Reply{
|
||||
Message: []byte{},
|
||||
Tokens: 10,
|
||||
PromptTokens: 5,
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "I need to search for information."},
|
||||
{
|
||||
ToolCalls: []*pb.ToolCallDelta{
|
||||
{
|
||||
Index: 0,
|
||||
Id: "call_mock_123",
|
||||
Name: toolName,
|
||||
Arguments: `{"query":"localai"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Simulate C++ autoparser: content via ChatDeltas, empty message
|
||||
if strings.Contains(in.Prompt, "AUTOPARSER_CONTENT") {
|
||||
return &pb.Reply{
|
||||
Message: []byte{},
|
||||
Tokens: 10,
|
||||
PromptTokens: 5,
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: "Let me compose a response."},
|
||||
{Content: "LocalAI is an open-source AI platform."},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var response string
|
||||
toolName := mockToolNameFromRequest(in)
|
||||
if toolName != "" && !promptHasToolResults(in.Prompt) {
|
||||
@@ -88,6 +128,77 @@ func (m *MockBackend) PredictStream(in *pb.PredictOptions, stream pb.Backend_Pre
|
||||
}
|
||||
return fmt.Errorf("mock backend stream error: simulated mid-stream failure")
|
||||
}
|
||||
|
||||
// Simulate C++ autoparser behavior: tool calls delivered via ChatDeltas
|
||||
// with empty message (autoparser clears raw message during parsing).
|
||||
if strings.Contains(in.Prompt, "AUTOPARSER_TOOL_CALL") {
|
||||
toolName := mockToolNameFromRequest(in)
|
||||
if toolName == "" {
|
||||
toolName = "search_collections"
|
||||
}
|
||||
// Phase 1: Stream reasoning tokens with empty message (autoparser active)
|
||||
reasoning := "I need to search for information."
|
||||
for _, r := range reasoning {
|
||||
if err := stream.Send(&pb.Reply{
|
||||
Message: []byte{}, // autoparser clears raw message
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: string(r)},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Phase 2: Emit tool call via ChatDeltas (no raw message)
|
||||
if err := stream.Send(&pb.Reply{
|
||||
Message: []byte{}, // autoparser clears raw message
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{
|
||||
ToolCalls: []*pb.ToolCallDelta{
|
||||
{
|
||||
Index: 0,
|
||||
Id: "call_mock_123",
|
||||
Name: toolName,
|
||||
Arguments: `{"query":"localai"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Simulate C++ autoparser behavior: content delivered via ChatDeltas
|
||||
// with empty message (autoparser clears raw message during parsing).
|
||||
if strings.Contains(in.Prompt, "AUTOPARSER_CONTENT") {
|
||||
// Phase 1: Stream reasoning via ChatDeltas
|
||||
reasoning := "Let me compose a response."
|
||||
for _, r := range reasoning {
|
||||
if err := stream.Send(&pb.Reply{
|
||||
Message: []byte{},
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{ReasoningContent: string(r)},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Phase 2: Stream content via ChatDeltas (no raw message)
|
||||
content := "LocalAI is an open-source AI platform."
|
||||
for _, r := range content {
|
||||
if err := stream.Send(&pb.Reply{
|
||||
Message: []byte{},
|
||||
ChatDeltas: []*pb.ChatDelta{
|
||||
{Content: string(r)},
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var toStream string
|
||||
toolName := mockToolNameFromRequest(in)
|
||||
if toolName != "" && !promptHasToolResults(in.Prompt) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package e2e_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -265,4 +266,201 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Autoparser ChatDelta Streaming", Label("Autoparser"), func() {
|
||||
// These tests verify that when the C++ autoparser handles tool calls
|
||||
// and content via ChatDeltas (with empty raw message), the streaming
|
||||
// endpoint does NOT unnecessarily retry. This is a regression test for
|
||||
// the bug where the retry logic only checked Go-side parsing, ignoring
|
||||
// ChatDelta results, causing up to 6 retries and concatenated output.
|
||||
|
||||
Context("Streaming with tools and ChatDelta tool calls", func() {
|
||||
It("should return tool calls without unnecessary retries", func() {
|
||||
body := `{
|
||||
"model": "mock-model-autoparser",
|
||||
"messages": [{"role": "user", "content": "AUTOPARSER_TOOL_CALL"}],
|
||||
"tools": [{"type": "function", "function": {"name": "search_collections", "description": "Search documents", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}}}],
|
||||
"stream": true
|
||||
}`
|
||||
req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
httpClient := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := httpClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
bodyStr := string(data)
|
||||
|
||||
// Parse all SSE events
|
||||
lines := strings.Split(bodyStr, "\n")
|
||||
var toolCallChunks int
|
||||
var reasoningChunks int
|
||||
hasFinishReason := false
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
jsonData := strings.TrimPrefix(line, "data: ")
|
||||
var chunk map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonData), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
choices, ok := chunk["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
choice := choices[0].(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if delta == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := delta["tool_calls"]; ok {
|
||||
toolCallChunks++
|
||||
}
|
||||
if _, ok := delta["reasoning"]; ok {
|
||||
reasoningChunks++
|
||||
}
|
||||
if fr, ok := choice["finish_reason"].(string); ok && fr != "" {
|
||||
hasFinishReason = true
|
||||
}
|
||||
}
|
||||
|
||||
// The key assertion: tool calls from ChatDeltas should be present
|
||||
Expect(toolCallChunks).To(BeNumerically(">", 0),
|
||||
"Expected tool_calls in streaming response from ChatDeltas, but got none. "+
|
||||
"This likely means the retry logic discarded ChatDelta tool calls.")
|
||||
|
||||
// Should have a finish reason
|
||||
Expect(hasFinishReason).To(BeTrue(), "Expected a finish_reason in the streaming response")
|
||||
|
||||
// Reasoning should be present (from ChatDelta reasoning)
|
||||
Expect(reasoningChunks).To(BeNumerically(">", 0),
|
||||
"Expected reasoning deltas from ChatDeltas")
|
||||
})
|
||||
})
|
||||
|
||||
Context("Streaming with tools and ChatDelta content (no tool calls)", func() {
|
||||
It("should return content without retrying and without concatenation", func() {
|
||||
body := `{
|
||||
"model": "mock-model-autoparser",
|
||||
"messages": [{"role": "user", "content": "AUTOPARSER_CONTENT"}],
|
||||
"tools": [{"type": "function", "function": {"name": "search_collections", "description": "Search documents", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}}}],
|
||||
"stream": true
|
||||
}`
|
||||
req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
httpClient := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := httpClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
bodyStr := string(data)
|
||||
|
||||
// Parse all SSE events and collect content
|
||||
lines := strings.Split(bodyStr, "\n")
|
||||
var contentParts []string
|
||||
var reasoningParts []string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
|
||||
continue
|
||||
}
|
||||
jsonData := strings.TrimPrefix(line, "data: ")
|
||||
var chunk map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonData), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
choices, ok := chunk["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
choice := choices[0].(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if delta == nil {
|
||||
continue
|
||||
}
|
||||
if content, ok := delta["content"].(string); ok && content != "" {
|
||||
contentParts = append(contentParts, content)
|
||||
}
|
||||
if reasoning, ok := delta["reasoning"].(string); ok && reasoning != "" {
|
||||
reasoningParts = append(reasoningParts, reasoning)
|
||||
}
|
||||
}
|
||||
|
||||
fullContent := strings.Join(contentParts, "")
|
||||
fullReasoning := strings.Join(reasoningParts, "")
|
||||
|
||||
// Content should be present and match the expected answer
|
||||
Expect(fullContent).To(ContainSubstring("LocalAI"),
|
||||
"Expected content from ChatDeltas to contain 'LocalAI'. "+
|
||||
"The retry logic may have discarded ChatDelta content.")
|
||||
|
||||
// Content should NOT be duplicated (no retry concatenation)
|
||||
occurrences := strings.Count(fullContent, "LocalAI is an open-source AI platform.")
|
||||
Expect(occurrences).To(Equal(1),
|
||||
"Expected content to appear exactly once, but found %d occurrences. "+
|
||||
"This indicates unnecessary retries are concatenating output.", occurrences)
|
||||
|
||||
// Reasoning should be present
|
||||
Expect(fullReasoning).To(ContainSubstring("compose"),
|
||||
"Expected reasoning content from ChatDeltas")
|
||||
})
|
||||
})
|
||||
|
||||
Context("Non-streaming with tools and ChatDelta tool calls", func() {
|
||||
It("should return tool calls from ChatDeltas", func() {
|
||||
body := `{
|
||||
"model": "mock-model-autoparser",
|
||||
"messages": [{"role": "user", "content": "AUTOPARSER_TOOL_CALL"}],
|
||||
"tools": [{"type": "function", "function": {"name": "search_collections", "description": "Search documents", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}}}]
|
||||
}`
|
||||
req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
httpClient := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := httpClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var result map[string]any
|
||||
Expect(json.Unmarshal(data, &result)).To(Succeed())
|
||||
|
||||
choices, ok := result["choices"].([]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(choices).To(HaveLen(1))
|
||||
|
||||
choice := choices[0].(map[string]any)
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
Expect(msg).ToNot(BeNil())
|
||||
|
||||
toolCalls, ok := msg["tool_calls"].([]any)
|
||||
Expect(ok).To(BeTrue(),
|
||||
"Expected tool_calls in non-streaming response from ChatDeltas, "+
|
||||
"but got: %s", string(data))
|
||||
Expect(toolCalls).To(HaveLen(1))
|
||||
|
||||
tc := toolCalls[0].(map[string]any)
|
||||
fn, _ := tc["function"].(map[string]any)
|
||||
Expect(fn["name"]).To(Equal("search_collections"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user