mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 05:36:49 -04:00
feat(functions): add peg-based parsing and allow backends to return tool calls directly (#8838)
* feat(functions): add peg-based parsing Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: support returning toolcalls directly from backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: do run PEG only if backend didn't send deltas Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
b57a6e42f1
commit
b2f81bfa2e
@@ -17,6 +17,7 @@
|
||||
#include "backend.pb.h"
|
||||
#include "backend.grpc.pb.h"
|
||||
#include "common.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include <getopt.h>
|
||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
@@ -866,6 +867,56 @@ public:
|
||||
return logprobs_json;
|
||||
}
|
||||
|
||||
// Helper: populate chat_deltas on a Reply from oaicompat_msg_diffs (streaming chunks)
|
||||
static void populate_chat_deltas_from_diffs(backend::Reply & reply,
|
||||
const std::vector<common_chat_msg_diff> & diffs) {
|
||||
for (const auto & diff : diffs) {
|
||||
auto* delta = reply.add_chat_deltas();
|
||||
if (!diff.content_delta.empty()) {
|
||||
delta->set_content(diff.content_delta);
|
||||
}
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta->set_reasoning_content(diff.reasoning_content_delta);
|
||||
}
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
auto* tc = delta->add_tool_calls();
|
||||
tc->set_index(static_cast<int32_t>(diff.tool_call_index));
|
||||
if (!diff.tool_call_delta.id.empty()) {
|
||||
tc->set_id(diff.tool_call_delta.id);
|
||||
}
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
tc->set_name(diff.tool_call_delta.name);
|
||||
}
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
tc->set_arguments(diff.tool_call_delta.arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: populate chat_deltas on a Reply from final oaicompat_msg (non-streaming)
|
||||
static void populate_chat_deltas_from_final(backend::Reply & reply,
|
||||
const common_chat_msg & msg) {
|
||||
// Content delta
|
||||
if (!msg.content.empty() || !msg.reasoning_content.empty() || !msg.tool_calls.empty()) {
|
||||
auto* delta = reply.add_chat_deltas();
|
||||
if (!msg.content.empty()) {
|
||||
delta->set_content(msg.content);
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
delta->set_reasoning_content(msg.reasoning_content);
|
||||
}
|
||||
// Tool calls as individual deltas within the same ChatDelta
|
||||
for (size_t i = 0; i < msg.tool_calls.size(); i++) {
|
||||
auto* tc = delta->add_tool_calls();
|
||||
tc->set_index(static_cast<int32_t>(i));
|
||||
tc->set_id(msg.tool_calls[i].id);
|
||||
tc->set_name(msg.tool_calls[i].name);
|
||||
tc->set_arguments(msg.tool_calls[i].arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
@@ -1484,127 +1535,76 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||
}
|
||||
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result
|
||||
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
||||
backend::Reply reply;
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
reply.set_message(completion_text);
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
|
||||
if (res_json.contains("timings")) {
|
||||
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
reply.set_logprobs(logprobs_json.dump());
|
||||
}
|
||||
|
||||
return reply;
|
||||
};
|
||||
|
||||
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
||||
// Try streaming partial result first
|
||||
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
||||
if (partial && !partial->oaicompat_msg_diffs.empty()) {
|
||||
populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs);
|
||||
return;
|
||||
}
|
||||
// Try final result
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(raw_result);
|
||||
if (final_res && final_res->is_updated) {
|
||||
populate_chat_deltas_from_diffs(reply, final_res->oaicompat_msg_diffs);
|
||||
}
|
||||
};
|
||||
|
||||
// Process first result
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
std::string completion_text = res.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = res.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (res.contains("timings")) {
|
||||
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
std::string completion_text = first_res_json.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (first_res_json.contains("timings")) {
|
||||
double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(first_res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
auto reply = build_reply_from_json(first_res_json, first_result.get());
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
writer->Write(reply);
|
||||
}
|
||||
|
||||
// Process subsequent results
|
||||
while (rd.has_next()) {
|
||||
// Check if context is cancelled before processing result
|
||||
if (context->IsCancelled()) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto result = rd.next([&context]() { return context->IsCancelled(); });
|
||||
if (result == nullptr) {
|
||||
// connection is closed
|
||||
break;
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
std::string completion_text = res.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = res.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (res.contains("timings")) {
|
||||
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
attach_chat_deltas(reply, result.get());
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = res_json.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (res_json.contains("timings")) {
|
||||
double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
auto reply = build_reply_from_json(res_json, result.get());
|
||||
attach_chat_deltas(reply, result.get());
|
||||
writer->Write(reply);
|
||||
}
|
||||
}
|
||||
@@ -2264,7 +2264,8 @@ public:
|
||||
std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl;
|
||||
if (all_results.results.size() == 1) {
|
||||
// single result
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
||||
GGML_ASSERT(final_res != nullptr);
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
@@ -2287,6 +2288,11 @@ public:
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
// Populate chat deltas from the autoparser's final parsed message
|
||||
if (final_res->is_updated) {
|
||||
populate_chat_deltas_from_final(*reply, final_res->oaicompat_msg);
|
||||
}
|
||||
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
@@ -2609,6 +2615,113 @@ public:
|
||||
|
||||
response->set_rendered_template(rendered_template);
|
||||
|
||||
// Run differential template analysis to detect tool format markers
|
||||
if (params_base.use_jinja) {
|
||||
try {
|
||||
// Get template source and reconstruct a common_chat_template for analysis
|
||||
std::string tmpl_src = common_chat_templates_source(ctx_server.impl->chat_params.tmpls.get());
|
||||
if (!tmpl_src.empty()) {
|
||||
const auto * vocab = llama_model_get_vocab(ctx_server.impl->model);
|
||||
std::string token_bos, token_eos;
|
||||
if (vocab) {
|
||||
auto bos_id = llama_vocab_bos(vocab);
|
||||
auto eos_id = llama_vocab_eos(vocab);
|
||||
if (bos_id != LLAMA_TOKEN_NULL) {
|
||||
token_bos = common_token_to_piece(vocab, bos_id, true);
|
||||
}
|
||||
if (eos_id != LLAMA_TOKEN_NULL) {
|
||||
token_eos = common_token_to_piece(vocab, eos_id, true);
|
||||
}
|
||||
}
|
||||
common_chat_template tmpl(tmpl_src, token_bos, token_eos);
|
||||
struct autoparser::autoparser ap;
|
||||
ap.analyze_template(tmpl);
|
||||
|
||||
if (ap.analysis_complete && ap.tools.format.mode != autoparser::tool_format::NONE) {
|
||||
auto * tf = response->mutable_tool_format();
|
||||
|
||||
// Format type
|
||||
switch (ap.tools.format.mode) {
|
||||
case autoparser::tool_format::JSON_NATIVE:
|
||||
tf->set_format_type("json_native");
|
||||
break;
|
||||
case autoparser::tool_format::TAG_WITH_JSON:
|
||||
tf->set_format_type("tag_with_json");
|
||||
break;
|
||||
case autoparser::tool_format::TAG_WITH_TAGGED:
|
||||
tf->set_format_type("tag_with_tagged");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
// Tool section markers
|
||||
tf->set_section_start(ap.tools.format.section_start);
|
||||
tf->set_section_end(ap.tools.format.section_end);
|
||||
tf->set_per_call_start(ap.tools.format.per_call_start);
|
||||
tf->set_per_call_end(ap.tools.format.per_call_end);
|
||||
|
||||
// Function markers
|
||||
tf->set_func_name_prefix(ap.tools.function.name_prefix);
|
||||
tf->set_func_name_suffix(ap.tools.function.name_suffix);
|
||||
tf->set_func_close(ap.tools.function.close);
|
||||
|
||||
// Argument markers
|
||||
tf->set_arg_name_prefix(ap.tools.arguments.name_prefix);
|
||||
tf->set_arg_name_suffix(ap.tools.arguments.name_suffix);
|
||||
tf->set_arg_value_prefix(ap.tools.arguments.value_prefix);
|
||||
tf->set_arg_value_suffix(ap.tools.arguments.value_suffix);
|
||||
tf->set_arg_separator(ap.tools.arguments.separator);
|
||||
tf->set_args_start(ap.tools.arguments.start);
|
||||
tf->set_args_end(ap.tools.arguments.end);
|
||||
|
||||
// JSON format fields
|
||||
tf->set_name_field(ap.tools.format.name_field);
|
||||
tf->set_args_field(ap.tools.format.args_field);
|
||||
tf->set_id_field(ap.tools.format.id_field);
|
||||
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
|
||||
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
|
||||
tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts);
|
||||
tf->set_function_field(ap.tools.format.function_field);
|
||||
|
||||
tf->set_gen_id_field(ap.tools.format.gen_id_field);
|
||||
|
||||
for (const auto & p : ap.tools.format.parameter_order) {
|
||||
tf->add_parameter_order(p);
|
||||
}
|
||||
|
||||
// Call ID markers
|
||||
switch (ap.tools.call_id.pos) {
|
||||
case autoparser::call_id_position::NONE:
|
||||
tf->set_call_id_position("none");
|
||||
break;
|
||||
case autoparser::call_id_position::PRE_FUNC_NAME:
|
||||
tf->set_call_id_position("pre_func_name");
|
||||
break;
|
||||
case autoparser::call_id_position::BETWEEN_FUNC_AND_ARGS:
|
||||
tf->set_call_id_position("between_func_and_args");
|
||||
break;
|
||||
case autoparser::call_id_position::POST_ARGS:
|
||||
tf->set_call_id_position("post_args");
|
||||
break;
|
||||
}
|
||||
tf->set_call_id_prefix(ap.tools.call_id.prefix);
|
||||
tf->set_call_id_suffix(ap.tools.call_id.suffix);
|
||||
|
||||
// Reasoning markers
|
||||
tf->set_reasoning_start(ap.reasoning.start);
|
||||
tf->set_reasoning_end(ap.reasoning.end);
|
||||
|
||||
// Content markers
|
||||
tf->set_content_start(ap.content.start);
|
||||
tf->set_content_end(ap.content.end);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
SRV_WRN("ModelMetadata: failed to run autoparser analysis: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user