feat(functions): add peg-based parsing and allow backends to return tool calls directly (#8838)

* feat(functions): add peg-based parsing

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: support returning toolcalls directly from backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: do run PEG only if backend didn't send deltas

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-03-08 22:21:57 +01:00
committed by GitHub
parent b57a6e42f1
commit b2f81bfa2e
25 changed files with 6204 additions and 1090 deletions

View File

@@ -17,6 +17,7 @@
#include "backend.pb.h"
#include "backend.grpc.pb.h"
#include "common.h"
#include "chat-auto-parser.h"
#include <getopt.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h>
@@ -866,6 +867,56 @@ public:
return logprobs_json;
}
// Helper: populate chat_deltas on a Reply from oaicompat_msg_diffs (streaming chunks)
static void populate_chat_deltas_from_diffs(backend::Reply & reply,
const std::vector<common_chat_msg_diff> & diffs) {
for (const auto & diff : diffs) {
auto* delta = reply.add_chat_deltas();
if (!diff.content_delta.empty()) {
delta->set_content(diff.content_delta);
}
if (!diff.reasoning_content_delta.empty()) {
delta->set_reasoning_content(diff.reasoning_content_delta);
}
if (diff.tool_call_index != std::string::npos) {
auto* tc = delta->add_tool_calls();
tc->set_index(static_cast<int32_t>(diff.tool_call_index));
if (!diff.tool_call_delta.id.empty()) {
tc->set_id(diff.tool_call_delta.id);
}
if (!diff.tool_call_delta.name.empty()) {
tc->set_name(diff.tool_call_delta.name);
}
if (!diff.tool_call_delta.arguments.empty()) {
tc->set_arguments(diff.tool_call_delta.arguments);
}
}
}
}
// Helper: populate chat_deltas on a Reply from final oaicompat_msg (non-streaming)
static void populate_chat_deltas_from_final(backend::Reply & reply,
const common_chat_msg & msg) {
// Content delta
if (!msg.content.empty() || !msg.reasoning_content.empty() || !msg.tool_calls.empty()) {
auto* delta = reply.add_chat_deltas();
if (!msg.content.empty()) {
delta->set_content(msg.content);
}
if (!msg.reasoning_content.empty()) {
delta->set_reasoning_content(msg.reasoning_content);
}
// Tool calls as individual deltas within the same ChatDelta
for (size_t i = 0; i < msg.tool_calls.size(); i++) {
auto* tc = delta->add_tool_calls();
tc->set_index(static_cast<int32_t>(i));
tc->set_id(msg.tool_calls[i].id);
tc->set_name(msg.tool_calls[i].name);
tc->set_arguments(msg.tool_calls[i].arguments);
}
}
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
if (params_base.model.path.empty()) {
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
@@ -1484,127 +1535,76 @@ public:
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
}
// Lambda to build a Reply from JSON + attach chat deltas from a result
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
backend::Reply reply;
std::string completion_text = res_json.value("content", "");
reply.set_message(completion_text);
reply.set_tokens(res_json.value("tokens_predicted", 0));
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
if (res_json.contains("timings")) {
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
}
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
reply.set_logprobs(logprobs_json.dump());
}
return reply;
};
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
// Try streaming partial result first
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
if (partial && !partial->oaicompat_msg_diffs.empty()) {
populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs);
return;
}
// Try final result
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(raw_result);
if (final_res && final_res->is_updated) {
populate_chat_deltas_from_diffs(reply, final_res->oaicompat_msg_diffs);
}
};
// Process first result
json first_res_json = first_result->to_json();
if (first_res_json.is_array()) {
for (const auto & res : first_res_json) {
std::string completion_text = res.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res.contains("timings")) {
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
auto reply = build_reply_from_json(res, first_result.get());
attach_chat_deltas(reply, first_result.get());
writer->Write(reply);
}
} else {
std::string completion_text = first_res_json.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (first_res_json.contains("timings")) {
double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(first_res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
auto reply = build_reply_from_json(first_res_json, first_result.get());
attach_chat_deltas(reply, first_result.get());
writer->Write(reply);
}
// Process subsequent results
while (rd.has_next()) {
// Check if context is cancelled before processing result
if (context->IsCancelled()) {
break;
}
auto result = rd.next([&context]() { return context->IsCancelled(); });
if (result == nullptr) {
// connection is closed
break;
}
json res_json = result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
std::string completion_text = res.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res.contains("timings")) {
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
auto reply = build_reply_from_json(res, result.get());
attach_chat_deltas(reply, result.get());
writer->Write(reply);
}
} else {
std::string completion_text = res_json.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res_json.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res_json.contains("timings")) {
double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
auto reply = build_reply_from_json(res_json, result.get());
attach_chat_deltas(reply, result.get());
writer->Write(reply);
}
}
@@ -2264,7 +2264,8 @@ public:
std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl;
if (all_results.results.size() == 1) {
// single result
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
GGML_ASSERT(final_res != nullptr);
json result_json = all_results.results[0]->to_json();
reply->set_message(result_json.value("content", ""));
@@ -2287,6 +2288,11 @@ public:
reply->set_logprobs(logprobs_str);
}
// Populate chat deltas from the autoparser's final parsed message
if (final_res->is_updated) {
populate_chat_deltas_from_final(*reply, final_res->oaicompat_msg);
}
} else {
// multiple results (multitask)
json arr = json::array();
@@ -2609,6 +2615,113 @@ public:
response->set_rendered_template(rendered_template);
// Run differential template analysis to detect tool format markers
if (params_base.use_jinja) {
try {
// Get template source and reconstruct a common_chat_template for analysis
std::string tmpl_src = common_chat_templates_source(ctx_server.impl->chat_params.tmpls.get());
if (!tmpl_src.empty()) {
const auto * vocab = llama_model_get_vocab(ctx_server.impl->model);
std::string token_bos, token_eos;
if (vocab) {
auto bos_id = llama_vocab_bos(vocab);
auto eos_id = llama_vocab_eos(vocab);
if (bos_id != LLAMA_TOKEN_NULL) {
token_bos = common_token_to_piece(vocab, bos_id, true);
}
if (eos_id != LLAMA_TOKEN_NULL) {
token_eos = common_token_to_piece(vocab, eos_id, true);
}
}
common_chat_template tmpl(tmpl_src, token_bos, token_eos);
struct autoparser::autoparser ap;
ap.analyze_template(tmpl);
if (ap.analysis_complete && ap.tools.format.mode != autoparser::tool_format::NONE) {
auto * tf = response->mutable_tool_format();
// Format type
switch (ap.tools.format.mode) {
case autoparser::tool_format::JSON_NATIVE:
tf->set_format_type("json_native");
break;
case autoparser::tool_format::TAG_WITH_JSON:
tf->set_format_type("tag_with_json");
break;
case autoparser::tool_format::TAG_WITH_TAGGED:
tf->set_format_type("tag_with_tagged");
break;
default:
break;
}
// Tool section markers
tf->set_section_start(ap.tools.format.section_start);
tf->set_section_end(ap.tools.format.section_end);
tf->set_per_call_start(ap.tools.format.per_call_start);
tf->set_per_call_end(ap.tools.format.per_call_end);
// Function markers
tf->set_func_name_prefix(ap.tools.function.name_prefix);
tf->set_func_name_suffix(ap.tools.function.name_suffix);
tf->set_func_close(ap.tools.function.close);
// Argument markers
tf->set_arg_name_prefix(ap.tools.arguments.name_prefix);
tf->set_arg_name_suffix(ap.tools.arguments.name_suffix);
tf->set_arg_value_prefix(ap.tools.arguments.value_prefix);
tf->set_arg_value_suffix(ap.tools.arguments.value_suffix);
tf->set_arg_separator(ap.tools.arguments.separator);
tf->set_args_start(ap.tools.arguments.start);
tf->set_args_end(ap.tools.arguments.end);
// JSON format fields
tf->set_name_field(ap.tools.format.name_field);
tf->set_args_field(ap.tools.format.args_field);
tf->set_id_field(ap.tools.format.id_field);
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts);
tf->set_function_field(ap.tools.format.function_field);
tf->set_gen_id_field(ap.tools.format.gen_id_field);
for (const auto & p : ap.tools.format.parameter_order) {
tf->add_parameter_order(p);
}
// Call ID markers
switch (ap.tools.call_id.pos) {
case autoparser::call_id_position::NONE:
tf->set_call_id_position("none");
break;
case autoparser::call_id_position::PRE_FUNC_NAME:
tf->set_call_id_position("pre_func_name");
break;
case autoparser::call_id_position::BETWEEN_FUNC_AND_ARGS:
tf->set_call_id_position("between_func_and_args");
break;
case autoparser::call_id_position::POST_ARGS:
tf->set_call_id_position("post_args");
break;
}
tf->set_call_id_prefix(ap.tools.call_id.prefix);
tf->set_call_id_suffix(ap.tools.call_id.suffix);
// Reasoning markers
tf->set_reasoning_start(ap.reasoning.start);
tf->set_reasoning_end(ap.reasoning.end);
// Content markers
tf->set_content_start(ap.content.start);
tf->set_content_end(ap.content.end);
}
}
} catch (const std::exception & e) {
SRV_WRN("ModelMetadata: failed to run autoparser analysis: %s\n", e.what());
}
}
return grpc::Status::OK;
}
};