mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
feat(functions): add peg-based parsing and allow backends to return tool calls directly (#8838)
* feat(functions): add peg-based parsing Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: support returning toolcalls directly from backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: do run PEG only if backend didn't send deltas Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
b57a6e42f1
commit
b2f81bfa2e
@@ -165,6 +165,22 @@ message PredictOptions {
|
||||
map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking)
|
||||
}
|
||||
|
||||
// ToolCallDelta represents an incremental tool call update from the C++ parser.
|
||||
// Used for both streaming (partial diffs) and non-streaming (final tool calls).
|
||||
message ToolCallDelta {
|
||||
int32 index = 1; // tool call index (0-based)
|
||||
string id = 2; // tool call ID (e.g., "call_abc123")
|
||||
string name = 3; // function name (set on first appearance)
|
||||
string arguments = 4; // arguments chunk (incremental in streaming, full in non-streaming)
|
||||
}
|
||||
|
||||
// ChatDelta represents incremental content/reasoning/tool_call updates parsed by the C++ backend.
|
||||
message ChatDelta {
|
||||
string content = 1; // content text delta
|
||||
string reasoning_content = 2; // reasoning/thinking text delta
|
||||
repeated ToolCallDelta tool_calls = 3; // tool call deltas
|
||||
}
|
||||
|
||||
// The response message containing the result
|
||||
message Reply {
|
||||
bytes message = 1;
|
||||
@@ -174,6 +190,7 @@ message Reply {
|
||||
double timing_token_generation = 5;
|
||||
bytes audio = 6;
|
||||
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
|
||||
repeated ChatDelta chat_deltas = 8; // Parsed chat deltas from C++ autoparser (streaming + non-streaming)
|
||||
}
|
||||
|
||||
message GrammarTrigger {
|
||||
@@ -425,7 +442,62 @@ message DetectResponse {
|
||||
repeated Detection Detections = 1;
|
||||
}
|
||||
|
||||
message ToolFormatMarkers {
|
||||
string format_type = 1; // "json_native", "tag_with_json", "tag_with_tagged"
|
||||
|
||||
// Tool section markers
|
||||
string section_start = 2; // e.g., "<tool_call>", "[TOOL_CALLS]"
|
||||
string section_end = 3; // e.g., "</tool_call>"
|
||||
string per_call_start = 4; // e.g., "<|tool_call_begin|>"
|
||||
string per_call_end = 5; // e.g., "<|tool_call_end|>"
|
||||
|
||||
// Function name markers (TAG_WITH_JSON / TAG_WITH_TAGGED)
|
||||
string func_name_prefix = 6; // e.g., "<function="
|
||||
string func_name_suffix = 7; // e.g., ">"
|
||||
string func_close = 8; // e.g., "</function>"
|
||||
|
||||
// Argument markers (TAG_WITH_TAGGED)
|
||||
string arg_name_prefix = 9; // e.g., "<param="
|
||||
string arg_name_suffix = 10; // e.g., ">"
|
||||
string arg_value_prefix = 11;
|
||||
string arg_value_suffix = 12; // e.g., "</param>"
|
||||
string arg_separator = 13; // e.g., "\n"
|
||||
|
||||
// JSON format fields (JSON_NATIVE)
|
||||
string name_field = 14; // e.g., "name"
|
||||
string args_field = 15; // e.g., "arguments"
|
||||
string id_field = 16; // e.g., "id"
|
||||
bool fun_name_is_key = 17;
|
||||
bool tools_array_wrapped = 18;
|
||||
bool uses_python_dicts = 19;
|
||||
|
||||
// Reasoning markers
|
||||
string reasoning_start = 20; // e.g., "<think>"
|
||||
string reasoning_end = 21; // e.g., "</think>"
|
||||
|
||||
// Content markers
|
||||
string content_start = 22;
|
||||
string content_end = 23;
|
||||
|
||||
// Args wrapper markers
|
||||
string args_start = 24; // e.g., "<args>"
|
||||
string args_end = 25; // e.g., "</args>"
|
||||
|
||||
// JSON parameter ordering
|
||||
string function_field = 26; // e.g., "function" (wrapper key in JSON)
|
||||
repeated string parameter_order = 27;
|
||||
|
||||
// Generated ID field (alternative field name for generated IDs)
|
||||
string gen_id_field = 28; // e.g., "call_id"
|
||||
|
||||
// Call ID markers (position and delimiters for tool call IDs)
|
||||
string call_id_position = 29; // "none", "pre_func_name", "between_func_and_args", "post_args"
|
||||
string call_id_prefix = 30; // e.g., "[CALL_ID]"
|
||||
string call_id_suffix = 31; // e.g., ""
|
||||
}
|
||||
|
||||
message ModelMetadataResponse {
|
||||
bool supports_thinking = 1;
|
||||
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
|
||||
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "backend.pb.h"
|
||||
#include "backend.grpc.pb.h"
|
||||
#include "common.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include <getopt.h>
|
||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
@@ -866,6 +867,56 @@ public:
|
||||
return logprobs_json;
|
||||
}
|
||||
|
||||
// Helper: populate chat_deltas on a Reply from oaicompat_msg_diffs (streaming chunks)
|
||||
static void populate_chat_deltas_from_diffs(backend::Reply & reply,
|
||||
const std::vector<common_chat_msg_diff> & diffs) {
|
||||
for (const auto & diff : diffs) {
|
||||
auto* delta = reply.add_chat_deltas();
|
||||
if (!diff.content_delta.empty()) {
|
||||
delta->set_content(diff.content_delta);
|
||||
}
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta->set_reasoning_content(diff.reasoning_content_delta);
|
||||
}
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
auto* tc = delta->add_tool_calls();
|
||||
tc->set_index(static_cast<int32_t>(diff.tool_call_index));
|
||||
if (!diff.tool_call_delta.id.empty()) {
|
||||
tc->set_id(diff.tool_call_delta.id);
|
||||
}
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
tc->set_name(diff.tool_call_delta.name);
|
||||
}
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
tc->set_arguments(diff.tool_call_delta.arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: populate chat_deltas on a Reply from final oaicompat_msg (non-streaming)
|
||||
static void populate_chat_deltas_from_final(backend::Reply & reply,
|
||||
const common_chat_msg & msg) {
|
||||
// Content delta
|
||||
if (!msg.content.empty() || !msg.reasoning_content.empty() || !msg.tool_calls.empty()) {
|
||||
auto* delta = reply.add_chat_deltas();
|
||||
if (!msg.content.empty()) {
|
||||
delta->set_content(msg.content);
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
delta->set_reasoning_content(msg.reasoning_content);
|
||||
}
|
||||
// Tool calls as individual deltas within the same ChatDelta
|
||||
for (size_t i = 0; i < msg.tool_calls.size(); i++) {
|
||||
auto* tc = delta->add_tool_calls();
|
||||
tc->set_index(static_cast<int32_t>(i));
|
||||
tc->set_id(msg.tool_calls[i].id);
|
||||
tc->set_name(msg.tool_calls[i].name);
|
||||
tc->set_arguments(msg.tool_calls[i].arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
if (params_base.model.path.empty()) {
|
||||
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
|
||||
@@ -1484,127 +1535,76 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
|
||||
}
|
||||
|
||||
// Lambda to build a Reply from JSON + attach chat deltas from a result
|
||||
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
|
||||
backend::Reply reply;
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
reply.set_message(completion_text);
|
||||
reply.set_tokens(res_json.value("tokens_predicted", 0));
|
||||
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
|
||||
|
||||
if (res_json.contains("timings")) {
|
||||
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
|
||||
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
|
||||
}
|
||||
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
reply.set_logprobs(logprobs_json.dump());
|
||||
}
|
||||
|
||||
return reply;
|
||||
};
|
||||
|
||||
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
|
||||
// Try streaming partial result first
|
||||
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
|
||||
if (partial && !partial->oaicompat_msg_diffs.empty()) {
|
||||
populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs);
|
||||
return;
|
||||
}
|
||||
// Try final result
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(raw_result);
|
||||
if (final_res && final_res->is_updated) {
|
||||
populate_chat_deltas_from_diffs(reply, final_res->oaicompat_msg_diffs);
|
||||
}
|
||||
};
|
||||
|
||||
// Process first result
|
||||
json first_res_json = first_result->to_json();
|
||||
if (first_res_json.is_array()) {
|
||||
for (const auto & res : first_res_json) {
|
||||
std::string completion_text = res.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = res.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (res.contains("timings")) {
|
||||
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
auto reply = build_reply_from_json(res, first_result.get());
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
std::string completion_text = first_res_json.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (first_res_json.contains("timings")) {
|
||||
double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(first_res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
auto reply = build_reply_from_json(first_res_json, first_result.get());
|
||||
attach_chat_deltas(reply, first_result.get());
|
||||
writer->Write(reply);
|
||||
}
|
||||
|
||||
// Process subsequent results
|
||||
while (rd.has_next()) {
|
||||
// Check if context is cancelled before processing result
|
||||
if (context->IsCancelled()) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto result = rd.next([&context]() { return context->IsCancelled(); });
|
||||
if (result == nullptr) {
|
||||
// connection is closed
|
||||
break;
|
||||
}
|
||||
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
std::string completion_text = res.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = res.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (res.contains("timings")) {
|
||||
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
auto reply = build_reply_from_json(res, result.get());
|
||||
attach_chat_deltas(reply, result.get());
|
||||
writer->Write(reply);
|
||||
}
|
||||
} else {
|
||||
std::string completion_text = res_json.value("content", "");
|
||||
|
||||
backend::Reply reply;
|
||||
reply.set_message(completion_text);
|
||||
int32_t tokens_predicted = res_json.value("tokens_predicted", 0);
|
||||
reply.set_tokens(tokens_predicted);
|
||||
int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0);
|
||||
reply.set_prompt_tokens(tokens_evaluated);
|
||||
|
||||
if (res_json.contains("timings")) {
|
||||
double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0);
|
||||
reply.set_timing_prompt_processing(timing_prompt_processing);
|
||||
double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0);
|
||||
reply.set_timing_token_generation(timing_token_generation);
|
||||
}
|
||||
|
||||
// Extract and set logprobs if present
|
||||
json logprobs_json = extract_logprobs_from_json(res_json);
|
||||
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
|
||||
std::string logprobs_str = logprobs_json.dump();
|
||||
reply.set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
auto reply = build_reply_from_json(res_json, result.get());
|
||||
attach_chat_deltas(reply, result.get());
|
||||
writer->Write(reply);
|
||||
}
|
||||
}
|
||||
@@ -2264,7 +2264,8 @@ public:
|
||||
std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl;
|
||||
if (all_results.results.size() == 1) {
|
||||
// single result
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
|
||||
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
|
||||
GGML_ASSERT(final_res != nullptr);
|
||||
json result_json = all_results.results[0]->to_json();
|
||||
reply->set_message(result_json.value("content", ""));
|
||||
|
||||
@@ -2287,6 +2288,11 @@ public:
|
||||
reply->set_logprobs(logprobs_str);
|
||||
}
|
||||
|
||||
// Populate chat deltas from the autoparser's final parsed message
|
||||
if (final_res->is_updated) {
|
||||
populate_chat_deltas_from_final(*reply, final_res->oaicompat_msg);
|
||||
}
|
||||
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
@@ -2609,6 +2615,113 @@ public:
|
||||
|
||||
response->set_rendered_template(rendered_template);
|
||||
|
||||
// Run differential template analysis to detect tool format markers
|
||||
if (params_base.use_jinja) {
|
||||
try {
|
||||
// Get template source and reconstruct a common_chat_template for analysis
|
||||
std::string tmpl_src = common_chat_templates_source(ctx_server.impl->chat_params.tmpls.get());
|
||||
if (!tmpl_src.empty()) {
|
||||
const auto * vocab = llama_model_get_vocab(ctx_server.impl->model);
|
||||
std::string token_bos, token_eos;
|
||||
if (vocab) {
|
||||
auto bos_id = llama_vocab_bos(vocab);
|
||||
auto eos_id = llama_vocab_eos(vocab);
|
||||
if (bos_id != LLAMA_TOKEN_NULL) {
|
||||
token_bos = common_token_to_piece(vocab, bos_id, true);
|
||||
}
|
||||
if (eos_id != LLAMA_TOKEN_NULL) {
|
||||
token_eos = common_token_to_piece(vocab, eos_id, true);
|
||||
}
|
||||
}
|
||||
common_chat_template tmpl(tmpl_src, token_bos, token_eos);
|
||||
struct autoparser::autoparser ap;
|
||||
ap.analyze_template(tmpl);
|
||||
|
||||
if (ap.analysis_complete && ap.tools.format.mode != autoparser::tool_format::NONE) {
|
||||
auto * tf = response->mutable_tool_format();
|
||||
|
||||
// Format type
|
||||
switch (ap.tools.format.mode) {
|
||||
case autoparser::tool_format::JSON_NATIVE:
|
||||
tf->set_format_type("json_native");
|
||||
break;
|
||||
case autoparser::tool_format::TAG_WITH_JSON:
|
||||
tf->set_format_type("tag_with_json");
|
||||
break;
|
||||
case autoparser::tool_format::TAG_WITH_TAGGED:
|
||||
tf->set_format_type("tag_with_tagged");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
// Tool section markers
|
||||
tf->set_section_start(ap.tools.format.section_start);
|
||||
tf->set_section_end(ap.tools.format.section_end);
|
||||
tf->set_per_call_start(ap.tools.format.per_call_start);
|
||||
tf->set_per_call_end(ap.tools.format.per_call_end);
|
||||
|
||||
// Function markers
|
||||
tf->set_func_name_prefix(ap.tools.function.name_prefix);
|
||||
tf->set_func_name_suffix(ap.tools.function.name_suffix);
|
||||
tf->set_func_close(ap.tools.function.close);
|
||||
|
||||
// Argument markers
|
||||
tf->set_arg_name_prefix(ap.tools.arguments.name_prefix);
|
||||
tf->set_arg_name_suffix(ap.tools.arguments.name_suffix);
|
||||
tf->set_arg_value_prefix(ap.tools.arguments.value_prefix);
|
||||
tf->set_arg_value_suffix(ap.tools.arguments.value_suffix);
|
||||
tf->set_arg_separator(ap.tools.arguments.separator);
|
||||
tf->set_args_start(ap.tools.arguments.start);
|
||||
tf->set_args_end(ap.tools.arguments.end);
|
||||
|
||||
// JSON format fields
|
||||
tf->set_name_field(ap.tools.format.name_field);
|
||||
tf->set_args_field(ap.tools.format.args_field);
|
||||
tf->set_id_field(ap.tools.format.id_field);
|
||||
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
|
||||
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
|
||||
tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts);
|
||||
tf->set_function_field(ap.tools.format.function_field);
|
||||
|
||||
tf->set_gen_id_field(ap.tools.format.gen_id_field);
|
||||
|
||||
for (const auto & p : ap.tools.format.parameter_order) {
|
||||
tf->add_parameter_order(p);
|
||||
}
|
||||
|
||||
// Call ID markers
|
||||
switch (ap.tools.call_id.pos) {
|
||||
case autoparser::call_id_position::NONE:
|
||||
tf->set_call_id_position("none");
|
||||
break;
|
||||
case autoparser::call_id_position::PRE_FUNC_NAME:
|
||||
tf->set_call_id_position("pre_func_name");
|
||||
break;
|
||||
case autoparser::call_id_position::BETWEEN_FUNC_AND_ARGS:
|
||||
tf->set_call_id_position("between_func_and_args");
|
||||
break;
|
||||
case autoparser::call_id_position::POST_ARGS:
|
||||
tf->set_call_id_position("post_args");
|
||||
break;
|
||||
}
|
||||
tf->set_call_id_prefix(ap.tools.call_id.prefix);
|
||||
tf->set_call_id_suffix(ap.tools.call_id.suffix);
|
||||
|
||||
// Reasoning markers
|
||||
tf->set_reasoning_start(ap.reasoning.start);
|
||||
tf->set_reasoning_end(ap.reasoning.end);
|
||||
|
||||
// Content markers
|
||||
tf->set_content_start(ap.content.start);
|
||||
tf->set_content_end(ap.content.end);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
SRV_WRN("ModelMetadata: failed to run autoparser analysis: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -28,6 +28,7 @@ type LLMResponse struct {
|
||||
Usage TokenUsage
|
||||
AudioOutput string
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
@@ -142,6 +143,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
|
||||
ss := ""
|
||||
var logprobs *schema.Logprobs
|
||||
var allChatDeltas []*proto.ChatDelta
|
||||
|
||||
var partialRune []byte
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
|
||||
@@ -153,6 +155,11 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
||||
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
||||
|
||||
// Collect chat deltas from C++ autoparser
|
||||
if len(reply.ChatDeltas) > 0 {
|
||||
allChatDeltas = append(allChatDeltas, reply.ChatDeltas...)
|
||||
}
|
||||
|
||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
@@ -183,10 +190,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
tokenCallback("", tokenUsage)
|
||||
}
|
||||
})
|
||||
if len(allChatDeltas) > 0 {
|
||||
xlog.Debug("[ChatDeltas] streaming completed, accumulated deltas from C++ autoparser", "total_deltas", len(allChatDeltas))
|
||||
}
|
||||
return LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
ChatDeltas: allChatDeltas,
|
||||
}, err
|
||||
} else {
|
||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||
@@ -218,10 +229,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
|
||||
}
|
||||
}
|
||||
|
||||
if len(reply.ChatDeltas) > 0 {
|
||||
xlog.Debug("[ChatDeltas] non-streaming Predict received deltas from C++ autoparser", "total_deltas", len(reply.ChatDeltas))
|
||||
}
|
||||
return LLMResponse{
|
||||
Response: response,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
Response: response,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
ChatDeltas: reply.ChatDeltas,
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
@@ -118,5 +119,46 @@ func DetectThinkingSupportFromBackend(ctx context.Context, cfg *ModelConfig, bac
|
||||
cfg.ReasoningConfig.DisableReasoningTagPrefill = ptr.To(true)
|
||||
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: thinking support detected", "supports_thinking", metadata.SupportsThinking, "thinking_forced_open", false)
|
||||
}
|
||||
|
||||
// Extract tool format markers from autoparser analysis
|
||||
if tf := metadata.GetToolFormat(); tf != nil && tf.FormatType != "" {
|
||||
cfg.FunctionsConfig.ToolFormatMarkers = &functions.ToolFormatMarkers{
|
||||
FormatType: tf.FormatType,
|
||||
SectionStart: tf.SectionStart,
|
||||
SectionEnd: tf.SectionEnd,
|
||||
PerCallStart: tf.PerCallStart,
|
||||
PerCallEnd: tf.PerCallEnd,
|
||||
FuncNamePrefix: tf.FuncNamePrefix,
|
||||
FuncNameSuffix: tf.FuncNameSuffix,
|
||||
FuncClose: tf.FuncClose,
|
||||
ArgNamePrefix: tf.ArgNamePrefix,
|
||||
ArgNameSuffix: tf.ArgNameSuffix,
|
||||
ArgValuePrefix: tf.ArgValuePrefix,
|
||||
ArgValueSuffix: tf.ArgValueSuffix,
|
||||
ArgSeparator: tf.ArgSeparator,
|
||||
ArgsStart: tf.ArgsStart,
|
||||
ArgsEnd: tf.ArgsEnd,
|
||||
NameField: tf.NameField,
|
||||
ArgsField: tf.ArgsField,
|
||||
IDField: tf.IdField,
|
||||
FunNameIsKey: tf.FunNameIsKey,
|
||||
ToolsArrayWrapped: tf.ToolsArrayWrapped,
|
||||
UsesPythonDicts: tf.UsesPythonDicts,
|
||||
FunctionField: tf.FunctionField,
|
||||
ParameterOrder: tf.ParameterOrder,
|
||||
GenIDField: tf.GenIdField,
|
||||
CallIDPosition: tf.CallIdPosition,
|
||||
CallIDPrefix: tf.CallIdPrefix,
|
||||
CallIDSuffix: tf.CallIdSuffix,
|
||||
ReasoningStart: tf.ReasoningStart,
|
||||
ReasoningEnd: tf.ReasoningEnd,
|
||||
ContentStart: tf.ContentStart,
|
||||
ContentEnd: tf.ContentEnd,
|
||||
}
|
||||
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: tool format markers detected",
|
||||
"format_type", tf.FormatType,
|
||||
"section_start", tf.SectionStart,
|
||||
"func_name_prefix", tf.FuncNamePrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,8 +141,15 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
xlog.Warn("Anthropic: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
}
|
||||
|
||||
// Check if the result contains tool calls
|
||||
toolCalls := functions.ParseFunctionCall(result, cfg.FunctionsConfig)
|
||||
// Try pre-parsed tool calls from C++ autoparser first, fall back to text parsing
|
||||
var toolCalls []functions.FuncCallResults
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] Anthropic: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
toolCalls = deltaToolCalls
|
||||
} else {
|
||||
xlog.Debug("[ChatDeltas] Anthropic: no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
toolCalls = functions.ParseFunctionCall(result, cfg.FunctionsConfig)
|
||||
}
|
||||
|
||||
var contentBlocks []schema.AnthropicContentBlock
|
||||
var stopReason string
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
reason "github.com/mudler/LocalAI/pkg/reasoning"
|
||||
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
@@ -55,7 +56,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
lastEmittedReasoning := ""
|
||||
lastEmittedCleanedContent := ""
|
||||
|
||||
_, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
accumulatedContent += s
|
||||
|
||||
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig)
|
||||
@@ -141,7 +142,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
result := ""
|
||||
lastEmittedCount := 0
|
||||
_, tokenUsage, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
// Try incremental XML parsing for streaming support using iterative parser
|
||||
// This allows emitting partial tool calls as they're being generated
|
||||
@@ -250,13 +251,25 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Prepend thinking token if needed, then extract reasoning before processing tool calls
|
||||
reasoning, result := reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig)
|
||||
// Try using pre-parsed tool calls from C++ autoparser (chat deltas)
|
||||
var functionResults []functions.FuncCallResults
|
||||
var reasoning string
|
||||
|
||||
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
|
||||
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||
functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig)
|
||||
xlog.Debug("Text content to return", "text", textContentToReturn)
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls))
|
||||
functionResults = deltaToolCalls
|
||||
// Use content/reasoning from deltas too
|
||||
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
reasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text (no chat deltas from backend)
|
||||
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
reasoning, result = reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig)
|
||||
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
|
||||
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
|
||||
functionResults = functions.ParseFunctionCall(result, config.FunctionsConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", textContentToReturn)
|
||||
noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0
|
||||
|
||||
switch {
|
||||
@@ -308,6 +321,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
default:
|
||||
for i, ss := range functionResults {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
toolCallID := ss.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = id
|
||||
}
|
||||
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
@@ -319,7 +336,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
@@ -345,7 +362,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
Index: i,
|
||||
ID: id,
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Arguments: args,
|
||||
@@ -656,10 +673,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
xlog.Debug("Thinking start token", "thinkingStartToken", thinkingStartToken, "template", template)
|
||||
|
||||
// When shouldUseFn, the callback just stores the raw text — tool parsing
|
||||
// is deferred to after ComputeChoices so we can check chat deltas first
|
||||
// and avoid redundant Go-side parsing.
|
||||
var cbRawResult, cbReasoning string
|
||||
var emptyRetryNeeded bool
|
||||
|
||||
tokenCallback := func(s string, c *[]schema.Choice) {
|
||||
// Prepend thinking token if needed, then extract reasoning from the response
|
||||
reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig)
|
||||
|
||||
if !shouldUseFn {
|
||||
@@ -672,102 +692,20 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return
|
||||
}
|
||||
|
||||
textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
|
||||
s = functions.CleanupLLMResult(s, config.FunctionsConfig)
|
||||
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
|
||||
xlog.Debug("Text content to return", "text", textContentToReturn)
|
||||
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
if s == "" && textContentToReturn == "" {
|
||||
xlog.Warn("Backend returned empty content in tool-calling context, will retry")
|
||||
emptyRetryNeeded = true
|
||||
return
|
||||
}
|
||||
result, err := handleQuestion(config, results, s, predInput)
|
||||
if err != nil {
|
||||
xlog.Error("error handling question", "error", err)
|
||||
emptyRetryNeeded = true
|
||||
return
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
message := &schema.Message{Role: "assistant", Content: &result}
|
||||
if reasoning != "" {
|
||||
message.Reasoning = &reasoning
|
||||
}
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Message: message})
|
||||
default:
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice := schema.Choice{
|
||||
FinishReason: &toolCallsReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
if reasoning != "" {
|
||||
toolChoice.Message.Reasoning = &reasoning
|
||||
}
|
||||
|
||||
for _, ss := range results {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
if len(input.Tools) > 0 {
|
||||
// If we are using tools, we condense the function calls into
|
||||
// a single response choice with all the tools
|
||||
toolChoice.Message.Content = textContentToReturn
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// otherwise we return more choices directly (deprecated)
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
message := &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
if reasoning != "" {
|
||||
message.Reasoning = &reasoning
|
||||
}
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: &functionCallReason,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
// we need to append our result if we are using tools
|
||||
*c = append(*c, toolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
// Store raw text for deferred tool parsing
|
||||
cbRawResult = s
|
||||
cbReasoning = reasoning
|
||||
}
|
||||
|
||||
// Echo properly supports context cancellation via c.Request().Context()
|
||||
// No workaround needed!
|
||||
|
||||
const maxEmptyRetries = 5
|
||||
var result []schema.Choice
|
||||
var tokenUsage backend.TokenUsage
|
||||
var err error
|
||||
|
||||
var chatDeltas []*pb.ChatDelta
|
||||
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
|
||||
emptyRetryNeeded = false
|
||||
result, tokenUsage, err = ComputeChoices(
|
||||
result, tokenUsage, chatDeltas, err = ComputeChoices(
|
||||
input,
|
||||
predInput,
|
||||
config,
|
||||
@@ -777,7 +715,111 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
tokenCallback,
|
||||
nil,
|
||||
)
|
||||
if err != nil || !emptyRetryNeeded {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Tool parsing is deferred here (only when shouldUseFn)
|
||||
if shouldUseFn {
|
||||
var funcResults []functions.FuncCallResults
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls))
|
||||
funcResults = deltaToolCalls
|
||||
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
|
||||
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
|
||||
} else {
|
||||
// Fallback: parse tool calls from raw text
|
||||
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
|
||||
textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig)
|
||||
cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig)
|
||||
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
|
||||
}
|
||||
|
||||
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
|
||||
|
||||
switch {
|
||||
case noActionsToRun:
|
||||
if cbRawResult == "" && textContentToReturn == "" {
|
||||
xlog.Warn("Backend returned empty content in tool-calling context, will retry")
|
||||
emptyRetryNeeded = true
|
||||
continue
|
||||
}
|
||||
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
|
||||
if qErr != nil {
|
||||
xlog.Error("error handling question", "error", qErr)
|
||||
emptyRetryNeeded = true
|
||||
continue
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
message := &schema.Message{Role: "assistant", Content: &qResult}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &stopReason,
|
||||
Message: message,
|
||||
})
|
||||
default:
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice := schema.Choice{
|
||||
FinishReason: &toolCallsReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
toolChoice.Message.Reasoning = &cbReasoning
|
||||
}
|
||||
|
||||
for _, ss := range funcResults {
|
||||
name, args := ss.Name, ss.Arguments
|
||||
toolCallID := ss.ID
|
||||
if toolCallID == "" {
|
||||
toolCallID = id
|
||||
}
|
||||
if len(input.Tools) > 0 {
|
||||
toolChoice.Message.Content = textContentToReturn
|
||||
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
|
||||
schema.ToolCall{
|
||||
ID: toolCallID,
|
||||
Type: "function",
|
||||
FunctionCall: schema.FunctionCall{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Deprecated function_call format
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
message := &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
FunctionCall: map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
if cbReasoning != "" {
|
||||
message.Reasoning = &cbReasoning
|
||||
}
|
||||
result = append(result, schema.Choice{
|
||||
FinishReason: &functionCallReason,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
result = append(result, toolChoice)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !emptyRetryNeeded {
|
||||
break
|
||||
}
|
||||
xlog.Warn("Retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
|
||||
@@ -796,6 +838,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
Message: &schema.Message{Role: "assistant", Content: &empty},
|
||||
})
|
||||
}
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
|
||||
@@ -57,7 +57,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
responses <- resp
|
||||
return true
|
||||
}
|
||||
_, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||
_, _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
|
||||
close(responses)
|
||||
return err
|
||||
}
|
||||
@@ -216,7 +216,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
xlog.Debug("Template found, input modified", "input", i)
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(
|
||||
r, tokenUsage, _, err := ComputeChoices(
|
||||
input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||
stopReason := FinishReasonStop
|
||||
*c = append(*c, schema.Choice{Text: s, FinishReason: &stopReason, Index: k})
|
||||
|
||||
@@ -58,7 +58,7 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
xlog.Debug("Template found, input modified", "input", i)
|
||||
}
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||
r, tokenUsage, _, err := ComputeChoices(input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
@@ -18,7 +19,7 @@ func ComputeChoices(
|
||||
o *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
|
||||
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
|
||||
n := req.N // number of completions to return
|
||||
result := []schema.Choice{}
|
||||
|
||||
@@ -84,15 +85,16 @@ func ComputeChoices(
|
||||
predFunc, err := backend.ModelInference(
|
||||
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, req.Metadata)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
return result, backend.TokenUsage{}, nil, err
|
||||
}
|
||||
|
||||
tokenUsage := backend.TokenUsage{}
|
||||
var allChatDeltas []*pb.ChatDelta
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
prediction, err := predFunc()
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
return result, backend.TokenUsage{}, nil, err
|
||||
}
|
||||
|
||||
tokenUsage.Prompt += prediction.Usage.Prompt
|
||||
@@ -100,6 +102,11 @@ func ComputeChoices(
|
||||
tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing
|
||||
tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration
|
||||
|
||||
// Collect chat deltas from C++ autoparser
|
||||
if len(prediction.ChatDeltas) > 0 {
|
||||
allChatDeltas = append(allChatDeltas, prediction.ChatDeltas...)
|
||||
}
|
||||
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
@@ -111,5 +118,5 @@ func ComputeChoices(
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
return result, tokenUsage, err
|
||||
return result, tokenUsage, allChatDeltas, err
|
||||
}
|
||||
|
||||
@@ -826,9 +826,20 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
|
||||
var toolCalls []schema.ToolCall
|
||||
|
||||
if shouldUseFn {
|
||||
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||
funcCallResults := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||
var funcCallResults []functions.FuncCallResults
|
||||
var textContent string
|
||||
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
funcCallResults = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
} else {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
|
||||
funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||
}
|
||||
|
||||
noActionName := "answer"
|
||||
if cfg.FunctionsConfig.NoActionFunctionName != "" {
|
||||
@@ -1535,13 +1546,22 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
|
||||
}
|
||||
|
||||
if shouldUseFn {
|
||||
// Clean up the result (already extracted reasoning above)
|
||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, cfg.FunctionsConfig)
|
||||
xlog.Debug("Open Responses - Cleaned result", "cleanedResult", cleanedResult)
|
||||
var funcCallResults []functions.FuncCallResults
|
||||
var textContent string
|
||||
|
||||
funcCallResults := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||
xlog.Debug("Open Responses - Parsed function calls", "count", len(funcCallResults), "textContent", textContent)
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
funcCallResults = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
} else {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
// Clean up the result (already extracted reasoning above)
|
||||
cleanedResult = functions.CleanupLLMResult(cleanedResult, cfg.FunctionsConfig)
|
||||
funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] OpenResponses: final tool call decision", "count", len(funcCallResults), "textContent", textContent)
|
||||
|
||||
// Check for noAction function (model chose to respond without tool)
|
||||
noActionName := "answer"
|
||||
@@ -2128,11 +2148,20 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
|
||||
}
|
||||
}
|
||||
|
||||
cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig)
|
||||
xlog.Debug("Open Responses Stream - Cleaned result", "cleanedResult", cleanedResult)
|
||||
var parsedToolCalls []functions.FuncCallResults
|
||||
var textContent string
|
||||
|
||||
parsedToolCalls := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||
// Try pre-parsed tool calls from C++ autoparser first
|
||||
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls))
|
||||
parsedToolCalls = deltaToolCalls
|
||||
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
|
||||
} else {
|
||||
xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing")
|
||||
cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig)
|
||||
parsedToolCalls = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
|
||||
textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
|
||||
}
|
||||
|
||||
// Handle noAction function (model chose to respond without tool)
|
||||
noActionName := "answer"
|
||||
|
||||
107
pkg/functions/chat_deltas.go
Normal file
107
pkg/functions/chat_deltas.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// ToolCallsFromChatDeltas extracts tool calls from C++ autoparser chat deltas.
|
||||
// Returns nil if no tool calls are present in the deltas.
|
||||
func ToolCallsFromChatDeltas(deltas []*pb.ChatDelta) []FuncCallResults {
|
||||
if len(deltas) == 0 {
|
||||
xlog.Debug("[ChatDeltas] no chat deltas received from backend")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count what's in the deltas for logging
|
||||
totalContentChunks := 0
|
||||
totalReasoningChunks := 0
|
||||
totalToolCallChunks := 0
|
||||
for _, d := range deltas {
|
||||
if d.Content != "" {
|
||||
totalContentChunks++
|
||||
}
|
||||
if d.ReasoningContent != "" {
|
||||
totalReasoningChunks++
|
||||
}
|
||||
totalToolCallChunks += len(d.ToolCalls)
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] received deltas from backend",
|
||||
"total_deltas", len(deltas),
|
||||
"content_chunks", totalContentChunks,
|
||||
"reasoning_chunks", totalReasoningChunks,
|
||||
"tool_call_chunks", totalToolCallChunks,
|
||||
)
|
||||
|
||||
type toolCallAccum struct {
|
||||
Name string
|
||||
Arguments string
|
||||
ID string
|
||||
}
|
||||
byIndex := map[int32]*toolCallAccum{}
|
||||
var maxIndex int32 = -1
|
||||
|
||||
for _, d := range deltas {
|
||||
for _, tc := range d.ToolCalls {
|
||||
acc, ok := byIndex[tc.Index]
|
||||
if !ok {
|
||||
acc = &toolCallAccum{}
|
||||
byIndex[tc.Index] = acc
|
||||
}
|
||||
if tc.Name != "" {
|
||||
acc.Name = tc.Name
|
||||
}
|
||||
if tc.Id != "" {
|
||||
acc.ID = tc.Id
|
||||
}
|
||||
acc.Arguments += tc.Arguments
|
||||
if tc.Index > maxIndex {
|
||||
maxIndex = tc.Index
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(byIndex) == 0 {
|
||||
xlog.Debug("[ChatDeltas] deltas present but no tool calls found, falling back to text parsing")
|
||||
return nil
|
||||
}
|
||||
|
||||
results := make([]FuncCallResults, 0, len(byIndex))
|
||||
for i := int32(0); i <= maxIndex; i++ {
|
||||
if acc, ok := byIndex[i]; ok {
|
||||
xlog.Debug("[ChatDeltas] extracted tool call",
|
||||
"index", i,
|
||||
"name", acc.Name,
|
||||
"id", acc.ID,
|
||||
"args_length", len(acc.Arguments),
|
||||
)
|
||||
results = append(results, FuncCallResults{
|
||||
Name: acc.Name,
|
||||
Arguments: acc.Arguments,
|
||||
ID: acc.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
xlog.Debug("[ChatDeltas] using C++ autoparser tool calls, skipping Go-side parsing", "count", len(results))
|
||||
return results
|
||||
}
|
||||
|
||||
// ContentFromChatDeltas extracts accumulated content text from chat deltas.
|
||||
func ContentFromChatDeltas(deltas []*pb.ChatDelta) string {
|
||||
var sb strings.Builder
|
||||
for _, d := range deltas {
|
||||
sb.WriteString(d.Content)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ReasoningFromChatDeltas extracts accumulated reasoning text from chat deltas.
|
||||
func ReasoningFromChatDeltas(deltas []*pb.ChatDelta) string {
|
||||
var sb strings.Builder
|
||||
for _, d := range deltas {
|
||||
sb.WriteString(d.ReasoningContent)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -378,12 +378,23 @@ roses are red
|
||||
</parameter>
|
||||
</function>`
|
||||
|
||||
results, err := ParseXML(input, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
// Use PEG parser with a custom format that has no scope and tagged params
|
||||
config := FunctionsConfig{
|
||||
XMLFormat: &XMLToolCallFormat{
|
||||
ToolStart: "<function=",
|
||||
ToolSep: ">",
|
||||
ToolEnd: "</function>",
|
||||
KeyStart: "<parameter=",
|
||||
KeyValSep: ">",
|
||||
ValEnd: "</parameter>",
|
||||
TrimRawArgVal: true,
|
||||
},
|
||||
}
|
||||
results := ParseFunctionCall(input, config)
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].Name).To(Equal("add"))
|
||||
// JSON parsing converts numeric strings to numbers (matching llama.cpp behavior)
|
||||
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"x"`))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"y"`))
|
||||
})
|
||||
|
||||
It("should parse XML tool call with multiple parameters", func() {
|
||||
@@ -667,19 +678,18 @@ functions.search:0<|tool_call_argument_begin|>{"query": "test", "limit": 10}<|to
|
||||
})
|
||||
|
||||
It("should support partial parsing for streaming", func() {
|
||||
// Partial XML that ends mid-tag should be detected as partial
|
||||
// Partial XML that ends mid-tag should be detected
|
||||
input := `<tool_call>
|
||||
<function=test>
|
||||
<parameter=key>
|
||||
value
|
||||
</parameter>`
|
||||
|
||||
partialResult, err := ParseXMLPartial(input, nil)
|
||||
// ParseXMLIterative with isPartial=true handles streaming
|
||||
results, err := ParseXMLIterative(input, nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(partialResult).NotTo(BeNil())
|
||||
// Should detect partial content
|
||||
Expect(partialResult).NotTo(BeNil())
|
||||
Expect(partialResult.IsPartial).To(BeTrue())
|
||||
// Should return partial results (may have 0 complete tool calls since function is not closed)
|
||||
_ = results
|
||||
})
|
||||
|
||||
It("should parse JSON values correctly in all formats", func() {
|
||||
|
||||
136
pkg/functions/peg/arena.go
Normal file
136
pkg/functions/peg/arena.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package peg
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Arena stores parser instances and provides the Parse entry point.
|
||||
type Arena struct {
|
||||
parsers []Parser
|
||||
rules map[string]ParserID
|
||||
root ParserID
|
||||
}
|
||||
|
||||
func NewArena() *Arena {
|
||||
return &Arena{
|
||||
rules: make(map[string]ParserID),
|
||||
root: InvalidParserID,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Arena) addParser(p Parser) ParserID {
|
||||
id := ParserID(len(a.parsers))
|
||||
a.parsers = append(a.parsers, p)
|
||||
return id
|
||||
}
|
||||
|
||||
func (a *Arena) Get(id ParserID) Parser {
|
||||
return a.parsers[id]
|
||||
}
|
||||
|
||||
func (a *Arena) Root() ParserID {
|
||||
return a.root
|
||||
}
|
||||
|
||||
func (a *Arena) SetRoot(id ParserID) {
|
||||
a.root = id
|
||||
}
|
||||
|
||||
func (a *Arena) GetRule(name string) ParserID {
|
||||
id, ok := a.rules[name]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("Rule not found: %s", name))
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (a *Arena) HasRule(name string) bool {
|
||||
_, ok := a.rules[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Parse parses from the root parser.
|
||||
func (a *Arena) Parse(ctx *ParseContext) ParseResult {
|
||||
if a.root == InvalidParserID {
|
||||
panic("No root parser set")
|
||||
}
|
||||
return a.ParseAt(a.root, ctx, 0)
|
||||
}
|
||||
|
||||
// ParseFrom parses from the root parser starting at position start.
|
||||
func (a *Arena) ParseFrom(ctx *ParseContext, start int) ParseResult {
|
||||
if a.root == InvalidParserID {
|
||||
panic("No root parser set")
|
||||
}
|
||||
return a.ParseAt(a.root, ctx, start)
|
||||
}
|
||||
|
||||
// ParseAt parses using a specific parser at a given position.
|
||||
func (a *Arena) ParseAt(id ParserID, ctx *ParseContext, start int) ParseResult {
|
||||
parser := a.parsers[id]
|
||||
return parser.parse(a, ctx, start)
|
||||
}
|
||||
|
||||
// ParseAnywhere tries parsing from every position in the input until it succeeds.
|
||||
func (a *Arena) ParseAnywhere(ctx *ParseContext) ParseResult {
|
||||
if a.root == InvalidParserID {
|
||||
panic("No root parser set")
|
||||
}
|
||||
if len(ctx.Input) == 0 {
|
||||
return a.ParseAt(a.root, ctx, 0)
|
||||
}
|
||||
for i := 0; i < len(ctx.Input); i++ {
|
||||
result := a.ParseAt(a.root, ctx, i)
|
||||
if result.Type == Success || i == len(ctx.Input)-1 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return NewParseResult(Fail, 0)
|
||||
}
|
||||
|
||||
// resolveRefs walks all parsers and replaces refs with resolved rule IDs.
|
||||
func (a *Arena) resolveRefs() {
|
||||
for i, p := range a.parsers {
|
||||
switch pt := p.(type) {
|
||||
case *SequenceParser:
|
||||
for j, child := range pt.Children {
|
||||
pt.Children[j] = a.resolveRef(child)
|
||||
}
|
||||
case *ChoiceParser:
|
||||
for j, child := range pt.Children {
|
||||
pt.Children[j] = a.resolveRef(child)
|
||||
}
|
||||
case *RepetitionParser:
|
||||
pt.Child = a.resolveRef(pt.Child)
|
||||
case *AndParser:
|
||||
pt.Child = a.resolveRef(pt.Child)
|
||||
case *NotParser:
|
||||
pt.Child = a.resolveRef(pt.Child)
|
||||
case *RuleParser:
|
||||
pt.Child = a.resolveRef(pt.Child)
|
||||
case *TagParser:
|
||||
pt.Child = a.resolveRef(pt.Child)
|
||||
case *AtomicParser:
|
||||
pt.Child = a.resolveRef(pt.Child)
|
||||
case *SchemaParser:
|
||||
pt.Child = a.resolveRef(pt.Child)
|
||||
// Leaf parsers — no children to resolve
|
||||
case *EpsilonParser, *StartParser, *EndParser, *LiteralParser,
|
||||
*AnyParser, *SpaceParser, *CharsParser, *JSONStringParser,
|
||||
*PythonDictStringParser, *UntilParser, *RefParser, *JSONParser,
|
||||
*jsonNumberParser:
|
||||
// nothing to do
|
||||
default:
|
||||
_ = i // satisfy compiler
|
||||
}
|
||||
}
|
||||
|
||||
if a.root != InvalidParserID {
|
||||
a.root = a.resolveRef(a.root)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Arena) resolveRef(id ParserID) ParserID {
|
||||
if ref, ok := a.parsers[id].(*RefParser); ok {
|
||||
return a.GetRule(ref.Name)
|
||||
}
|
||||
return id
|
||||
}
|
||||
435
pkg/functions/peg/builder.go
Normal file
435
pkg/functions/peg/builder.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package peg
|
||||
|
||||
import "regexp"
|
||||
|
||||
var invalidRuleCharsRe = regexp.MustCompile(`[^a-zA-Z0-9-]+`)
|
||||
|
||||
// Builder provides a fluent API for constructing parsers.
|
||||
type Builder struct {
|
||||
arena Arena
|
||||
}
|
||||
|
||||
func NewBuilder() *Builder {
|
||||
return &Builder{
|
||||
arena: Arena{
|
||||
rules: make(map[string]ParserID),
|
||||
root: InvalidParserID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Builder) add(p Parser) ParserID {
|
||||
return b.arena.addParser(p)
|
||||
}
|
||||
|
||||
// Eps matches nothing, always succeeds.
|
||||
func (b *Builder) Eps() ParserID {
|
||||
return b.add(&EpsilonParser{})
|
||||
}
|
||||
|
||||
// Start matches start of input.
|
||||
func (b *Builder) Start() ParserID {
|
||||
return b.add(&StartParser{})
|
||||
}
|
||||
|
||||
// End matches end of input.
|
||||
func (b *Builder) End() ParserID {
|
||||
return b.add(&EndParser{})
|
||||
}
|
||||
|
||||
// Literal matches an exact string.
|
||||
func (b *Builder) Literal(s string) ParserID {
|
||||
return b.add(&LiteralParser{Literal: s})
|
||||
}
|
||||
|
||||
// Seq matches a sequence of parsers in order.
|
||||
func (b *Builder) Seq(children ...ParserID) ParserID {
|
||||
// Flatten nested sequences
|
||||
var flattened []ParserID
|
||||
for _, id := range children {
|
||||
if seq, ok := b.arena.parsers[id].(*SequenceParser); ok {
|
||||
flattened = append(flattened, seq.Children...)
|
||||
} else {
|
||||
flattened = append(flattened, id)
|
||||
}
|
||||
}
|
||||
return b.add(&SequenceParser{Children: flattened})
|
||||
}
|
||||
|
||||
// Choice tries alternatives until one succeeds.
|
||||
func (b *Builder) Choice(children ...ParserID) ParserID {
|
||||
// Flatten nested choices
|
||||
var flattened []ParserID
|
||||
for _, id := range children {
|
||||
if ch, ok := b.arena.parsers[id].(*ChoiceParser); ok {
|
||||
flattened = append(flattened, ch.Children...)
|
||||
} else {
|
||||
flattened = append(flattened, id)
|
||||
}
|
||||
}
|
||||
return b.add(&ChoiceParser{Children: flattened})
|
||||
}
|
||||
|
||||
// Optional matches zero or one occurrence.
|
||||
func (b *Builder) Optional(child ParserID) ParserID {
|
||||
return b.Repeat(child, 0, 1)
|
||||
}
|
||||
|
||||
// ZeroOrMore matches zero or more occurrences.
|
||||
func (b *Builder) ZeroOrMore(child ParserID) ParserID {
|
||||
return b.Repeat(child, 0, -1)
|
||||
}
|
||||
|
||||
// OneOrMore matches one or more occurrences.
|
||||
func (b *Builder) OneOrMore(child ParserID) ParserID {
|
||||
return b.Repeat(child, 1, -1)
|
||||
}
|
||||
|
||||
// Repeat matches between min and max times. Use -1 for unbounded max.
|
||||
func (b *Builder) Repeat(child ParserID, min, max int) ParserID {
|
||||
return b.add(&RepetitionParser{Child: child, MinCount: min, MaxCount: max})
|
||||
}
|
||||
|
||||
// Peek is a positive lookahead — succeeds if child succeeds, consumes nothing.
|
||||
func (b *Builder) Peek(child ParserID) ParserID {
|
||||
return b.add(&AndParser{Child: child})
|
||||
}
|
||||
|
||||
// Negate is a negative lookahead — succeeds if child fails, consumes nothing.
|
||||
func (b *Builder) Negate(child ParserID) ParserID {
|
||||
return b.add(&NotParser{Child: child})
|
||||
}
|
||||
|
||||
// Any matches a single UTF-8 codepoint.
|
||||
func (b *Builder) Any() ParserID {
|
||||
return b.add(&AnyParser{})
|
||||
}
|
||||
|
||||
// Space matches zero or more whitespace characters.
|
||||
func (b *Builder) Space() ParserID {
|
||||
return b.add(&SpaceParser{})
|
||||
}
|
||||
|
||||
// Chars matches characters from a character class expression like "[a-z]".
|
||||
func (b *Builder) Chars(classes string, min, max int) ParserID {
|
||||
ranges, negated := parseCharClasses(classes)
|
||||
return b.add(&CharsParser{
|
||||
Pattern: classes,
|
||||
Ranges: ranges,
|
||||
Negated: negated,
|
||||
MinCount: min,
|
||||
MaxCount: max,
|
||||
})
|
||||
}
|
||||
|
||||
// Until matches all characters until a delimiter is found (not consumed).
|
||||
func (b *Builder) Until(delimiter string) ParserID {
|
||||
return b.add(&UntilParser{Delimiters: []string{delimiter}})
|
||||
}
|
||||
|
||||
// UntilOneOf matches until any of the delimiters is found.
|
||||
func (b *Builder) UntilOneOf(delimiters ...string) ParserID {
|
||||
return b.add(&UntilParser{Delimiters: delimiters})
|
||||
}
|
||||
|
||||
// Rest matches everything to end of input.
|
||||
func (b *Builder) Rest() ParserID {
|
||||
return b.add(&UntilParser{Delimiters: nil})
|
||||
}
|
||||
|
||||
// JSONString matches JSON string content (without surrounding quotes).
|
||||
func (b *Builder) JSONString() ParserID {
|
||||
return b.add(&JSONStringParser{})
|
||||
}
|
||||
|
||||
// JSON matches a complete JSON value.
|
||||
func (b *Builder) JSON() ParserID {
|
||||
return b.add(&JSONParser{})
|
||||
}
|
||||
|
||||
// JSONNumber matches a JSON number.
|
||||
func (b *Builder) JSONNumber() ParserID {
|
||||
// We implement this as a dedicated parser entry that delegates to parseJSONNumber
|
||||
return b.add(&jsonNumberParser{})
|
||||
}
|
||||
|
||||
// PythonDictString matches single-quoted string content (without quotes).
|
||||
func (b *Builder) PythonDictString() ParserID {
|
||||
return b.add(&PythonDictStringParser{})
|
||||
}
|
||||
|
||||
// DoubleQuotedString matches a double-quoted string: "content" + space
|
||||
func (b *Builder) DoubleQuotedString() ParserID {
|
||||
return b.LazyRule("dq-string", func() ParserID {
|
||||
return b.Seq(b.Literal(`"`), b.JSONString(), b.Literal(`"`), b.Space())
|
||||
})
|
||||
}
|
||||
|
||||
// SingleQuotedString matches a single-quoted string: 'content' + space
|
||||
func (b *Builder) SingleQuotedString() ParserID {
|
||||
return b.LazyRule("sq-string", func() ParserID {
|
||||
return b.Seq(b.Literal("'"), b.PythonDictString(), b.Literal("'"), b.Space())
|
||||
})
|
||||
}
|
||||
|
||||
// FlexibleString matches either a double or single-quoted string.
|
||||
func (b *Builder) FlexibleString() ParserID {
|
||||
return b.LazyRule("flexible-string", func() ParserID {
|
||||
return b.Choice(b.DoubleQuotedString(), b.SingleQuotedString())
|
||||
})
|
||||
}
|
||||
|
||||
// Marker matches <...> or [...] delimited text.
|
||||
func (b *Builder) Marker() ParserID {
|
||||
return b.Choice(
|
||||
b.Seq(b.Literal("<"), b.Until(">"), b.Literal(">")),
|
||||
b.Seq(b.Literal("["), b.Until("]"), b.Literal("]")),
|
||||
)
|
||||
}
|
||||
|
||||
// PythonValue matches a Python-style value (dict, array, string, number, bool, None).
|
||||
func (b *Builder) PythonValue() ParserID {
|
||||
return b.LazyRule("python-value", func() ParserID {
|
||||
return b.Choice(
|
||||
b.PythonDict(), b.PythonArray(), b.PythonString(),
|
||||
b.JSONNumber(), b.PythonBool(), b.PythonNull(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// PythonString matches a Python string (double or single-quoted).
|
||||
func (b *Builder) PythonString() ParserID {
|
||||
return b.LazyRule("python-string", func() ParserID {
|
||||
return b.Choice(b.DoubleQuotedString(), b.SingleQuotedString())
|
||||
})
|
||||
}
|
||||
|
||||
// PythonBool matches True or False.
|
||||
func (b *Builder) PythonBool() ParserID {
|
||||
return b.LazyRule("python-bool", func() ParserID {
|
||||
return b.Seq(b.Choice(b.Literal("True"), b.Literal("False")), b.Space())
|
||||
})
|
||||
}
|
||||
|
||||
// PythonNull matches None.
|
||||
func (b *Builder) PythonNull() ParserID {
|
||||
return b.LazyRule("python-none", func() ParserID {
|
||||
return b.Seq(b.Literal("None"), b.Space())
|
||||
})
|
||||
}
|
||||
|
||||
// PythonDict matches a Python dictionary {key: value, ...}.
|
||||
func (b *Builder) PythonDict() ParserID {
|
||||
return b.LazyRule("python-dict", func() ParserID {
|
||||
member := b.Seq(b.PythonString(), b.Space(), b.Literal(":"), b.Space(), b.PythonValue())
|
||||
return b.Seq(
|
||||
b.Literal("{"), b.Space(),
|
||||
b.Optional(b.Seq(member, b.ZeroOrMore(b.Seq(b.Space(), b.Literal(","), b.Space(), member)))),
|
||||
b.Space(), b.Literal("}"), b.Space(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// PythonArray matches a Python array [value, ...].
|
||||
func (b *Builder) PythonArray() ParserID {
|
||||
return b.LazyRule("python-array", func() ParserID {
|
||||
return b.Seq(
|
||||
b.Literal("["), b.Space(),
|
||||
b.Optional(b.Seq(b.PythonValue(), b.ZeroOrMore(b.Seq(b.Space(), b.Literal(","), b.Space(), b.PythonValue())))),
|
||||
b.Space(), b.Literal("]"), b.Space(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// LazyRule creates a named rule with deferred construction to support recursion.
|
||||
// If the rule already exists, returns a ref to it. Otherwise, creates a placeholder,
|
||||
// builds the child, and replaces the placeholder.
|
||||
func (b *Builder) LazyRule(name string, builderFn func() ParserID) ParserID {
|
||||
cleanName := invalidRuleCharsRe.ReplaceAllString(name, "-")
|
||||
if _, exists := b.arena.rules[cleanName]; exists {
|
||||
return b.add(&RefParser{Name: cleanName})
|
||||
}
|
||||
|
||||
// Create placeholder rule to allow recursive references
|
||||
placeholderChild := b.add(&AnyParser{})
|
||||
ruleID := b.add(&RuleParser{Name: cleanName, Child: placeholderChild})
|
||||
b.arena.rules[cleanName] = ruleID
|
||||
|
||||
// Build the actual parser
|
||||
child := builderFn()
|
||||
|
||||
// Update the rule with the real child
|
||||
b.arena.parsers[ruleID] = &RuleParser{Name: cleanName, Child: child}
|
||||
|
||||
return b.add(&RefParser{Name: cleanName})
|
||||
}
|
||||
|
||||
// Rule creates a named rule and returns a ref to it.
|
||||
func (b *Builder) Rule(name string, child ParserID) ParserID {
|
||||
cleanName := invalidRuleCharsRe.ReplaceAllString(name, "-")
|
||||
ruleID := b.add(&RuleParser{Name: cleanName, Child: child})
|
||||
b.arena.rules[cleanName] = ruleID
|
||||
return b.add(&RefParser{Name: cleanName})
|
||||
}
|
||||
|
||||
// TriggerRule creates a named rule marked as a trigger (for lazy grammar generation).
|
||||
func (b *Builder) TriggerRule(name string, child ParserID) ParserID {
|
||||
cleanName := invalidRuleCharsRe.ReplaceAllString(name, "-")
|
||||
ruleID := b.add(&RuleParser{Name: cleanName, Child: child, Trigger: true})
|
||||
b.arena.rules[cleanName] = ruleID
|
||||
return b.add(&RefParser{Name: cleanName})
|
||||
}
|
||||
|
||||
// Ref creates a forward reference to a named rule.
|
||||
func (b *Builder) Ref(name string) ParserID {
|
||||
return b.add(&RefParser{Name: name})
|
||||
}
|
||||
|
||||
// Atomic creates a parser that suppresses partial AST nodes.
|
||||
func (b *Builder) Atomic(child ParserID) ParserID {
|
||||
return b.add(&AtomicParser{Child: child})
|
||||
}
|
||||
|
||||
// Tag creates a semantic tag in the AST.
|
||||
func (b *Builder) Tag(tag string, child ParserID) ParserID {
|
||||
return b.add(&TagParser{Child: child, Tag: tag})
|
||||
}
|
||||
|
||||
// Schema wraps a parser with schema metadata (pass-through at parse time).
|
||||
func (b *Builder) Schema(child ParserID, name string) ParserID {
|
||||
return b.add(&SchemaParser{Child: child, Name: name})
|
||||
}
|
||||
|
||||
// SetRoot sets the root parser.
|
||||
func (b *Builder) SetRoot(id ParserID) {
|
||||
b.arena.root = id
|
||||
}
|
||||
|
||||
// Build resolves references and returns the arena.
|
||||
func (b *Builder) Build() *Arena {
|
||||
b.arena.resolveRefs()
|
||||
arena := b.arena
|
||||
// Reset builder
|
||||
b.arena = Arena{
|
||||
rules: make(map[string]ParserID),
|
||||
root: InvalidParserID,
|
||||
}
|
||||
return &arena
|
||||
}
|
||||
|
||||
// parseCharClasses parses a character class expression and returns ranges and negation.
|
||||
func parseCharClasses(classes string) ([]CharRange, bool) {
|
||||
content := classes
|
||||
negated := false
|
||||
|
||||
if len(content) > 0 && content[0] == '[' {
|
||||
content = content[1:]
|
||||
}
|
||||
if len(content) > 0 && content[len(content)-1] == ']' {
|
||||
content = content[:len(content)-1]
|
||||
}
|
||||
if len(content) > 0 && content[0] == '^' {
|
||||
negated = true
|
||||
content = content[1:]
|
||||
}
|
||||
|
||||
var ranges []CharRange
|
||||
i := 0
|
||||
for i < len(content) {
|
||||
startChar, startLen := ParseCharClassChar(content, i)
|
||||
i += startLen
|
||||
|
||||
if i+1 < len(content) && content[i] == '-' {
|
||||
endChar, endLen := ParseCharClassChar(content, i+1)
|
||||
ranges = append(ranges, CharRange{Start: startChar, End: endChar})
|
||||
i += 1 + endLen
|
||||
} else {
|
||||
ranges = append(ranges, CharRange{Start: startChar, End: startChar})
|
||||
}
|
||||
}
|
||||
|
||||
return ranges, negated
|
||||
}
|
||||
|
||||
func ParseCharClassChar(content string, pos int) (rune, int) {
|
||||
if content[pos] == '\\' && pos+1 < len(content) {
|
||||
switch content[pos+1] {
|
||||
case 'n':
|
||||
return '\n', 2
|
||||
case 't':
|
||||
return '\t', 2
|
||||
case 'r':
|
||||
return '\r', 2
|
||||
case '\\':
|
||||
return '\\', 2
|
||||
case ']':
|
||||
return ']', 2
|
||||
case '[':
|
||||
return '[', 2
|
||||
case 'x':
|
||||
if r, n := parseHexEscape(content, pos+2, 2); n > 0 {
|
||||
return r, 2 + n
|
||||
}
|
||||
return 'x', 2
|
||||
case 'u':
|
||||
if r, n := parseHexEscape(content, pos+2, 4); n > 0 {
|
||||
return r, 2 + n
|
||||
}
|
||||
return 'u', 2
|
||||
case 'U':
|
||||
if r, n := parseHexEscape(content, pos+2, 8); n > 0 {
|
||||
return r, 2 + n
|
||||
}
|
||||
return 'U', 2
|
||||
default:
|
||||
return rune(content[pos+1]), 2
|
||||
}
|
||||
}
|
||||
return rune(content[pos]), 1
|
||||
}
|
||||
|
||||
func parseHexEscape(s string, pos, count int) (rune, int) {
|
||||
if pos+count > len(s) {
|
||||
return 0, 0
|
||||
}
|
||||
var value rune
|
||||
for i := 0; i < count; i++ {
|
||||
c := s[pos+i]
|
||||
value <<= 4
|
||||
switch {
|
||||
case c >= '0' && c <= '9':
|
||||
value += rune(c - '0')
|
||||
case c >= 'a' && c <= 'f':
|
||||
value += rune(c-'a') + 10
|
||||
case c >= 'A' && c <= 'F':
|
||||
value += rune(c-'A') + 10
|
||||
default:
|
||||
return 0, 0
|
||||
}
|
||||
}
|
||||
return value, count
|
||||
}
|
||||
|
||||
// jsonNumberParser is a dedicated parser for JSON numbers used by JSONNumber().
|
||||
type jsonNumberParser struct{}
|
||||
|
||||
func (p *jsonNumberParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
if start >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, start)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[start] == '-' || (ctx.Input[start] >= '0' && ctx.Input[start] <= '9') {
|
||||
return parseJSONNumber(ctx, start, start)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
|
||||
// BuildPegParser is a helper that creates a parser using a builder function.
|
||||
func BuildPegParser(fn func(b *Builder) ParserID) *Arena {
|
||||
b := NewBuilder()
|
||||
root := fn(b)
|
||||
b.SetRoot(root)
|
||||
return b.Build()
|
||||
}
|
||||
954
pkg/functions/peg/chat.go
Normal file
954
pkg/functions/peg/chat.go
Normal file
@@ -0,0 +1,954 @@
|
||||
package peg
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Tag constants matching llama.cpp
|
||||
const (
|
||||
TagReasoningBlock = "reasoning-block"
|
||||
TagReasoning = "reasoning"
|
||||
TagContent = "content"
|
||||
TagTool = "tool"
|
||||
TagToolOpen = "tool-open"
|
||||
TagToolClose = "tool-close"
|
||||
TagToolID = "tool-id"
|
||||
TagToolName = "tool-name"
|
||||
TagToolArgs = "tool-args"
|
||||
TagToolArg = "tool-arg"
|
||||
TagToolArgOpen = "tool-arg-open"
|
||||
TagToolArgClose = "tool-arg-close"
|
||||
TagToolArgName = "tool-arg-name"
|
||||
TagToolArgValue = "tool-arg-value"
|
||||
TagToolArgStrVal = "tool-arg-string-value"
|
||||
)
|
||||
|
||||
// ChatBuilder extends Builder with chat-specific tag helpers.
|
||||
type ChatBuilder struct {
|
||||
*Builder
|
||||
}
|
||||
|
||||
func NewChatBuilder() *ChatBuilder {
|
||||
return &ChatBuilder{Builder: NewBuilder()}
|
||||
}
|
||||
|
||||
// Semantic tag wrappers
|
||||
func (cb *ChatBuilder) ReasoningBlock(child ParserID) ParserID {
|
||||
return cb.Tag(TagReasoningBlock, child)
|
||||
}
|
||||
func (cb *ChatBuilder) Reasoning(child ParserID) ParserID {
|
||||
return cb.Tag(TagReasoning, child)
|
||||
}
|
||||
func (cb *ChatBuilder) Content(child ParserID) ParserID {
|
||||
return cb.Tag(TagContent, child)
|
||||
}
|
||||
func (cb *ChatBuilder) Tool(child ParserID) ParserID {
|
||||
return cb.Tag(TagTool, child)
|
||||
}
|
||||
func (cb *ChatBuilder) ToolOpen(child ParserID) ParserID {
|
||||
return cb.Atomic(cb.Tag(TagToolOpen, child))
|
||||
}
|
||||
func (cb *ChatBuilder) ToolClose(child ParserID) ParserID {
|
||||
return cb.Atomic(cb.Tag(TagToolClose, child))
|
||||
}
|
||||
func (cb *ChatBuilder) ToolID(child ParserID) ParserID {
|
||||
return cb.Atomic(cb.Tag(TagToolID, child))
|
||||
}
|
||||
func (cb *ChatBuilder) ToolName(child ParserID) ParserID {
|
||||
return cb.Atomic(cb.Tag(TagToolName, child))
|
||||
}
|
||||
func (cb *ChatBuilder) ToolArgs(child ParserID) ParserID {
|
||||
return cb.Tag(TagToolArgs, child)
|
||||
}
|
||||
func (cb *ChatBuilder) ToolArg(child ParserID) ParserID {
|
||||
return cb.Tag(TagToolArg, child)
|
||||
}
|
||||
func (cb *ChatBuilder) ToolArgOpen(child ParserID) ParserID {
|
||||
return cb.Atomic(cb.Tag(TagToolArgOpen, child))
|
||||
}
|
||||
func (cb *ChatBuilder) ToolArgClose(child ParserID) ParserID {
|
||||
return cb.Atomic(cb.Tag(TagToolArgClose, child))
|
||||
}
|
||||
func (cb *ChatBuilder) ToolArgName(child ParserID) ParserID {
|
||||
return cb.Atomic(cb.Tag(TagToolArgName, child))
|
||||
}
|
||||
func (cb *ChatBuilder) ToolArgValue(child ParserID) ParserID {
|
||||
return cb.Tag(TagToolArgValue, child)
|
||||
}
|
||||
func (cb *ChatBuilder) ToolArgStringValue(child ParserID) ParserID {
|
||||
return cb.Tag(TagToolArgStrVal, child)
|
||||
}
|
||||
func (cb *ChatBuilder) ToolArgJSONValue(child ParserID) ParserID {
|
||||
return cb.Atomic(cb.Tag(TagToolArgValue, child))
|
||||
}
|
||||
|
||||
// TagWithSafeContent creates content parsing that avoids a marker string.
|
||||
func (cb *ChatBuilder) TagWithSafeContent(tagName, marker string, p ParserID) ParserID {
|
||||
if marker == "" {
|
||||
return cb.ZeroOrMore(cb.Choice(p,
|
||||
cb.Rule(tagName, cb.Content(cb.Any())),
|
||||
))
|
||||
}
|
||||
contentChunk := cb.Rule(tagName,
|
||||
cb.Content(cb.Seq(
|
||||
cb.Negate(cb.Literal(marker)),
|
||||
cb.Any(),
|
||||
cb.Until(marker),
|
||||
)),
|
||||
)
|
||||
return cb.ZeroOrMore(cb.Choice(p, contentChunk))
|
||||
}
|
||||
|
||||
// ToolDef holds a tool definition used to build parsers.
|
||||
type ToolDef struct {
|
||||
Name string
|
||||
Properties map[string]PropDef
|
||||
}
|
||||
|
||||
// PropDef holds a property definition for tool arguments.
|
||||
type PropDef struct {
|
||||
Type string
|
||||
}
|
||||
|
||||
// StandardConstructedTools builds XML/tagged-style tool parsers.
|
||||
func (cb *ChatBuilder) StandardConstructedTools(
|
||||
markers map[string]string,
|
||||
tools []ToolDef,
|
||||
parallelToolCalls bool,
|
||||
forceToolCalls bool,
|
||||
) ParserID {
|
||||
getMarker := func(key, defaultVal string) string {
|
||||
if v, ok := markers[key]; ok {
|
||||
return v
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
sectionStart := getMarker("tool_call_start_marker", "<tool_call>")
|
||||
sectionEnd := getMarker("tool_call_end_marker", "</tool_call>")
|
||||
funcOpener := getMarker("function_opener", "<function=")
|
||||
funcNameSuffix := getMarker("function_name_suffix", ">")
|
||||
funcCloser := getMarker("function_closer", "</function>")
|
||||
paramKeyPrefix := getMarker("parameter_key_prefix", "<param=")
|
||||
paramKeySuffix := getMarker("parameter_key_suffix", ">")
|
||||
paramCloser := getMarker("parameter_closer", "</param>")
|
||||
callIDPrefix := getMarker("call_id_prefix", "")
|
||||
callIDSuffix := getMarker("call_id_suffix", "")
|
||||
|
||||
hasTaggedParams := paramKeyPrefix != ""
|
||||
|
||||
var toolChoices []ParserID
|
||||
|
||||
if len(tools) == 0 {
|
||||
// Generic parser: accept any function name
|
||||
var args ParserID
|
||||
if hasTaggedParams {
|
||||
// Tagged parameters: <param=key>value</param>
|
||||
argRule := cb.ToolArg(cb.Seq(
|
||||
cb.ToolArgOpen(cb.Literal(paramKeyPrefix)),
|
||||
cb.ToolArgName(cb.Until(paramKeySuffix)),
|
||||
cb.Literal(paramKeySuffix),
|
||||
cb.ToolArgValue(cb.Until(paramCloser)),
|
||||
cb.ToolArgClose(cb.Literal(paramCloser)),
|
||||
))
|
||||
args = cb.ToolArgs(cb.ZeroOrMore(cb.Seq(argRule, cb.Space())))
|
||||
} else {
|
||||
// JSON arguments: {"key": "val"}
|
||||
args = cb.ToolArgs(cb.Until(funcCloser))
|
||||
}
|
||||
|
||||
// Build optional call ID section (between function name and args)
|
||||
callIDSection := cb.Eps()
|
||||
if callIDPrefix != "" && callIDSuffix != "" {
|
||||
callIDSection = cb.Optional(cb.Seq(
|
||||
cb.Literal(callIDPrefix),
|
||||
cb.ToolID(cb.Until(callIDSuffix)),
|
||||
cb.Literal(callIDSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
toolParser := cb.Tool(cb.Seq(
|
||||
cb.ToolOpen(cb.Seq(
|
||||
cb.Literal(funcOpener),
|
||||
cb.ToolName(cb.Until(funcNameSuffix)),
|
||||
cb.Literal(funcNameSuffix),
|
||||
)),
|
||||
callIDSection,
|
||||
cb.Space(),
|
||||
args,
|
||||
cb.Space(),
|
||||
cb.ToolClose(cb.Literal(funcCloser)),
|
||||
))
|
||||
|
||||
toolChoices = append(toolChoices, cb.Rule("tool-generic", toolParser))
|
||||
} else {
|
||||
for _, tool := range tools {
|
||||
// Build argument parsers
|
||||
args := cb.Eps()
|
||||
if hasTaggedParams && len(tool.Properties) > 0 {
|
||||
var argParsers []ParserID
|
||||
for propName := range tool.Properties {
|
||||
argNameParser := cb.Choice(
|
||||
cb.Literal(propName),
|
||||
cb.Literal("\""+propName+"\""),
|
||||
cb.Literal("'"+propName+"'"),
|
||||
)
|
||||
|
||||
argRule := cb.ToolArg(cb.Seq(
|
||||
cb.ToolArgOpen(cb.Literal(paramKeyPrefix)),
|
||||
cb.ToolArgName(argNameParser),
|
||||
cb.Literal(paramKeySuffix),
|
||||
cb.ToolArgValue(cb.Until(paramCloser)),
|
||||
cb.ToolArgClose(cb.Literal(paramCloser)),
|
||||
))
|
||||
argParsers = append(argParsers, argRule)
|
||||
}
|
||||
argChoice := cb.Choice(argParsers...)
|
||||
args = cb.ZeroOrMore(cb.Seq(argChoice, cb.Space()))
|
||||
} else if !hasTaggedParams {
|
||||
// JSON arguments
|
||||
args = cb.Until(funcCloser)
|
||||
}
|
||||
|
||||
// Build optional call ID section
|
||||
toolCallIDSection := cb.Eps()
|
||||
if callIDPrefix != "" && callIDSuffix != "" {
|
||||
toolCallIDSection = cb.Optional(cb.Seq(
|
||||
cb.Literal(callIDPrefix),
|
||||
cb.ToolID(cb.Until(callIDSuffix)),
|
||||
cb.Literal(callIDSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
// Build function parser
|
||||
toolParser := cb.Tool(cb.Seq(
|
||||
cb.ToolOpen(cb.Seq(
|
||||
cb.Literal(funcOpener),
|
||||
cb.ToolName(cb.Literal(tool.Name)),
|
||||
cb.Literal(funcNameSuffix),
|
||||
)),
|
||||
toolCallIDSection,
|
||||
cb.Space(),
|
||||
cb.ToolArgs(args),
|
||||
cb.Space(),
|
||||
cb.ToolClose(cb.Literal(funcCloser)),
|
||||
))
|
||||
|
||||
toolChoices = append(toolChoices, cb.Rule("tool-"+tool.Name, toolParser))
|
||||
}
|
||||
}
|
||||
|
||||
toolChoice := cb.Choice(toolChoices...)
|
||||
|
||||
var section ParserID
|
||||
if parallelToolCalls {
|
||||
section = cb.TriggerRule("tool-call", cb.Seq(
|
||||
cb.Literal(sectionStart), cb.Space(),
|
||||
cb.OneOrMore(cb.Seq(toolChoice, cb.Space())),
|
||||
cb.Literal(sectionEnd),
|
||||
))
|
||||
} else {
|
||||
section = cb.TriggerRule("tool-call", cb.Seq(
|
||||
cb.Literal(sectionStart), cb.Space(),
|
||||
toolChoice, cb.Space(),
|
||||
cb.Literal(sectionEnd),
|
||||
))
|
||||
}
|
||||
|
||||
if forceToolCalls {
|
||||
return section
|
||||
}
|
||||
return cb.Optional(section)
|
||||
}
|
||||
|
||||
// StandardJSONToolsOpts holds options for building JSON tool call parsers.
|
||||
type StandardJSONToolsOpts struct {
|
||||
SectionStart string
|
||||
SectionEnd string
|
||||
Tools []ToolDef
|
||||
ParallelCalls bool
|
||||
ForceToolCalls bool
|
||||
NameKey string
|
||||
ArgsKey string
|
||||
ArrayWrapped bool
|
||||
FunctionIsKey bool
|
||||
CallIDKey string
|
||||
GenCallIDKey string
|
||||
ParametersOrder []string
|
||||
}
|
||||
|
||||
// StandardJSONTools builds JSON-format tool call parsers.
|
||||
func (cb *ChatBuilder) StandardJSONTools(opts StandardJSONToolsOpts) ParserID {
|
||||
if len(opts.Tools) == 0 {
|
||||
return cb.Eps()
|
||||
}
|
||||
|
||||
effectiveNameKey := opts.NameKey
|
||||
if effectiveNameKey == "" {
|
||||
effectiveNameKey = "name"
|
||||
}
|
||||
effectiveArgsKey := opts.ArgsKey
|
||||
if effectiveArgsKey == "" {
|
||||
effectiveArgsKey = "arguments"
|
||||
}
|
||||
|
||||
var toolChoices ParserID
|
||||
if opts.FunctionIsKey {
|
||||
toolChoices = cb.buildJSONToolsFunctionIsKey(opts.Tools, opts.ArgsKey, effectiveArgsKey, opts.CallIDKey, opts.GenCallIDKey)
|
||||
} else {
|
||||
nameSpec := parseKeySpec(effectiveNameKey)
|
||||
argsSpec := parseKeySpec(effectiveArgsKey)
|
||||
if nameSpec.prefix != "" || argsSpec.prefix != "" {
|
||||
toolChoices = cb.buildJSONToolsNestedKeys(opts.Tools, effectiveNameKey, effectiveArgsKey, opts.CallIDKey, opts.GenCallIDKey)
|
||||
} else {
|
||||
toolChoices = cb.buildJSONToolsFlatKeys(opts.Tools, effectiveNameKey, effectiveArgsKey, opts.CallIDKey, opts.GenCallIDKey, opts.ParametersOrder)
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := toolChoices
|
||||
if opts.ParallelCalls {
|
||||
toolCalls = cb.Seq(
|
||||
toolChoices,
|
||||
cb.ZeroOrMore(cb.Seq(cb.Space(), cb.Literal(","), cb.Space(), toolChoices)),
|
||||
)
|
||||
}
|
||||
|
||||
if opts.ArrayWrapped {
|
||||
toolCalls = cb.Seq(cb.Literal("["), cb.Space(), toolCalls, cb.Space(), cb.Literal("]"))
|
||||
}
|
||||
|
||||
section := cb.TriggerRule("tool-call", cb.Seq(
|
||||
cb.Literal(opts.SectionStart), cb.Space(),
|
||||
toolCalls, cb.Space(),
|
||||
cb.Literal(opts.SectionEnd),
|
||||
))
|
||||
|
||||
if opts.ForceToolCalls {
|
||||
return section
|
||||
}
|
||||
return cb.Optional(section)
|
||||
}
|
||||
|
||||
func (cb *ChatBuilder) buildJSONToolsFunctionIsKey(
|
||||
tools []ToolDef,
|
||||
argsKey, effectiveArgsKey, callIDKey, genCallIDKey string,
|
||||
) ParserID {
|
||||
var toolChoices []ParserID
|
||||
|
||||
for _, tool := range tools {
|
||||
var innerFields []ParserID
|
||||
|
||||
if callIDKey != "" {
|
||||
idParser := cb.Atomic(cb.Seq(
|
||||
cb.Literal("\""+callIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\""),
|
||||
))
|
||||
innerFields = append(innerFields, cb.Optional(cb.Seq(idParser, cb.Space(), cb.Optional(cb.Seq(cb.Literal(","), cb.Space())))))
|
||||
}
|
||||
|
||||
if genCallIDKey != "" {
|
||||
genIDParser := cb.Atomic(cb.Seq(
|
||||
cb.Literal("\""+genCallIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.Choice(
|
||||
cb.Seq(cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\"")),
|
||||
cb.ToolID(cb.JSONNumber()),
|
||||
),
|
||||
))
|
||||
innerFields = append(innerFields, cb.Optional(cb.Seq(genIDParser, cb.Space(), cb.Optional(cb.Seq(cb.Literal(","), cb.Space())))))
|
||||
}
|
||||
|
||||
// Arguments
|
||||
var argsParser ParserID
|
||||
if argsKey == "" {
|
||||
argsParser = cb.ToolArgs(cb.JSON())
|
||||
} else {
|
||||
argsParser = cb.Seq(
|
||||
cb.Literal("\""+effectiveArgsKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.ToolArgs(cb.JSON()),
|
||||
)
|
||||
}
|
||||
innerFields = append(innerFields, argsParser)
|
||||
|
||||
// Build inner object
|
||||
var innerObject ParserID
|
||||
if argsKey == "" && len(innerFields) == 1 {
|
||||
innerObject = innerFields[0]
|
||||
} else {
|
||||
innerObject = cb.Literal("{")
|
||||
for i, f := range innerFields {
|
||||
innerObject = cb.Seq(innerObject, cb.Space(), f)
|
||||
if i < len(innerFields)-1 {
|
||||
innerObject = cb.Seq(innerObject, cb.Space())
|
||||
}
|
||||
}
|
||||
innerObject = cb.Seq(innerObject, cb.Space(), cb.Literal("}"))
|
||||
}
|
||||
|
||||
toolParser := cb.Tool(cb.Seq(
|
||||
cb.ToolOpen(cb.Literal("{")), cb.Space(),
|
||||
cb.Literal("\""), cb.ToolName(cb.Literal(tool.Name)), cb.Literal("\""),
|
||||
cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
innerObject,
|
||||
cb.Space(), cb.ToolClose(cb.Literal("}")),
|
||||
))
|
||||
|
||||
toolChoices = append(toolChoices, cb.Rule("tool-"+tool.Name, toolParser))
|
||||
}
|
||||
|
||||
return cb.Choice(toolChoices...)
|
||||
}
|
||||
|
||||
// keySpec represents a dot-notation key split into prefix and field.
|
||||
type keySpec struct {
|
||||
prefix string
|
||||
field string
|
||||
}
|
||||
|
||||
func parseKeySpec(key string) keySpec {
|
||||
for i, c := range key {
|
||||
if c == '.' {
|
||||
return keySpec{prefix: key[:i], field: key[i+1:]}
|
||||
}
|
||||
}
|
||||
return keySpec{field: key}
|
||||
}
|
||||
|
||||
func (cb *ChatBuilder) buildJSONToolsNestedKeys(
|
||||
tools []ToolDef,
|
||||
effectiveNameKey, effectiveArgsKey, callIDKey, genCallIDKey string,
|
||||
) ParserID {
|
||||
var toolChoices []ParserID
|
||||
|
||||
nameSpec := parseKeySpec(effectiveNameKey)
|
||||
argsSpec := parseKeySpec(effectiveArgsKey)
|
||||
|
||||
nestedPrefix := nameSpec.prefix
|
||||
if nestedPrefix == "" {
|
||||
nestedPrefix = argsSpec.prefix
|
||||
}
|
||||
nestedNameField := nameSpec.field
|
||||
if nameSpec.prefix == "" {
|
||||
nestedNameField = effectiveNameKey
|
||||
}
|
||||
nestedArgsField := argsSpec.field
|
||||
if argsSpec.prefix == "" {
|
||||
nestedArgsField = effectiveArgsKey
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
nestedName := cb.Seq(
|
||||
cb.Literal("\""+nestedNameField+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.Literal("\""), cb.ToolName(cb.Literal(tool.Name)), cb.Literal("\""),
|
||||
)
|
||||
nestedArgs := cb.Seq(
|
||||
cb.Literal("\""+nestedArgsField+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.ToolArgs(cb.JSON()),
|
||||
)
|
||||
nestedObject := cb.Seq(
|
||||
cb.Literal("{"), cb.Space(),
|
||||
nestedName, cb.Space(), cb.Literal(","), cb.Space(),
|
||||
nestedArgs,
|
||||
cb.Space(), cb.Literal("}"),
|
||||
)
|
||||
|
||||
toolParserBody := cb.ToolOpen(cb.Literal("{"))
|
||||
|
||||
if callIDKey != "" {
|
||||
idSpec := parseKeySpec(callIDKey)
|
||||
if idSpec.prefix == "" {
|
||||
idParser := cb.Atomic(cb.Seq(
|
||||
cb.Literal("\""+callIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\""),
|
||||
))
|
||||
toolParserBody = cb.Seq(toolParserBody, cb.Space(),
|
||||
cb.Optional(cb.Seq(idParser, cb.Space(), cb.Literal(","), cb.Space())))
|
||||
}
|
||||
}
|
||||
|
||||
if genCallIDKey != "" {
|
||||
genIDSpec := parseKeySpec(genCallIDKey)
|
||||
if genIDSpec.prefix == "" {
|
||||
genIDParser := cb.Atomic(cb.Seq(
|
||||
cb.Literal("\""+genCallIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.Choice(
|
||||
cb.Seq(cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\"")),
|
||||
cb.ToolID(cb.JSONNumber()),
|
||||
),
|
||||
))
|
||||
toolParserBody = cb.Seq(toolParserBody, cb.Space(),
|
||||
cb.Optional(cb.Seq(genIDParser, cb.Space(), cb.Literal(","), cb.Space())))
|
||||
}
|
||||
}
|
||||
|
||||
nestedField := cb.Seq(
|
||||
cb.Literal("\""+nestedPrefix+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
nestedObject,
|
||||
)
|
||||
toolParserBody = cb.Seq(toolParserBody, cb.Space(), nestedField, cb.Space(), cb.ToolClose(cb.Literal("}")))
|
||||
|
||||
toolChoices = append(toolChoices, cb.Rule("tool-"+tool.Name, cb.Tool(toolParserBody)))
|
||||
}
|
||||
|
||||
return cb.Choice(toolChoices...)
|
||||
}
|
||||
|
||||
func (cb *ChatBuilder) buildJSONToolsFlatKeys(
|
||||
tools []ToolDef,
|
||||
effectiveNameKey, effectiveArgsKey, callIDKey, genCallIDKey string,
|
||||
parametersOrder []string,
|
||||
) ParserID {
|
||||
var toolChoices []ParserID
|
||||
nameKeyParser := cb.Literal("\"" + effectiveNameKey + "\"")
|
||||
argsKeyParser := cb.Literal("\"" + effectiveArgsKey + "\"")
|
||||
|
||||
for _, tool := range tools {
|
||||
toolNameP := cb.Seq(
|
||||
nameKeyParser, cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.Literal("\""), cb.ToolName(cb.Literal(tool.Name)), cb.Literal("\""),
|
||||
)
|
||||
toolArgsP := cb.Seq(
|
||||
argsKeyParser, cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.ToolArgs(cb.JSON()),
|
||||
)
|
||||
|
||||
pairs := []parserPair{
|
||||
{toolNameP, effectiveNameKey},
|
||||
{toolArgsP, effectiveArgsKey},
|
||||
}
|
||||
|
||||
if callIDKey != "" {
|
||||
idParser := cb.Atomic(cb.Seq(
|
||||
cb.Literal("\""+callIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.Choice(
|
||||
cb.Seq(cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\"")),
|
||||
cb.ToolID(cb.JSONNumber()),
|
||||
),
|
||||
))
|
||||
pairs = append(pairs, parserPair{cb.Optional(idParser), callIDKey})
|
||||
}
|
||||
|
||||
if genCallIDKey != "" {
|
||||
genIDParser := cb.Atomic(cb.Seq(
|
||||
cb.Literal("\""+genCallIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
|
||||
cb.Choice(
|
||||
cb.Seq(cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\"")),
|
||||
cb.ToolID(cb.JSONNumber()),
|
||||
),
|
||||
))
|
||||
pairs = append(pairs, parserPair{cb.Optional(genIDParser), genCallIDKey})
|
||||
}
|
||||
|
||||
// Sort by parameters_order if provided
|
||||
if len(parametersOrder) > 0 {
|
||||
sortPairsByOrder(pairs, parametersOrder)
|
||||
}
|
||||
|
||||
orderedBody := cb.ToolOpen(cb.Literal("{"))
|
||||
for i, p := range pairs {
|
||||
orderedBody = cb.Seq(orderedBody, cb.Space(), p.parser)
|
||||
if i < len(pairs)-1 {
|
||||
orderedBody = cb.Seq(orderedBody, cb.Space(), cb.Literal(","), cb.Space())
|
||||
}
|
||||
}
|
||||
orderedBody = cb.Seq(orderedBody, cb.Space(), cb.ToolClose(cb.Literal("}")))
|
||||
|
||||
toolChoices = append(toolChoices, cb.Rule("tool-"+tool.Name, cb.Tool(orderedBody)))
|
||||
}
|
||||
|
||||
return cb.Choice(toolChoices...)
|
||||
}
|
||||
|
||||
type parserPair struct {
|
||||
parser ParserID
|
||||
key string
|
||||
}
|
||||
|
||||
func sortPairsByOrder(pairs []parserPair, order []string) {
|
||||
indexOf := func(key string) int {
|
||||
for i, o := range order {
|
||||
if o == key {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return len(order)
|
||||
}
|
||||
// Simple insertion sort (small N)
|
||||
for i := 1; i < len(pairs); i++ {
|
||||
for j := i; j > 0 && indexOf(pairs[j].key) < indexOf(pairs[j-1].key); j-- {
|
||||
pairs[j], pairs[j-1] = pairs[j-1], pairs[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BuildChatPegParser is a convenience function to build a chat parser.
|
||||
func BuildChatPegParser(fn func(cb *ChatBuilder) ParserID) *Arena {
|
||||
cb := NewChatBuilder()
|
||||
root := fn(cb)
|
||||
cb.SetRoot(root)
|
||||
return cb.Build()
|
||||
}
|
||||
|
||||
// ToolCall represents a parsed tool call.
|
||||
type ToolCall struct {
|
||||
Name string
|
||||
Arguments string
|
||||
ID string
|
||||
}
|
||||
|
||||
// ChatMsg represents a parsed chat message.
|
||||
type ChatMsg struct {
|
||||
Content string
|
||||
ReasoningContent string
|
||||
ToolCalls []ToolCall
|
||||
}
|
||||
|
||||
// ChatPegMapper maps AST nodes to a ChatMsg.
|
||||
type ChatPegMapper struct {
|
||||
Result ChatMsg
|
||||
|
||||
pendingToolCall *ToolCall
|
||||
currentTool *ToolCall
|
||||
argCount int
|
||||
closingQuotePend bool
|
||||
argsBuffer string
|
||||
}
|
||||
|
||||
func (m *ChatPegMapper) argsTarget() *string {
|
||||
if m.currentTool != nil && m.currentTool.Name != "" {
|
||||
return &m.currentTool.Arguments
|
||||
}
|
||||
return &m.argsBuffer
|
||||
}
|
||||
|
||||
// FromAST populates the ChatMsg from parse results.
|
||||
func (m *ChatPegMapper) FromAST(ast *AstArena, result *ParseResult) {
|
||||
ast.VisitResult(result, func(node *AstNode) {
|
||||
m.mapNode(node)
|
||||
})
|
||||
|
||||
// Flush pending tool call
|
||||
if m.pendingToolCall != nil && m.pendingToolCall.Name != "" {
|
||||
if m.argsBuffer != "" {
|
||||
m.pendingToolCall.Arguments = m.argsBuffer
|
||||
}
|
||||
if m.closingQuotePend && m.pendingToolCall.Arguments != "" {
|
||||
m.pendingToolCall.Arguments += "\""
|
||||
}
|
||||
m.Result.ToolCalls = append(m.Result.ToolCalls, *m.pendingToolCall)
|
||||
m.pendingToolCall = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ChatPegMapper) mapNode(node *AstNode) {
|
||||
switch node.Tag {
|
||||
case TagReasoning:
|
||||
m.Result.ReasoningContent += node.Text
|
||||
|
||||
case TagContent:
|
||||
m.Result.Content += node.Text
|
||||
|
||||
case TagToolOpen:
|
||||
tc := ToolCall{}
|
||||
m.pendingToolCall = &tc
|
||||
m.currentTool = m.pendingToolCall
|
||||
m.argCount = 0
|
||||
m.argsBuffer = ""
|
||||
m.closingQuotePend = false
|
||||
|
||||
case TagToolID:
|
||||
if m.currentTool != nil {
|
||||
text := trimTrailingSpace(node.Text)
|
||||
if len(text) >= 2 && text[0] == '"' && text[len(text)-1] == '"' {
|
||||
text = text[1 : len(text)-1]
|
||||
}
|
||||
m.currentTool.ID = text
|
||||
}
|
||||
|
||||
case TagToolName:
|
||||
if m.currentTool != nil {
|
||||
m.currentTool.Name = trimTrailingSpace(node.Text)
|
||||
if m.argsBuffer != "" {
|
||||
m.currentTool.Arguments = m.argsBuffer
|
||||
m.argsBuffer = ""
|
||||
} else if m.currentTool.Arguments == "" {
|
||||
m.currentTool.Arguments = "{"
|
||||
}
|
||||
// Add tool call to results for streaming
|
||||
if m.pendingToolCall != nil {
|
||||
m.Result.ToolCalls = append(m.Result.ToolCalls, *m.pendingToolCall)
|
||||
m.pendingToolCall = nil
|
||||
m.currentTool = &m.Result.ToolCalls[len(m.Result.ToolCalls)-1]
|
||||
}
|
||||
}
|
||||
|
||||
case TagToolArgs:
|
||||
if m.currentTool != nil {
|
||||
text := trimTrailingSpace(node.Text)
|
||||
if len(text) > 0 && text[0] == '{' {
|
||||
*m.argsTarget() = text
|
||||
}
|
||||
}
|
||||
|
||||
case TagToolArgOpen:
|
||||
m.closingQuotePend = false
|
||||
|
||||
case TagToolArgName:
|
||||
if m.currentTool != nil {
|
||||
argEntry := ""
|
||||
if m.argCount > 0 {
|
||||
argEntry = ","
|
||||
}
|
||||
trimmed := trimSpace(node.Text)
|
||||
escapedKey := escapeJSONString(trimmed)
|
||||
argEntry += escapedKey + ":"
|
||||
m.argCount++
|
||||
|
||||
target := m.argsTarget()
|
||||
if *target == "" {
|
||||
*target = "{"
|
||||
}
|
||||
*target += argEntry
|
||||
}
|
||||
|
||||
case TagToolArgStrVal:
|
||||
if m.currentTool != nil {
|
||||
content := trimOneSpace(node.Text)
|
||||
var valueToAdd string
|
||||
if content == "" {
|
||||
valueToAdd = "\""
|
||||
m.closingQuotePend = true
|
||||
} else {
|
||||
if !m.closingQuotePend {
|
||||
valueToAdd = "\""
|
||||
m.closingQuotePend = true
|
||||
}
|
||||
valueToAdd += EscapeJSONStringInner(content)
|
||||
}
|
||||
*m.argsTarget() += valueToAdd
|
||||
}
|
||||
|
||||
case TagToolArgValue:
|
||||
if m.currentTool != nil {
|
||||
content := trimOneSpace(node.Text)
|
||||
var valueToAdd string
|
||||
if content != "" {
|
||||
isPotentialContainer := content[0] == '[' || content[0] == '{'
|
||||
if isPotentialContainer {
|
||||
content = NormalizeQuotesToJSON(content)
|
||||
}
|
||||
|
||||
// Try to parse as JSON
|
||||
var parsed json.RawMessage
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err == nil {
|
||||
// Check if it's a string
|
||||
var s string
|
||||
if err2 := json.Unmarshal(parsed, &s); err2 == nil {
|
||||
// It's a string — strip closing quote for monotonic streaming
|
||||
escaped, _ := json.Marshal(s)
|
||||
str := string(escaped)
|
||||
if len(str) > 0 && str[len(str)-1] == '"' {
|
||||
str = str[:len(str)-1]
|
||||
}
|
||||
valueToAdd = str
|
||||
m.closingQuotePend = true
|
||||
} else {
|
||||
// Non-string: use raw content
|
||||
valueToAdd = content
|
||||
}
|
||||
} else {
|
||||
if node.IsPartial && isPotentialContainer {
|
||||
valueToAdd = content
|
||||
} else {
|
||||
if !m.closingQuotePend {
|
||||
valueToAdd = "\""
|
||||
m.closingQuotePend = true
|
||||
}
|
||||
valueToAdd += EscapeJSONStringInner(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
*m.argsTarget() += valueToAdd
|
||||
}
|
||||
|
||||
case TagToolArgClose:
|
||||
if m.currentTool != nil {
|
||||
if m.closingQuotePend {
|
||||
*m.argsTarget() += "\""
|
||||
m.closingQuotePend = false
|
||||
}
|
||||
}
|
||||
|
||||
case TagToolClose:
|
||||
if m.currentTool != nil {
|
||||
// Flush buffer if tool name was never seen
|
||||
if m.currentTool.Name == "" && m.argsBuffer != "" {
|
||||
m.currentTool.Arguments = m.argsBuffer
|
||||
m.argsBuffer = ""
|
||||
}
|
||||
if m.closingQuotePend {
|
||||
m.currentTool.Arguments += "\""
|
||||
m.closingQuotePend = false
|
||||
}
|
||||
// Close unclosed braces
|
||||
for depth := jsonBraceDepth(m.currentTool.Arguments); depth > 0; depth-- {
|
||||
m.currentTool.Arguments += "}"
|
||||
}
|
||||
// Add if pending and named
|
||||
if m.pendingToolCall != nil {
|
||||
if m.currentTool.Name != "" {
|
||||
m.Result.ToolCalls = append(m.Result.ToolCalls, *m.pendingToolCall)
|
||||
}
|
||||
m.pendingToolCall = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeQuotesToJSON converts Python-style single-quoted strings to JSON double-quoted.
|
||||
func NormalizeQuotesToJSON(input string) string {
|
||||
result := make([]byte, 0, len(input)+16)
|
||||
|
||||
inSingleQuoted := false
|
||||
inDoubleQuoted := false
|
||||
|
||||
for i := 0; i < len(input); i++ {
|
||||
c := input[i]
|
||||
|
||||
if c == '\\' && i+1 < len(input) {
|
||||
next := input[i+1]
|
||||
|
||||
if inSingleQuoted {
|
||||
if next == '\'' {
|
||||
result = append(result, '\'')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if next == '"' {
|
||||
result = append(result, '\\', '"')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
result = append(result, c, next)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if inDoubleQuoted {
|
||||
result = append(result, c, next)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, c)
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' {
|
||||
if inSingleQuoted {
|
||||
result = append(result, '\\', '"')
|
||||
} else {
|
||||
inDoubleQuoted = !inDoubleQuoted
|
||||
result = append(result, c)
|
||||
}
|
||||
} else if c == '\'' {
|
||||
if inDoubleQuoted {
|
||||
result = append(result, c)
|
||||
} else if inSingleQuoted {
|
||||
inSingleQuoted = false
|
||||
result = append(result, '"')
|
||||
} else {
|
||||
inSingleQuoted = true
|
||||
result = append(result, '"')
|
||||
}
|
||||
} else {
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// EscapeJSONStringInner JSON-escapes a string and returns the inner content (without surrounding quotes).
|
||||
func EscapeJSONStringInner(s string) string {
|
||||
escaped, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return s
|
||||
}
|
||||
str := string(escaped)
|
||||
if len(str) >= 2 && str[0] == '"' && str[len(str)-1] == '"' {
|
||||
return str[1 : len(str)-1]
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func escapeJSONString(s string) string {
|
||||
escaped, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
return string(escaped)
|
||||
}
|
||||
|
||||
func jsonBraceDepth(s string) int {
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' && inString {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
if !inString {
|
||||
if c == '{' {
|
||||
depth++
|
||||
} else if c == '}' {
|
||||
depth--
|
||||
}
|
||||
}
|
||||
}
|
||||
return depth
|
||||
}
|
||||
|
||||
func trimTrailingSpace(s string) string {
|
||||
end := len(s)
|
||||
for end > 0 && isWhitespace(s[end-1]) {
|
||||
end--
|
||||
}
|
||||
return s[:end]
|
||||
}
|
||||
|
||||
func trimLeadingSpace(s string, max int) string {
|
||||
start := 0
|
||||
count := 0
|
||||
for start < len(s) && isWhitespace(s[start]) {
|
||||
if max >= 0 && count >= max {
|
||||
break
|
||||
}
|
||||
start++
|
||||
count++
|
||||
}
|
||||
return s[start:]
|
||||
}
|
||||
|
||||
func trimSpace(s string) string {
|
||||
s = trimLeadingSpace(s, 1)
|
||||
return trimTrailingSpace(s)
|
||||
}
|
||||
|
||||
func trimOneSpace(s string) string {
|
||||
s = trimLeadingSpace(s, 1)
|
||||
end := len(s)
|
||||
count := 0
|
||||
for end > 0 && isWhitespace(s[end-1]) && count < 1 {
|
||||
end--
|
||||
count++
|
||||
}
|
||||
return s[:end]
|
||||
}
|
||||
910
pkg/functions/peg/chat_test.go
Normal file
910
pkg/functions/peg/chat_test.go
Normal file
@@ -0,0 +1,910 @@
|
||||
package peg_test
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/pkg/functions/peg"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func createTools() []peg.ToolDef {
|
||||
return []peg.ToolDef{
|
||||
{
|
||||
Name: "get_current_weather",
|
||||
Properties: map[string]peg.PropDef{
|
||||
"location": {Type: "string"},
|
||||
"unit": {Type: "string"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "get_forecast",
|
||||
Properties: map[string]peg.PropDef{
|
||||
"location": {Type: "string"},
|
||||
"unit": {Type: "string"},
|
||||
"days": {Type: "integer"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "search_knowledge_base",
|
||||
Properties: map[string]peg.PropDef{
|
||||
"query": {Type: "string"},
|
||||
"max_results": {Type: "integer"},
|
||||
"category": {Type: "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func simpleTokenize(input string) []string {
|
||||
var result []string
|
||||
var current string
|
||||
|
||||
for _, c := range input {
|
||||
switch c {
|
||||
case ' ', '\n', '\t', '{', '}', ',', '[', '"', ']', '.', '<', '>', '=', '/':
|
||||
if current != "" {
|
||||
result = append(result, current)
|
||||
current = ""
|
||||
}
|
||||
}
|
||||
current += string(c)
|
||||
}
|
||||
if current != "" {
|
||||
result = append(result, current)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
var _ = Describe("Chat PEG Parser", func() {
|
||||
Context("ExampleNative", func() {
|
||||
type testCase struct {
|
||||
name string
|
||||
tools []peg.ToolDef
|
||||
reasoningFormat string
|
||||
parallelCalls bool
|
||||
forcedOpen bool
|
||||
forceToolCalls bool
|
||||
input string
|
||||
expectReasoning string
|
||||
expectContent string
|
||||
expectToolCalls []peg.ToolCall
|
||||
}
|
||||
|
||||
buildParser := func(tc testCase) *peg.Arena {
|
||||
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
reasoningInContent := tc.reasoningFormat == "none"
|
||||
|
||||
var reasoning peg.ParserID
|
||||
if tc.forcedOpen {
|
||||
reasoning = p.Seq(
|
||||
p.Reasoning(p.Until("</think>")),
|
||||
p.Literal("</think>"),
|
||||
p.Space(),
|
||||
)
|
||||
} else {
|
||||
reasoning = p.Optional(p.Seq(
|
||||
p.Literal("<think>"),
|
||||
p.Reasoning(p.Until("</think>")),
|
||||
p.Literal("</think>"),
|
||||
p.Space(),
|
||||
))
|
||||
}
|
||||
|
||||
if len(tc.tools) > 0 {
|
||||
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
|
||||
SectionStart: "<tool_call>[",
|
||||
SectionEnd: "]</tool_call>",
|
||||
Tools: tc.tools,
|
||||
ParallelCalls: tc.parallelCalls,
|
||||
ForceToolCalls: tc.forceToolCalls,
|
||||
})
|
||||
|
||||
var parts []peg.ParserID
|
||||
if reasoningInContent {
|
||||
parts = append(parts, p.Eps())
|
||||
} else {
|
||||
parts = append(parts, reasoning)
|
||||
}
|
||||
parts = append(parts,
|
||||
p.Content(p.Until("<tool_call>")),
|
||||
p.Optional(p.Seq(p.Space(), toolCall)),
|
||||
p.Space(),
|
||||
p.End(),
|
||||
)
|
||||
return p.Seq(parts...)
|
||||
}
|
||||
|
||||
var parts []peg.ParserID
|
||||
if reasoningInContent {
|
||||
parts = append(parts, p.Eps())
|
||||
} else {
|
||||
parts = append(parts, reasoning)
|
||||
}
|
||||
parts = append(parts, p.Content(p.Rest()), p.End())
|
||||
return p.Seq(parts...)
|
||||
})
|
||||
}
|
||||
|
||||
DescribeTable("native format cases",
|
||||
func(tc testCase) {
|
||||
parser := buildParser(tc)
|
||||
ctx := peg.NewParseContext(tc.input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.Content).To(Equal(tc.expectContent))
|
||||
Expect(msg.ReasoningContent).To(Equal(tc.expectReasoning))
|
||||
Expect(msg.ToolCalls).To(HaveLen(len(tc.expectToolCalls)))
|
||||
|
||||
for i := 0; i < len(tc.expectToolCalls) && i < len(msg.ToolCalls); i++ {
|
||||
Expect(msg.ToolCalls[i].Name).To(Equal(tc.expectToolCalls[i].Name))
|
||||
Expect(msg.ToolCalls[i].Arguments).To(Equal(tc.expectToolCalls[i].Arguments))
|
||||
}
|
||||
},
|
||||
Entry("content with thinking", testCase{
|
||||
reasoningFormat: "auto",
|
||||
input: "<think>The user said hello, I must say hello back</think>\nHello",
|
||||
expectReasoning: "The user said hello, I must say hello back",
|
||||
expectContent: "Hello",
|
||||
}),
|
||||
Entry("content without thinking", testCase{
|
||||
reasoningFormat: "auto",
|
||||
input: "Hello",
|
||||
expectContent: "Hello",
|
||||
}),
|
||||
Entry("content with reasoning_format = none", testCase{
|
||||
reasoningFormat: "none",
|
||||
forcedOpen: true,
|
||||
input: "<think>The user said hello, I must say hello back</think>\nHello",
|
||||
expectContent: "<think>The user said hello, I must say hello back</think>\nHello",
|
||||
}),
|
||||
Entry("content with forced_open", testCase{
|
||||
reasoningFormat: "auto",
|
||||
forcedOpen: true,
|
||||
input: "The user said hello, I must say hello back</think>\nHello",
|
||||
expectReasoning: "The user said hello, I must say hello back",
|
||||
expectContent: "Hello",
|
||||
}),
|
||||
Entry("content with forced_open and reasoning_format = none", testCase{
|
||||
reasoningFormat: "none",
|
||||
forcedOpen: true,
|
||||
input: "The user said hello, I must say hello back</think>\nHello",
|
||||
expectContent: "The user said hello, I must say hello back</think>\nHello",
|
||||
}),
|
||||
Entry("single tool call", testCase{
|
||||
tools: createTools(),
|
||||
reasoningFormat: "auto",
|
||||
forcedOpen: true,
|
||||
input: "I must get the weather in New York</think>\n" +
|
||||
"<tool_call>[" +
|
||||
`{"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}}` +
|
||||
"]</tool_call>",
|
||||
expectReasoning: "I must get the weather in New York",
|
||||
expectToolCalls: []peg.ToolCall{
|
||||
{
|
||||
Name: "get_current_weather",
|
||||
Arguments: `{"location": "New York City, NY", "unit": "fahrenheit"}`,
|
||||
},
|
||||
},
|
||||
}),
|
||||
Entry("parallel tool calls", testCase{
|
||||
tools: createTools(),
|
||||
reasoningFormat: "auto",
|
||||
parallelCalls: true,
|
||||
forcedOpen: true,
|
||||
input: "I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me search that for you." +
|
||||
"<tool_call>[" +
|
||||
`{"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}}` +
|
||||
", " +
|
||||
`{"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}` +
|
||||
", " +
|
||||
`{"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}}` +
|
||||
", " +
|
||||
`{"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}}` +
|
||||
"]</tool_call>",
|
||||
expectReasoning: "I must get the weather in New York and San Francisco and a 3 day forecast of each.",
|
||||
expectContent: "Let me search that for you.",
|
||||
expectToolCalls: []peg.ToolCall{
|
||||
{Name: "get_current_weather", Arguments: `{"location": "New York City, NY", "unit": "fahrenheit"}`},
|
||||
{Name: "get_current_weather", Arguments: `{"location": "San Francisco, CA", "unit": "fahrenheit"}`},
|
||||
{Name: "get_forecast", Arguments: `{"location": "New York City, NY", "unit": "fahrenheit", "days": 3}`},
|
||||
{Name: "get_forecast", Arguments: `{"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}`},
|
||||
},
|
||||
}),
|
||||
Entry("JSON schema response format", testCase{
|
||||
tools: createTools(),
|
||||
reasoningFormat: "auto",
|
||||
forcedOpen: true,
|
||||
forceToolCalls: false,
|
||||
input: "Thinking about the answer</think>\n" +
|
||||
`<tool_call>[{"name": "get_current_weather", "arguments": {"location": "NYC", "unit": "celsius"}}]</tool_call>`,
|
||||
expectReasoning: "Thinking about the answer",
|
||||
expectToolCalls: []peg.ToolCall{
|
||||
{Name: "get_current_weather", Arguments: `{"location": "NYC", "unit": "celsius"}`},
|
||||
},
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
Context("ExampleQwen3Coder", func() {
|
||||
It("parses tool calls with tagged parameters", func() {
|
||||
tools := createTools()
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
content := p.Rule("content", p.Content(p.Until("<tool_call>")))
|
||||
|
||||
var toolParsers []peg.ParserID
|
||||
for _, tool := range tools {
|
||||
var argChoices []peg.ParserID
|
||||
for propName, prop := range tool.Properties {
|
||||
var argValueParser peg.ParserID
|
||||
if prop.Type == "string" {
|
||||
argValueParser = p.ToolArgStringValue(p.UntilOneOf("</parameter>\n<parameter=", "</parameter>\n</function>"))
|
||||
} else {
|
||||
argValueParser = p.ToolArgJSONValue(p.JSON())
|
||||
}
|
||||
|
||||
arg := p.ToolArg(p.Seq(
|
||||
p.ToolArgOpen(p.Literal("<parameter="+propName+">")),
|
||||
argValueParser,
|
||||
p.ToolArgClose(p.Seq(
|
||||
p.Literal("</parameter>\n"),
|
||||
p.Peek(p.Choice(p.Literal("<parameter="), p.Literal("</function>"))),
|
||||
)),
|
||||
))
|
||||
argChoices = append(argChoices, arg)
|
||||
}
|
||||
|
||||
argChoice := p.Choice(argChoices...)
|
||||
args := p.ZeroOrMore(argChoice)
|
||||
|
||||
toolParser := p.Rule("tool-"+tool.Name, p.Seq(
|
||||
p.ToolOpen(p.Seq(
|
||||
p.Literal("<function="),
|
||||
p.ToolName(p.Literal(tool.Name)),
|
||||
p.Literal(">\n"),
|
||||
)),
|
||||
args,
|
||||
p.ToolClose(p.Literal("</function>")),
|
||||
))
|
||||
toolParsers = append(toolParsers, toolParser)
|
||||
}
|
||||
|
||||
toolCall := p.TriggerRule("tool-call", p.Seq(
|
||||
p.Literal("<tool_call>"), p.Space(),
|
||||
p.Choice(toolParsers...), p.Space(),
|
||||
p.Literal("</tool_call>"),
|
||||
))
|
||||
|
||||
return p.Seq(content, p.ZeroOrMore(p.Seq(p.Space(), toolCall)), p.End())
|
||||
})
|
||||
|
||||
input := "Let me search the knowledge base for cat pictures." +
|
||||
"<tool_call>\n" +
|
||||
"<function=search_knowledge_base>\n" +
|
||||
"<parameter=query>cat pictures</parameter>\n" +
|
||||
"<parameter=category>general</parameter>\n" +
|
||||
"</function>\n" +
|
||||
"</tool_call>"
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.Content).To(Equal("Let me search the knowledge base for cat pictures."))
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("search_knowledge_base"))
|
||||
Expect(msg.ToolCalls[0].Arguments).NotTo(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ExampleQwen3NonCoder", func() {
|
||||
It("parses JSON tool calls", func() {
|
||||
tools := createTools()
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
|
||||
SectionStart: "<tool_call>",
|
||||
SectionEnd: "</tool_call>",
|
||||
Tools: tools,
|
||||
ParallelCalls: true,
|
||||
})
|
||||
return p.Seq(
|
||||
p.Content(p.Until("<tool_call>")),
|
||||
p.Optional(p.Seq(p.Space(), toolCall)),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
|
||||
input := "I need to get the weather.\n" +
|
||||
"<tool_call>" +
|
||||
`{"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}}` +
|
||||
"</tool_call>"
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.Content).To(Equal("I need to get the weather.\n"))
|
||||
Expect(msg.ReasoningContent).To(BeEmpty())
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
|
||||
Expect(msg.ToolCalls[0].Arguments).To(Equal(`{"location": "New York City, NY", "unit": "fahrenheit"}`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Command7", func() {
|
||||
var parser *peg.Arena
|
||||
|
||||
BeforeEach(func() {
|
||||
parser = peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
thinking := p.ReasoningBlock(p.Seq(
|
||||
p.Literal("<|START_THINKING|>"), p.Space(),
|
||||
p.Reasoning(p.Until("<|END_THINKING|>")), p.Space(),
|
||||
p.Literal("<|END_THINKING|>"),
|
||||
))
|
||||
|
||||
response := p.Seq(
|
||||
p.Literal("<|START_RESPONSE|>"), p.Space(),
|
||||
p.Content(p.Until("<|END_RESPONSE|>")), p.Space(),
|
||||
p.Literal("<|END_RESPONSE|>"),
|
||||
)
|
||||
|
||||
toolCallID := p.Atomic(p.Seq(
|
||||
p.Literal("\"tool_call_id\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
|
||||
))
|
||||
toolCallName := p.Atomic(p.Seq(
|
||||
p.Literal("\"tool_name\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
|
||||
))
|
||||
toolCallArgs := p.Seq(
|
||||
p.Literal("\"parameters\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.ToolArgs(p.JSON()),
|
||||
)
|
||||
|
||||
toolCallFields := p.Rule("tool-call-fields", p.Choice(toolCallID, toolCallName, toolCallArgs))
|
||||
toolCall := p.Rule("tool-call-single", p.Tool(p.Seq(
|
||||
p.ToolOpen(p.Literal("{")), p.Space(),
|
||||
toolCallFields,
|
||||
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCallFields)),
|
||||
p.Space(), p.ToolClose(p.Literal("}")),
|
||||
)))
|
||||
|
||||
toolCalls := p.Rule("tool-calls", p.Seq(
|
||||
p.Literal("<|START_ACTION|>"), p.Space(),
|
||||
p.Literal("["), p.Space(),
|
||||
toolCall,
|
||||
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCall)),
|
||||
p.Space(), p.Literal("]"), p.Space(),
|
||||
p.Literal("<|END_ACTION|>"),
|
||||
))
|
||||
|
||||
return p.Seq(
|
||||
p.Optional(p.Seq(thinking, p.Space())),
|
||||
p.Choice(toolCalls, response),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
It("parses tool call with reasoning", func() {
|
||||
input := "<|START_THINKING|>I need to plan a trip to Japan.\n<|END_THINKING|>" +
|
||||
"<|START_ACTION|>[" +
|
||||
`{"tool_call_id": "call_0", "tool_name": "plan_trip", "parameters": {"destination": "Japan", "duration": 14}}` +
|
||||
"]<|END_ACTION|>"
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.ReasoningContent).To(Equal("I need to plan a trip to Japan.\n"))
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("plan_trip"))
|
||||
Expect(msg.ToolCalls[0].ID).To(Equal("call_0"))
|
||||
})
|
||||
|
||||
It("parses content-only response", func() {
|
||||
input := "<|START_RESPONSE|>Hello, how can I help you?<|END_RESPONSE|>"
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.Content).To(Equal("Hello, how can I help you?"))
|
||||
Expect(msg.ToolCalls).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("PrefixToolNames", func() {
|
||||
var parser *peg.Arena
|
||||
|
||||
BeforeEach(func() {
|
||||
tools := []peg.ToolDef{
|
||||
{Name: "special_function", Properties: map[string]peg.PropDef{"arg1": {Type: "string"}}},
|
||||
{Name: "special_function_with_opt", Properties: map[string]peg.PropDef{"arg1": {Type: "string"}}},
|
||||
}
|
||||
|
||||
parser = peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardConstructedTools(
|
||||
map[string]string{},
|
||||
tools,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
return p.Seq(
|
||||
p.Content(p.Until("<tool_call>")),
|
||||
p.Optional(p.Seq(p.Space(), toolCall)),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
It("parses long tool name", func() {
|
||||
input := "Let me call the function.<tool_call><function=special_function_with_opt><param=arg1>42</param></function></tool_call>"
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
|
||||
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("special_function_with_opt"))
|
||||
})
|
||||
|
||||
It("parses short tool name", func() {
|
||||
input := "Let me call the function.<tool_call><function=special_function><param=arg1>42</param></function></tool_call>"
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
|
||||
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("special_function"))
|
||||
})
|
||||
|
||||
It("never prematurely matches during incremental parsing", func() {
|
||||
input := "Let me call the function." +
|
||||
"<tool_call>" +
|
||||
"<function=special_function_with_opt>" +
|
||||
"<param=arg1>42</param>" +
|
||||
"</function>" +
|
||||
"</tool_call>"
|
||||
|
||||
tokens := simpleTokenize(input)
|
||||
var accumulated string
|
||||
|
||||
for i, tok := range tokens {
|
||||
accumulated += tok
|
||||
isPartial := i < len(tokens)-1
|
||||
|
||||
ctx := peg.NewParseContext(accumulated, isPartial)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).NotTo(Equal(peg.Fail), "parse failed at token %d, input: %s", i, accumulated)
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
|
||||
for _, tc := range mapper.Result.ToolCalls {
|
||||
Expect(tc.Name).NotTo(Equal("special_function"),
|
||||
"premature tool name match at token %d, input: %s", i, accumulated)
|
||||
}
|
||||
}
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
|
||||
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("special_function_with_opt"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IncrementalParsing", func() {
|
||||
It("handles qwen3 coder format incrementally", func() {
|
||||
tools := createTools()
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
content := p.Rule("content", p.Content(p.Until("<tool_call>")))
|
||||
|
||||
var toolParsers []peg.ParserID
|
||||
for _, tool := range tools {
|
||||
var argChoices []peg.ParserID
|
||||
for propName, prop := range tool.Properties {
|
||||
var argValueParser peg.ParserID
|
||||
if prop.Type == "string" {
|
||||
argValueParser = p.ToolArgStringValue(p.UntilOneOf("</parameter>\n<parameter=", "</parameter>\n</function>"))
|
||||
} else {
|
||||
argValueParser = p.ToolArgJSONValue(p.JSON())
|
||||
}
|
||||
arg := p.ToolArg(p.Seq(
|
||||
p.ToolArgOpen(p.Literal("<parameter="+propName+">")),
|
||||
argValueParser,
|
||||
p.ToolArgClose(p.Seq(
|
||||
p.Literal("</parameter>\n"),
|
||||
p.Peek(p.Choice(p.Literal("<parameter="), p.Literal("</function>"))),
|
||||
)),
|
||||
))
|
||||
argChoices = append(argChoices, arg)
|
||||
}
|
||||
argChoice := p.Choice(argChoices...)
|
||||
args := p.ZeroOrMore(argChoice)
|
||||
toolParser := p.Rule("tool-"+tool.Name, p.Seq(
|
||||
p.ToolOpen(p.Seq(p.Literal("<function="), p.ToolName(p.Literal(tool.Name)), p.Literal(">\n"))),
|
||||
args,
|
||||
p.ToolClose(p.Literal("</function>")),
|
||||
))
|
||||
toolParsers = append(toolParsers, toolParser)
|
||||
}
|
||||
toolCall := p.TriggerRule("tool-call", p.Seq(
|
||||
p.Literal("<tool_call>"), p.Space(),
|
||||
p.Choice(toolParsers...), p.Space(),
|
||||
p.Literal("</tool_call>"),
|
||||
))
|
||||
return p.Seq(content, p.ZeroOrMore(p.Seq(p.Space(), toolCall)), p.End())
|
||||
})
|
||||
|
||||
input := "Let me search the knowledge base for cat pictures." +
|
||||
"<tool_call>\n" +
|
||||
"<function=search_knowledge_base>\n" +
|
||||
"<parameter=query>cat pictures</parameter>\n" +
|
||||
"<parameter=category>general</parameter>\n" +
|
||||
"</function>\n" +
|
||||
"</tool_call>"
|
||||
|
||||
tokens := simpleTokenize(input)
|
||||
var accumulated string
|
||||
var prevToolCalls int
|
||||
|
||||
for i, tok := range tokens {
|
||||
accumulated += tok
|
||||
isPartial := i < len(tokens)-1
|
||||
|
||||
ctx := peg.NewParseContext(accumulated, isPartial)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).NotTo(Equal(peg.Fail), "parse failed at token %d, input: %s", i, accumulated)
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
|
||||
Expect(len(mapper.Result.ToolCalls)).To(BeNumerically(">=", prevToolCalls),
|
||||
"tool call count decreased at token %d", i)
|
||||
prevToolCalls = len(mapper.Result.ToolCalls)
|
||||
}
|
||||
})
|
||||
|
||||
It("handles qwen3 non-coder format incrementally", func() {
|
||||
tools := createTools()
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
|
||||
SectionStart: "<tool_call>",
|
||||
SectionEnd: "</tool_call>",
|
||||
Tools: tools,
|
||||
ParallelCalls: true,
|
||||
})
|
||||
return p.Seq(
|
||||
p.Content(p.Until("<tool_call>")),
|
||||
p.Optional(p.Seq(p.Space(), toolCall)),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
|
||||
input := "I need to get the weather.\n" +
|
||||
"<tool_call>" +
|
||||
`{"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}}` +
|
||||
"</tool_call>"
|
||||
|
||||
tokens := simpleTokenize(input)
|
||||
var accumulated string
|
||||
|
||||
for i, tok := range tokens {
|
||||
accumulated += tok
|
||||
isPartial := i < len(tokens)-1
|
||||
|
||||
ctx := peg.NewParseContext(accumulated, isPartial)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).NotTo(Equal(peg.Fail), "parse failed at token %d, input: %s", i, accumulated)
|
||||
}
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
|
||||
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("get_current_weather"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Command7 complex input", func() {
|
||||
It("parses complex reasoning and tool calls", func() {
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
thinking := p.ReasoningBlock(p.Seq(
|
||||
p.Literal("<|START_THINKING|>"), p.Space(),
|
||||
p.Reasoning(p.Until("<|END_THINKING|>")), p.Space(),
|
||||
p.Literal("<|END_THINKING|>"),
|
||||
))
|
||||
|
||||
toolCallID := p.Atomic(p.Seq(
|
||||
p.Literal("\"tool_call_id\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
|
||||
))
|
||||
toolCallName := p.Atomic(p.Seq(
|
||||
p.Literal("\"tool_name\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
|
||||
))
|
||||
toolCallArgs := p.Seq(
|
||||
p.Literal("\"parameters\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.ToolArgs(p.JSON()),
|
||||
)
|
||||
|
||||
toolCallFields := p.Rule("tool-call-fields", p.Choice(toolCallID, toolCallName, toolCallArgs))
|
||||
toolCall := p.Rule("tool-call-single", p.Tool(p.Seq(
|
||||
p.ToolOpen(p.Literal("{")), p.Space(),
|
||||
toolCallFields,
|
||||
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCallFields)),
|
||||
p.Space(), p.ToolClose(p.Literal("}")),
|
||||
)))
|
||||
|
||||
toolCalls := p.Rule("tool-calls", p.Seq(
|
||||
p.Literal("<|START_ACTION|>"), p.Space(),
|
||||
p.Literal("["), p.Space(),
|
||||
toolCall,
|
||||
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCall)),
|
||||
p.Space(), p.Literal("]"), p.Space(),
|
||||
p.Literal("<|END_ACTION|>"),
|
||||
))
|
||||
|
||||
return p.Seq(
|
||||
p.Optional(p.Seq(thinking, p.Space())),
|
||||
toolCalls,
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
|
||||
reasoning := "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " +
|
||||
"budget of $4000 for a two-week stay, we need to:\n\n" +
|
||||
"1. Identify key historical sites and modern attractions in Japan.\n" +
|
||||
"2. Find affordable accommodation options that provide a balance between comfort and cost.\n" +
|
||||
"3. Determine the best modes of transportation for getting around Japan.\n" +
|
||||
"4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without " +
|
||||
"overspending.\n" +
|
||||
"5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees " +
|
||||
"to attractions."
|
||||
|
||||
input := "<|START_THINKING|>" + reasoning + "<|END_THINKING|>" +
|
||||
`<|START_ACTION|>[{"tool_call_id": "call_0", "tool_name": "plan_trip", "parameters": {"destination": "Japan", "duration": 14, "budget": 4000, "interests": ["historical sites", "modern attractions"], "accommodation_preferences": "affordable", "transportation_preferences": "efficient", "meal_preferences": "local cuisine"}}]<|END_ACTION|>`
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.ReasoningContent).To(Equal(reasoning))
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("plan_trip"))
|
||||
Expect(msg.ToolCalls[0].ID).To(Equal("call_0"))
|
||||
Expect(msg.ToolCalls[0].Arguments).To(ContainSubstring(`"interests"`))
|
||||
Expect(msg.ToolCalls[0].Arguments).To(ContainSubstring(`"historical sites"`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ForceToolCalls", func() {
|
||||
var parser *peg.Arena
|
||||
|
||||
BeforeEach(func() {
|
||||
tools := createTools()
|
||||
parser = peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
|
||||
SectionStart: "<tool_call>[",
|
||||
SectionEnd: "]</tool_call>",
|
||||
Tools: tools,
|
||||
ForceToolCalls: true,
|
||||
})
|
||||
return p.Seq(
|
||||
p.Content(p.Until("<tool_call>")),
|
||||
p.Space(),
|
||||
toolCall,
|
||||
p.Space(),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
It("succeeds with tool call present", func() {
|
||||
input := "Let me check." +
|
||||
`<tool_call>[{"name": "get_current_weather", "arguments": {"location": "NYC", "unit": "celsius"}}]</tool_call>`
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
|
||||
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("get_current_weather"))
|
||||
})
|
||||
|
||||
It("fails without tool call", func() {
|
||||
input := "Just a response without any tool calls."
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
})
|
||||
|
||||
Context("NestedKeysJSONTools", func() {
|
||||
It("parses nested function.name and function.arguments keys", func() {
|
||||
tools := []peg.ToolDef{
|
||||
{
|
||||
Name: "get_current_weather",
|
||||
Properties: map[string]peg.PropDef{
|
||||
"location": {Type: "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
|
||||
SectionStart: "<tool_call>",
|
||||
SectionEnd: "</tool_call>",
|
||||
Tools: tools,
|
||||
NameKey: "function.name",
|
||||
ArgsKey: "function.arguments",
|
||||
CallIDKey: "id",
|
||||
})
|
||||
return p.Seq(
|
||||
p.Content(p.Until("<tool_call>")),
|
||||
p.Optional(p.Seq(p.Space(), toolCall)),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
|
||||
input := `Let me check.<tool_call>{"id": "call_123", "function": {"name": "get_current_weather", "arguments": {"location": "NYC"}}}</tool_call>`
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
|
||||
Expect(msg.ToolCalls[0].ID).To(Equal("call_123"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Command7 incremental", func() {
|
||||
It("handles incremental parsing without regressions", func() {
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
thinking := p.ReasoningBlock(p.Seq(
|
||||
p.Literal("<|START_THINKING|>"), p.Space(),
|
||||
p.Reasoning(p.Until("<|END_THINKING|>")), p.Space(),
|
||||
p.Literal("<|END_THINKING|>"),
|
||||
))
|
||||
|
||||
toolCallID := p.Atomic(p.Seq(
|
||||
p.Literal("\"tool_call_id\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
|
||||
))
|
||||
toolCallName := p.Atomic(p.Seq(
|
||||
p.Literal("\"tool_name\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
|
||||
))
|
||||
toolCallArgs := p.Seq(
|
||||
p.Literal("\"parameters\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.ToolArgs(p.JSON()),
|
||||
)
|
||||
|
||||
toolCallFields := p.Rule("tool-call-fields", p.Choice(toolCallID, toolCallName, toolCallArgs))
|
||||
toolCall := p.Rule("tool-call-single", p.Tool(p.Seq(
|
||||
p.ToolOpen(p.Literal("{")), p.Space(),
|
||||
toolCallFields,
|
||||
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCallFields)),
|
||||
p.Space(), p.ToolClose(p.Literal("}")),
|
||||
)))
|
||||
|
||||
toolCalls := p.Rule("tool-calls", p.Seq(
|
||||
p.Literal("<|START_ACTION|>"), p.Space(),
|
||||
p.Literal("["), p.Space(),
|
||||
toolCall,
|
||||
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCall)),
|
||||
p.Space(), p.Literal("]"), p.Space(),
|
||||
p.Literal("<|END_ACTION|>"),
|
||||
))
|
||||
|
||||
return p.Seq(
|
||||
p.Optional(p.Seq(thinking, p.Space())),
|
||||
toolCalls,
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
|
||||
reasoning := "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " +
|
||||
"budget of $4000 for a two-week stay, we need to:\n\n" +
|
||||
"1. Identify key historical sites and modern attractions in Japan.\n" +
|
||||
"2. Find affordable accommodation options.\n" +
|
||||
"3. Determine the best modes of transportation.\n" +
|
||||
"4. Create a day-by-day itinerary.\n" +
|
||||
"5. Provide a detailed cost breakdown."
|
||||
|
||||
input := "<|START_THINKING|>" + reasoning + "<|END_THINKING|>" +
|
||||
`<|START_ACTION|>[{"tool_call_id": "call_0", "tool_name": "plan_trip", "parameters": {"destination": "Japan", "duration": 14, "budget": 4000, "interests": ["historical sites", "modern attractions"]}}]<|END_ACTION|>`
|
||||
|
||||
tokens := simpleTokenize(input)
|
||||
var accumulated string
|
||||
var prevToolCalls int
|
||||
|
||||
for i, tok := range tokens {
|
||||
accumulated += tok
|
||||
isPartial := i < len(tokens)-1
|
||||
|
||||
ctx := peg.NewParseContext(accumulated, isPartial)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).NotTo(Equal(peg.Fail), "parse failed at token %d, accumulated length=%d", i, len(accumulated))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
|
||||
Expect(len(mapper.Result.ToolCalls)).To(BeNumerically(">=", prevToolCalls),
|
||||
"tool call count decreased at token %d", i)
|
||||
prevToolCalls = len(mapper.Result.ToolCalls)
|
||||
}
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.ReasoningContent).To(Equal(reasoning))
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("plan_trip"))
|
||||
Expect(msg.ToolCalls[0].ID).To(Equal("call_0"))
|
||||
})
|
||||
})
|
||||
})
|
||||
831
pkg/functions/peg/parser.go
Normal file
831
pkg/functions/peg/parser.go
Normal file
@@ -0,0 +1,831 @@
|
||||
package peg
|
||||
|
||||
|
||||
// Parser is the interface all parser types implement.
|
||||
type Parser interface {
|
||||
parse(arena *Arena, ctx *ParseContext, start int) ParseResult
|
||||
}
|
||||
|
||||
// EpsilonParser always succeeds, consumes nothing.
|
||||
type EpsilonParser struct{}
|
||||
|
||||
func (p *EpsilonParser) parse(_ *Arena, _ *ParseContext, start int) ParseResult {
|
||||
return NewParseResult(Success, start)
|
||||
}
|
||||
|
||||
// StartParser matches start of input.
|
||||
type StartParser struct{}
|
||||
|
||||
func (p *StartParser) parse(_ *Arena, _ *ParseContext, start int) ParseResult {
|
||||
if start == 0 {
|
||||
return NewParseResult(Success, start)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
|
||||
// EndParser matches end of input.
|
||||
type EndParser struct{}
|
||||
|
||||
func (p *EndParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
if start >= len(ctx.Input) {
|
||||
return NewParseResult(Success, start)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
|
||||
// LiteralParser matches an exact string.
|
||||
type LiteralParser struct {
|
||||
Literal string
|
||||
}
|
||||
|
||||
func (p *LiteralParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
pos := start
|
||||
for i := 0; i < len(p.Literal); i++ {
|
||||
if pos >= len(ctx.Input) {
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
if ctx.Input[pos] != p.Literal[i] {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
|
||||
// SequenceParser matches children in order.
|
||||
type SequenceParser struct {
|
||||
Children []ParserID
|
||||
}
|
||||
|
||||
func (p *SequenceParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
pos := start
|
||||
var nodes []AstID
|
||||
|
||||
for _, childID := range p.Children {
|
||||
result := arena.ParseAt(childID, ctx, pos)
|
||||
|
||||
if result.Type == Fail {
|
||||
if ctx.IsPartial && result.End >= len(ctx.Input) {
|
||||
return NewParseResultNodes(NeedMoreInput, start, result.End, nodes)
|
||||
}
|
||||
return NewParseResultRange(Fail, start, result.End)
|
||||
}
|
||||
|
||||
nodes = append(nodes, result.Nodes...)
|
||||
|
||||
if result.Type == NeedMoreInput {
|
||||
return NewParseResultNodes(NeedMoreInput, start, result.End, nodes)
|
||||
}
|
||||
|
||||
pos = result.End
|
||||
}
|
||||
|
||||
return NewParseResultNodes(Success, start, pos, nodes)
|
||||
}
|
||||
|
||||
// ChoiceParser tries each alternative until one succeeds.
|
||||
type ChoiceParser struct {
|
||||
Children []ParserID
|
||||
}
|
||||
|
||||
func (p *ChoiceParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
for _, childID := range p.Children {
|
||||
result := arena.ParseAt(childID, ctx, start)
|
||||
if result.Type != Fail {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
|
||||
// RepetitionParser matches min to max repetitions.
|
||||
type RepetitionParser struct {
|
||||
Child ParserID
|
||||
MinCount int
|
||||
MaxCount int // -1 for unbounded
|
||||
}
|
||||
|
||||
func (p *RepetitionParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
pos := start
|
||||
matchCount := 0
|
||||
var nodes []AstID
|
||||
|
||||
for p.MaxCount == -1 || matchCount < p.MaxCount {
|
||||
if pos >= len(ctx.Input) {
|
||||
break
|
||||
}
|
||||
|
||||
result := arena.ParseAt(p.Child, ctx, pos)
|
||||
|
||||
if result.Type == Success {
|
||||
// Prevent infinite loop on empty matches
|
||||
if result.End == pos {
|
||||
break
|
||||
}
|
||||
nodes = append(nodes, result.Nodes...)
|
||||
pos = result.End
|
||||
matchCount++
|
||||
continue
|
||||
}
|
||||
|
||||
if result.Type == NeedMoreInput {
|
||||
nodes = append(nodes, result.Nodes...)
|
||||
return NewParseResultNodes(NeedMoreInput, start, result.End, nodes)
|
||||
}
|
||||
|
||||
// Child failed
|
||||
break
|
||||
}
|
||||
|
||||
if p.MinCount > 0 && matchCount < p.MinCount {
|
||||
if pos >= len(ctx.Input) && ctx.IsPartial {
|
||||
return NewParseResultNodes(NeedMoreInput, start, pos, nodes)
|
||||
}
|
||||
return NewParseResultRange(Fail, start, pos)
|
||||
}
|
||||
|
||||
return NewParseResultNodes(Success, start, pos, nodes)
|
||||
}
|
||||
|
||||
// AndParser is a positive lookahead — succeeds if child succeeds, consumes nothing.
|
||||
type AndParser struct {
|
||||
Child ParserID
|
||||
}
|
||||
|
||||
func (p *AndParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
result := arena.ParseAt(p.Child, ctx, start)
|
||||
return NewParseResult(result.Type, start)
|
||||
}
|
||||
|
||||
// NotParser is a negative lookahead — succeeds if child fails, consumes nothing.
|
||||
type NotParser struct {
|
||||
Child ParserID
|
||||
}
|
||||
|
||||
func (p *NotParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
result := arena.ParseAt(p.Child, ctx, start)
|
||||
if result.Type == Success {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if result.Type == NeedMoreInput {
|
||||
return NewParseResult(NeedMoreInput, start)
|
||||
}
|
||||
return NewParseResult(Success, start)
|
||||
}
|
||||
|
||||
// AnyParser matches any single UTF-8 codepoint.
|
||||
type AnyParser struct{}
|
||||
|
||||
func (p *AnyParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
_, size, status := parseUTF8Codepoint(ctx.Input, start)
|
||||
if status == utf8Incomplete {
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResult(NeedMoreInput, start)
|
||||
}
|
||||
if status == utf8Invalid {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(Success, start, start+size)
|
||||
}
|
||||
|
||||
// SpaceParser matches zero or more whitespace characters.
|
||||
type SpaceParser struct{}
|
||||
|
||||
func (p *SpaceParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
pos := start
|
||||
for pos < len(ctx.Input) {
|
||||
c := ctx.Input[pos]
|
||||
if c == ' ' || c == '\t' || c == '\n' || c == '\r' || c == '\v' || c == '\f' {
|
||||
pos++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
|
||||
// CharRange represents a range of Unicode codepoints.
|
||||
type CharRange struct {
|
||||
Start rune
|
||||
End rune
|
||||
}
|
||||
|
||||
func (r CharRange) Contains(cp rune) bool {
|
||||
return cp >= r.Start && cp <= r.End
|
||||
}
|
||||
|
||||
// CharsParser matches characters from a character class.
|
||||
type CharsParser struct {
|
||||
Pattern string
|
||||
Ranges []CharRange
|
||||
Negated bool
|
||||
MinCount int
|
||||
MaxCount int // -1 for unbounded
|
||||
}
|
||||
|
||||
func (p *CharsParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
pos := start
|
||||
matchCount := 0
|
||||
|
||||
for p.MaxCount == -1 || matchCount < p.MaxCount {
|
||||
r, size, status := parseUTF8Codepoint(ctx.Input, pos)
|
||||
|
||||
if status == utf8Incomplete {
|
||||
if matchCount >= p.MinCount {
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
|
||||
if status == utf8Invalid {
|
||||
if matchCount >= p.MinCount {
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
|
||||
matches := false
|
||||
for _, cr := range p.Ranges {
|
||||
if cr.Contains(r) {
|
||||
matches = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if p.Negated {
|
||||
matches = !matches
|
||||
}
|
||||
|
||||
if matches {
|
||||
pos += size
|
||||
matchCount++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchCount < p.MinCount {
|
||||
if pos >= len(ctx.Input) && ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResultRange(Fail, start, pos)
|
||||
}
|
||||
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
|
||||
// JSONStringParser matches JSON string content (without quotes).
|
||||
type JSONStringParser struct{}
|
||||
|
||||
func (p *JSONStringParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
pos := start
|
||||
|
||||
for pos < len(ctx.Input) {
|
||||
c := ctx.Input[pos]
|
||||
|
||||
if c == '"' {
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
|
||||
if c == '\\' {
|
||||
result := handleEscapeSequence(ctx, start, pos)
|
||||
if result.Type != Success {
|
||||
return result
|
||||
}
|
||||
pos = result.End
|
||||
} else {
|
||||
_, size, status := parseUTF8Codepoint(ctx.Input, pos)
|
||||
if status == utf8Incomplete {
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
if status == utf8Invalid {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos += size
|
||||
}
|
||||
}
|
||||
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResultRange(Fail, start, pos)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
|
||||
// PythonDictStringParser matches single-quoted string content (without quotes).
|
||||
// Like JSONStringParser but terminates on single quote instead of double quote.
|
||||
type PythonDictStringParser struct{}
|
||||
|
||||
func (p *PythonDictStringParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
pos := start
|
||||
|
||||
for pos < len(ctx.Input) {
|
||||
c := ctx.Input[pos]
|
||||
|
||||
if c == '\'' {
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
|
||||
if c == '\\' {
|
||||
result := handleEscapeSequence(ctx, start, pos)
|
||||
if result.Type != Success {
|
||||
return result
|
||||
}
|
||||
pos = result.End
|
||||
} else {
|
||||
_, size, status := parseUTF8Codepoint(ctx.Input, pos)
|
||||
if status == utf8Incomplete {
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
if status == utf8Invalid {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos += size
|
||||
}
|
||||
}
|
||||
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResultRange(Fail, start, pos)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
|
||||
func handleEscapeSequence(ctx *ParseContext, start int, pos int) ParseResult {
|
||||
pos++ // consume '\'
|
||||
if pos >= len(ctx.Input) {
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
|
||||
switch ctx.Input[pos] {
|
||||
case '"', '\'', '\\', '/', 'b', 'f', 'n', 'r', 't':
|
||||
pos++
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
case 'u':
|
||||
return handleUnicodeEscape(ctx, start, pos)
|
||||
default:
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
}
|
||||
|
||||
func handleUnicodeEscape(ctx *ParseContext, start int, pos int) ParseResult {
|
||||
pos++ // consume 'u'
|
||||
for i := 0; i < 4; i++ {
|
||||
if pos >= len(ctx.Input) {
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
if !isHexDigit(ctx.Input[pos]) {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
|
||||
func isHexDigit(c byte) bool {
|
||||
return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')
|
||||
}
|
||||
|
||||
// UntilParser matches everything until one of the delimiters is found.
|
||||
type UntilParser struct {
|
||||
Delimiters []string
|
||||
}
|
||||
|
||||
func (p *UntilParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
matcher := newTrie(p.Delimiters)
|
||||
|
||||
pos := start
|
||||
lastValidPos := start
|
||||
|
||||
for pos < len(ctx.Input) {
|
||||
_, size, status := parseUTF8Codepoint(ctx.Input, pos)
|
||||
|
||||
if status == utf8Incomplete {
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, lastValidPos)
|
||||
}
|
||||
|
||||
if status == utf8Invalid {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
|
||||
match := matcher.checkAt(ctx.Input, pos)
|
||||
|
||||
if match == trieCompleteMatch {
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
|
||||
if match == triePartialMatch {
|
||||
return NewParseResultRange(Success, start, pos)
|
||||
}
|
||||
|
||||
pos += size
|
||||
lastValidPos = pos
|
||||
}
|
||||
|
||||
if lastValidPos == len(ctx.Input) && ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, lastValidPos)
|
||||
}
|
||||
return NewParseResultRange(Success, start, lastValidPos)
|
||||
}
|
||||
|
||||
// RuleParser creates an AST node with a rule name.
|
||||
type RuleParser struct {
|
||||
Name string
|
||||
Child ParserID
|
||||
Trigger bool
|
||||
}
|
||||
|
||||
func (p *RuleParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
result := arena.ParseAt(p.Child, ctx, start)
|
||||
|
||||
if result.Type != Fail {
|
||||
text := ""
|
||||
if result.Start < len(ctx.Input) {
|
||||
end := result.End
|
||||
if end > len(ctx.Input) {
|
||||
end = len(ctx.Input)
|
||||
}
|
||||
text = ctx.Input[result.Start:end]
|
||||
}
|
||||
|
||||
nodeID := ctx.Ast.AddNode(
|
||||
p.Name, "", result.Start, result.End, text,
|
||||
result.Nodes, result.Type == NeedMoreInput,
|
||||
)
|
||||
|
||||
return NewParseResultNodes(result.Type, result.Start, result.End, []AstID{nodeID})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// RefParser references a named rule (resolved during Build).
|
||||
type RefParser struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (p *RefParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
ruleID := arena.GetRule(p.Name)
|
||||
return arena.ParseAt(ruleID, ctx, start)
|
||||
}
|
||||
|
||||
// AtomicParser suppresses partial AST nodes.
|
||||
type AtomicParser struct {
|
||||
Child ParserID
|
||||
}
|
||||
|
||||
func (p *AtomicParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
result := arena.ParseAt(p.Child, ctx, start)
|
||||
if result.Type == NeedMoreInput {
|
||||
result.Nodes = nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// TagParser creates an AST node with a semantic tag.
|
||||
type TagParser struct {
|
||||
Child ParserID
|
||||
Tag string
|
||||
}
|
||||
|
||||
func (p *TagParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
result := arena.ParseAt(p.Child, ctx, start)
|
||||
|
||||
if result.Type != Fail {
|
||||
text := ""
|
||||
if result.Start < len(ctx.Input) {
|
||||
end := result.End
|
||||
if end > len(ctx.Input) {
|
||||
end = len(ctx.Input)
|
||||
}
|
||||
text = ctx.Input[result.Start:end]
|
||||
}
|
||||
|
||||
nodeID := ctx.Ast.AddNode(
|
||||
"", p.Tag, result.Start, result.End, text,
|
||||
result.Nodes, result.Type == NeedMoreInput,
|
||||
)
|
||||
|
||||
return NewParseResultNodes(result.Type, result.Start, result.End, []AstID{nodeID})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// SchemaParser wraps a parser with schema metadata (pass-through at parse time).
|
||||
type SchemaParser struct {
|
||||
Child ParserID
|
||||
Name string
|
||||
}
|
||||
|
||||
func (p *SchemaParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
return arena.ParseAt(p.Child, ctx, start)
|
||||
}
|
||||
|
||||
// JSONParser matches a complete JSON value (object, array, string, number, bool, null).
|
||||
type JSONParser struct {
|
||||
arena *Arena
|
||||
}
|
||||
|
||||
func (p *JSONParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
|
||||
return parseJSONValue(ctx, start, start)
|
||||
}
|
||||
|
||||
func isWhitespace(c byte) bool {
|
||||
return c == ' ' || c == '\t' || c == '\n' || c == '\r'
|
||||
}
|
||||
|
||||
func parseLiteralAt(ctx *ParseContext, start, pos int, lit string) ParseResult {
|
||||
for i := 0; i < len(lit); i++ {
|
||||
if pos+i >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos+i)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[pos+i] != lit[i] {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
}
|
||||
return NewParseResultRange(Success, start, pos+len(lit))
|
||||
}
|
||||
|
||||
func parseJSONString(ctx *ParseContext, start, pos int) ParseResult {
|
||||
pos++ // skip opening "
|
||||
for pos < len(ctx.Input) {
|
||||
c := ctx.Input[pos]
|
||||
if c == '"' {
|
||||
return NewParseResultRange(Success, start, pos+1)
|
||||
}
|
||||
if c == '\\' {
|
||||
pos++
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
switch ctx.Input[pos] {
|
||||
case '"', '\\', '/', 'b', 'f', 'n', 'r', 't':
|
||||
pos++
|
||||
case 'u':
|
||||
pos++
|
||||
for i := 0; i < 4; i++ {
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if !isHexDigit(ctx.Input[pos]) {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos++
|
||||
}
|
||||
default:
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
} else {
|
||||
_, size, status := parseUTF8Codepoint(ctx.Input, pos)
|
||||
if status == utf8Incomplete {
|
||||
if !ctx.IsPartial {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
if status == utf8Invalid {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos += size
|
||||
}
|
||||
}
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
|
||||
func parseJSONNumber(ctx *ParseContext, start, pos int) ParseResult {
|
||||
p := pos
|
||||
if p < len(ctx.Input) && ctx.Input[p] == '-' {
|
||||
p++
|
||||
}
|
||||
if p >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, p)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[p] == '0' {
|
||||
p++
|
||||
} else if ctx.Input[p] >= '1' && ctx.Input[p] <= '9' {
|
||||
p++
|
||||
for p < len(ctx.Input) && ctx.Input[p] >= '0' && ctx.Input[p] <= '9' {
|
||||
p++
|
||||
}
|
||||
} else {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if p < len(ctx.Input) && ctx.Input[p] == '.' {
|
||||
p++
|
||||
digitStart := p
|
||||
for p < len(ctx.Input) && ctx.Input[p] >= '0' && ctx.Input[p] <= '9' {
|
||||
p++
|
||||
}
|
||||
if p == digitStart {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, p)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
}
|
||||
if p < len(ctx.Input) && (ctx.Input[p] == 'e' || ctx.Input[p] == 'E') {
|
||||
p++
|
||||
if p < len(ctx.Input) && (ctx.Input[p] == '+' || ctx.Input[p] == '-') {
|
||||
p++
|
||||
}
|
||||
digitStart := p
|
||||
for p < len(ctx.Input) && ctx.Input[p] >= '0' && ctx.Input[p] <= '9' {
|
||||
p++
|
||||
}
|
||||
if p == digitStart {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, p)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
}
|
||||
|
||||
// In partial mode, check if the next character could continue the number.
|
||||
// This prevents premature commits (e.g. returning "3" when "3.14" is incoming).
|
||||
if ctx.IsPartial && p >= len(ctx.Input) {
|
||||
return NewParseResultRange(NeedMoreInput, start, p)
|
||||
}
|
||||
if ctx.IsPartial && p < len(ctx.Input) && isNumberContinuation(ctx.Input[p]) {
|
||||
return NewParseResultRange(NeedMoreInput, start, p)
|
||||
}
|
||||
|
||||
return NewParseResultRange(Success, start, p)
|
||||
}
|
||||
|
||||
func isNumberContinuation(c byte) bool {
|
||||
return (c >= '0' && c <= '9') || c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-'
|
||||
}
|
||||
|
||||
func parseJSONObject(ctx *ParseContext, start, pos int) ParseResult {
|
||||
pos++ // skip {
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[pos] == '}' {
|
||||
return NewParseResultRange(Success, start, pos+1)
|
||||
}
|
||||
for {
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
// key
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[pos] != '"' {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
r := parseJSONString(ctx, start, pos)
|
||||
if r.Type != Success {
|
||||
return r
|
||||
}
|
||||
pos = r.End
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
// colon
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[pos] != ':' {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos++
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
// value
|
||||
vr := parseJSONValue(ctx, start, pos)
|
||||
if vr.Type != Success {
|
||||
return vr
|
||||
}
|
||||
pos = vr.End
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[pos] == '}' {
|
||||
return NewParseResultRange(Success, start, pos+1)
|
||||
}
|
||||
if ctx.Input[pos] != ',' {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONArray(ctx *ParseContext, start, pos int) ParseResult {
|
||||
pos++ // skip [
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[pos] == ']' {
|
||||
return NewParseResultRange(Success, start, pos+1)
|
||||
}
|
||||
for {
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
vr := parseJSONValue(ctx, start, pos)
|
||||
if vr.Type != Success {
|
||||
return vr
|
||||
}
|
||||
pos = vr.End
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
if ctx.Input[pos] == ']' {
|
||||
return NewParseResultRange(Success, start, pos+1)
|
||||
}
|
||||
if ctx.Input[pos] != ',' {
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONValue(ctx *ParseContext, start, pos int) ParseResult {
|
||||
pos = skipWS(ctx.Input, pos)
|
||||
if pos >= len(ctx.Input) {
|
||||
if ctx.IsPartial {
|
||||
return NewParseResultRange(NeedMoreInput, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
switch ctx.Input[pos] {
|
||||
case '{':
|
||||
return parseJSONObject(ctx, start, pos)
|
||||
case '[':
|
||||
return parseJSONArray(ctx, start, pos)
|
||||
case '"':
|
||||
return parseJSONString(ctx, start, pos)
|
||||
case 't':
|
||||
return parseLiteralAt(ctx, start, pos, "true")
|
||||
case 'f':
|
||||
return parseLiteralAt(ctx, start, pos, "false")
|
||||
case 'n':
|
||||
return parseLiteralAt(ctx, start, pos, "null")
|
||||
default:
|
||||
if ctx.Input[pos] == '-' || (ctx.Input[pos] >= '0' && ctx.Input[pos] <= '9') {
|
||||
return parseJSONNumber(ctx, start, pos)
|
||||
}
|
||||
return NewParseResult(Fail, start)
|
||||
}
|
||||
}
|
||||
|
||||
func skipWS(input string, pos int) int {
|
||||
for pos < len(input) && isWhitespace(input[pos]) {
|
||||
pos++
|
||||
}
|
||||
return pos
|
||||
}
|
||||
777
pkg/functions/peg/parser_test.go
Normal file
777
pkg/functions/peg/parser_test.go
Normal file
@@ -0,0 +1,777 @@
|
||||
package peg_test
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/pkg/functions/peg"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func extractTags(ast *peg.AstArena, result *peg.ParseResult) map[string]string {
|
||||
tags := make(map[string]string)
|
||||
ast.VisitResult(result, func(node *peg.AstNode) {
|
||||
if node.Tag != "" {
|
||||
tags[node.Tag] = node.Text
|
||||
}
|
||||
})
|
||||
return tags
|
||||
}
|
||||
|
||||
var _ = Describe("PEG Parser", func() {
|
||||
Context("LiteralParser", func() {
|
||||
It("succeeds on exact match", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Literal("hello")
|
||||
})
|
||||
ctx := peg.NewParseContext("hello world", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.Start).To(Equal(0))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
|
||||
It("fails on mismatch", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Literal("hello")
|
||||
})
|
||||
ctx := peg.NewParseContext("world", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
|
||||
It("returns NeedMoreInput in partial mode", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Literal("hello")
|
||||
})
|
||||
ctx := peg.NewParseContext("hel", true)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.NeedMoreInput))
|
||||
})
|
||||
|
||||
It("fails on partial input when not in partial mode", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Literal("hello")
|
||||
})
|
||||
ctx := peg.NewParseContext("hel", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
})
|
||||
|
||||
Context("SequenceParser", func() {
|
||||
It("matches full sequence", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("hello"), b.Literal(" "), b.Literal("world"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello world", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(11))
|
||||
})
|
||||
|
||||
It("fails midway", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("hello"), b.Literal("X"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello world", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
|
||||
It("returns NeedMoreInput at boundary in partial mode", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("hello"), b.Literal(" world"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello wo", true)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.NeedMoreInput))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ChoiceParser", func() {
|
||||
It("matches first alternative", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Choice(b.Literal("hello"), b.Literal("world"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
|
||||
It("matches second alternative", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Choice(b.Literal("hello"), b.Literal("world"))
|
||||
})
|
||||
ctx := peg.NewParseContext("world", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
|
||||
It("fails when all alternatives fail", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Choice(b.Literal("hello"), b.Literal("world"))
|
||||
})
|
||||
ctx := peg.NewParseContext("foo", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
})
|
||||
|
||||
Context("RepetitionParser", func() {
|
||||
It("handles zero or more matches", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.ZeroOrMore(b.Literal("ab"))
|
||||
})
|
||||
|
||||
ctx := peg.NewParseContext("ababab", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(6))
|
||||
|
||||
ctx2 := peg.NewParseContext("xyz", false)
|
||||
r2 := arena.Parse(ctx2)
|
||||
Expect(r2.Type).To(Equal(peg.Success))
|
||||
Expect(r2.End).To(Equal(0))
|
||||
})
|
||||
|
||||
It("handles one or more matches", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.OneOrMore(b.Literal("ab"))
|
||||
})
|
||||
|
||||
ctx := peg.NewParseContext("ababab", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(6))
|
||||
|
||||
ctx2 := peg.NewParseContext("xyz", false)
|
||||
r2 := arena.Parse(ctx2)
|
||||
Expect(r2.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
|
||||
It("handles optional matches", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Optional(b.Literal("hello")), b.Literal("world"))
|
||||
})
|
||||
|
||||
ctx := peg.NewParseContext("helloworld", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(10))
|
||||
|
||||
ctx2 := peg.NewParseContext("world", false)
|
||||
r2 := arena.Parse(ctx2)
|
||||
Expect(r2.Type).To(Equal(peg.Success))
|
||||
Expect(r2.End).To(Equal(5))
|
||||
})
|
||||
|
||||
It("respects bounded repetition", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Repeat(b.Literal("a"), 2, 4)
|
||||
})
|
||||
|
||||
ctx := peg.NewParseContext("aaaaa", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(4))
|
||||
|
||||
ctx2 := peg.NewParseContext("a", false)
|
||||
r2 := arena.Parse(ctx2)
|
||||
Expect(r2.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Lookahead", func() {
|
||||
It("succeeds with positive lookahead", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Peek(b.Literal("hello")), b.Literal("hello"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
|
||||
It("fails with positive lookahead mismatch", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Peek(b.Literal("world")), b.Literal("hello"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
|
||||
It("succeeds with negative lookahead", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Negate(b.Literal("world")), b.Literal("hello"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
|
||||
It("fails with negative lookahead match", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Negate(b.Literal("hello")), b.Literal("hello"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
})
|
||||
|
||||
Context("UntilParser", func() {
|
||||
It("consumes until single delimiter", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Until("<end>"), b.Literal("<end>"))
|
||||
})
|
||||
ctx := peg.NewParseContext("content<end>", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(12))
|
||||
})
|
||||
|
||||
It("consumes until first of multiple delimiters", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.UntilOneOf("<a>", "<b>")
|
||||
})
|
||||
ctx := peg.NewParseContext("content<b>more", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(7))
|
||||
})
|
||||
|
||||
It("consumes rest of input", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Rest()
|
||||
})
|
||||
ctx := peg.NewParseContext("everything", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(10))
|
||||
})
|
||||
|
||||
It("returns NeedMoreInput in partial mode", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Until("<end>")
|
||||
})
|
||||
ctx := peg.NewParseContext("content", true)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.NeedMoreInput))
|
||||
})
|
||||
})
|
||||
|
||||
Context("JSONParser", func() {
|
||||
It("parses objects", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.JSON()
|
||||
})
|
||||
ctx := peg.NewParseContext(`{"key": "value", "num": 42}`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(27))
|
||||
})
|
||||
|
||||
It("parses arrays", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.JSON()
|
||||
})
|
||||
ctx := peg.NewParseContext(`[1, "two", true, null]`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(22))
|
||||
})
|
||||
|
||||
It("parses strings with escapes", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.JSON()
|
||||
})
|
||||
ctx := peg.NewParseContext(`"hello \"world\""`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(17))
|
||||
})
|
||||
|
||||
It("parses numbers", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.JSON()
|
||||
})
|
||||
ctx := peg.NewParseContext(`-123.45e10`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(10))
|
||||
})
|
||||
|
||||
It("parses booleans", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.JSON()
|
||||
})
|
||||
ctx := peg.NewParseContext(`true`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(4))
|
||||
})
|
||||
|
||||
It("parses null", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.JSON()
|
||||
})
|
||||
ctx := peg.NewParseContext(`null`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(4))
|
||||
})
|
||||
|
||||
It("parses nested structures", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.JSON()
|
||||
})
|
||||
input := `{"a": [1, {"b": true}], "c": null}`
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(len(input)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Tag extraction", func() {
|
||||
It("extracts basic tags", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(
|
||||
b.Tag("greeting", b.Until(" ")),
|
||||
b.Literal(" "),
|
||||
b.Tag("name", b.Rest()),
|
||||
)
|
||||
})
|
||||
ctx := peg.NewParseContext("Hello World", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
|
||||
tags := extractTags(&ctx.Ast, &r)
|
||||
Expect(tags["greeting"]).To(Equal("Hello"))
|
||||
Expect(tags["name"]).To(Equal("World"))
|
||||
})
|
||||
|
||||
It("extracts structured tags", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(
|
||||
b.Tag("header", b.Until("\n")),
|
||||
b.Literal("\n"),
|
||||
b.Tag("body", b.Rest()),
|
||||
)
|
||||
})
|
||||
ctx := peg.NewParseContext("Title\nBody content here", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
|
||||
tags := extractTags(&ctx.Ast, &r)
|
||||
Expect(tags["header"]).To(Equal("Title"))
|
||||
Expect(tags["body"]).To(Equal("Body content here"))
|
||||
})
|
||||
|
||||
It("overwrites duplicate tags", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(
|
||||
b.Tag("item", b.Until(",")),
|
||||
b.Literal(","),
|
||||
b.Tag("item", b.Rest()),
|
||||
)
|
||||
})
|
||||
ctx := peg.NewParseContext("first,second", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
|
||||
tags := extractTags(&ctx.Ast, &r)
|
||||
Expect(tags["item"]).To(Equal("second"))
|
||||
})
|
||||
|
||||
It("returns empty map when no tags", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Rest()
|
||||
})
|
||||
ctx := peg.NewParseContext("Hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
|
||||
tags := extractTags(&ctx.Ast, &r)
|
||||
Expect(tags).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Rule and Ref", func() {
|
||||
It("handles named rules", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
word := b.Rule("word", b.Chars("[a-z]", 1, -1))
|
||||
return b.Seq(word, b.Literal(" "), word)
|
||||
})
|
||||
ctx := peg.NewParseContext("hello world", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(11))
|
||||
})
|
||||
|
||||
It("handles forward references", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
ref := b.Ref("greeting")
|
||||
b.Rule("greeting", b.Literal("hello"))
|
||||
return ref
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Context("AtomicParser", func() {
|
||||
It("suppresses partial AST nodes", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Atomic(b.Tag("test", b.Literal("hello world")))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", true)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.NeedMoreInput))
|
||||
Expect(r.Nodes).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Start and End parsers", func() {
|
||||
It("matches at start of input", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Start(), b.Literal("hello"))
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
})
|
||||
|
||||
It("matches at end of input", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("hello"), b.End())
|
||||
})
|
||||
ctx := peg.NewParseContext("hello", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
|
||||
ctx2 := peg.NewParseContext("hello world", false)
|
||||
r2 := arena.Parse(ctx2)
|
||||
Expect(r2.Type).To(Equal(peg.Fail))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Partial parsing", func() {
|
||||
It("extracts tags during partial parse", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(
|
||||
b.Tag("prefix", b.Until(":")),
|
||||
b.Literal(":"),
|
||||
b.Tag("value", b.Rest()),
|
||||
)
|
||||
})
|
||||
ctx := peg.NewParseContext("key:val", true)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).NotTo(Equal(peg.Fail))
|
||||
|
||||
tags := extractTags(&ctx.Ast, &r)
|
||||
Expect(tags["prefix"]).To(Equal("key"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ParseAnywhere", func() {
|
||||
It("finds pattern in middle of input", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(
|
||||
b.Choice(b.Literal("{"), b.Literal(":")),
|
||||
b.Space(),
|
||||
b.Literal("\""),
|
||||
b.Atomic(b.Literal("fun_name")),
|
||||
)
|
||||
})
|
||||
|
||||
input := `This is a very long jinja template string... <tool_call>{ "fun_name" : { "arg" : 1 }</tool_call>`
|
||||
found := false
|
||||
for i := 0; i < len(input); i++ {
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
r := arena.ParseFrom(ctx, i)
|
||||
if r.Type == peg.Success {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).To(BeTrue())
|
||||
})
|
||||
|
||||
It("fails when pattern is not found", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(
|
||||
b.Choice(b.Literal("{"), b.Literal(":")),
|
||||
b.Space(),
|
||||
b.Literal("\""),
|
||||
b.Atomic(b.Literal("fun_name")),
|
||||
)
|
||||
})
|
||||
|
||||
input := `This is a very long jinja template string... <tool_call><fun=fun_name><arg name=arg>1</arg></tool_call>`
|
||||
found := false
|
||||
for i := 0; i < len(input); i++ {
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
r := arena.ParseFrom(ctx, i)
|
||||
if r.Type == peg.Success {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("CharsParser", func() {
|
||||
It("matches lowercase letters", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Chars("[a-z]", 1, -1)
|
||||
})
|
||||
ctx := peg.NewParseContext("hello123", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
|
||||
It("matches negated character class", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Chars("[^0-9]", 1, -1)
|
||||
})
|
||||
ctx := peg.NewParseContext("hello123", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Context("JSONStringParser", func() {
|
||||
It("parses basic strings", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("\""), b.JSONString(), b.Literal("\""))
|
||||
})
|
||||
ctx := peg.NewParseContext(`"hello world"`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(13))
|
||||
})
|
||||
|
||||
It("parses strings with escapes", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("\""), b.JSONString(), b.Literal("\""))
|
||||
})
|
||||
ctx := peg.NewParseContext(`"hello \"world\""`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(17))
|
||||
})
|
||||
|
||||
It("parses strings with unicode escapes", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("\""), b.JSONString(), b.Literal("\""))
|
||||
})
|
||||
ctx := peg.NewParseContext(`"hello \u0041"`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(14))
|
||||
})
|
||||
})
|
||||
|
||||
Context("SpaceParser", func() {
|
||||
It("matches whitespace", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("a"), b.Space(), b.Literal("b"))
|
||||
})
|
||||
ctx := peg.NewParseContext("a b", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(5))
|
||||
})
|
||||
|
||||
It("matches zero whitespace", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("a"), b.Space(), b.Literal("b"))
|
||||
})
|
||||
ctx := peg.NewParseContext("ab", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(2))
|
||||
})
|
||||
})
|
||||
|
||||
Context("PythonDictStringParser", func() {
|
||||
It("parses basic single-quoted strings", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("'"), b.PythonDictString(), b.Literal("'"))
|
||||
})
|
||||
ctx := peg.NewParseContext("'hello world'", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
Expect(r.End).To(Equal(13))
|
||||
})
|
||||
|
||||
It("handles escaped single quotes", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("'"), b.PythonDictString(), b.Literal("'"))
|
||||
})
|
||||
ctx := peg.NewParseContext(`'it\'s fine'`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
})
|
||||
|
||||
It("handles double quotes inside single-quoted strings", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("'"), b.PythonDictString(), b.Literal("'"))
|
||||
})
|
||||
ctx := peg.NewParseContext(`'He said "hi"'`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
})
|
||||
})
|
||||
|
||||
Context("peg.ParseCharClassChar", func() {
|
||||
It("parses \\x hex escape", func() {
|
||||
r, n := peg.ParseCharClassChar(`\x41`, 0)
|
||||
Expect(r).To(Equal('A'))
|
||||
Expect(n).To(Equal(4))
|
||||
})
|
||||
|
||||
It("parses \\u unicode escape", func() {
|
||||
r, n := peg.ParseCharClassChar(`\u0041`, 0)
|
||||
Expect(r).To(Equal('A'))
|
||||
Expect(n).To(Equal(6))
|
||||
})
|
||||
|
||||
It("parses \\U unicode escape", func() {
|
||||
r, n := peg.ParseCharClassChar(`\U00000041`, 0)
|
||||
Expect(r).To(Equal('A'))
|
||||
Expect(n).To(Equal(10))
|
||||
})
|
||||
|
||||
It("falls back on invalid hex", func() {
|
||||
r, n := peg.ParseCharClassChar(`\xZZ`, 0)
|
||||
Expect(r).To(Equal('x'))
|
||||
Expect(n).To(Equal(2))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ParseAnywhere method", func() {
|
||||
It("finds pattern in middle of input", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Seq(b.Literal("needle"), b.Tag("after", b.Until(".")))
|
||||
})
|
||||
ctx := peg.NewParseContext("some hay needle found.", false)
|
||||
r := arena.ParseAnywhere(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
tags := extractTags(&ctx.Ast, &r)
|
||||
Expect(tags["after"]).To(Equal(" found"))
|
||||
})
|
||||
|
||||
It("finds function tag with name", func() {
|
||||
haystack := "\n<tool_call>\n<function=foofoo>\n<parameter=first>\nXXXX\n</parameter>\n<parameter=second>\nYYYY\n</parameter>\n</function>\n</tool_call>\n"
|
||||
needle := "foofoo"
|
||||
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Tag("fun_marker", b.Choice(
|
||||
b.Seq(
|
||||
b.Tag("fun_pre", b.Seq(b.Literal("<"), b.UntilOneOf(">", needle))),
|
||||
b.Literal(needle),
|
||||
b.Tag("fun_post", b.Seq(
|
||||
b.Seq(b.Negate(b.Seq(b.Space(), b.Literal("<"))), b.Until(">"), b.Literal(">")),
|
||||
)),
|
||||
b.Space(),
|
||||
),
|
||||
b.Seq(
|
||||
b.Tag("fun_pre", b.Seq(b.Literal("["), b.UntilOneOf("]", needle))),
|
||||
b.Literal(needle),
|
||||
b.Tag("fun_post", b.Seq(
|
||||
b.Negate(b.Seq(b.Space(), b.Seq(b.Literal("["), b.Until("]"), b.Literal("]")))),
|
||||
b.Space(),
|
||||
)),
|
||||
),
|
||||
))
|
||||
})
|
||||
|
||||
ctx := peg.NewParseContext(haystack, false)
|
||||
r := arena.ParseAnywhere(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
tags := extractTags(&ctx.Ast, &r)
|
||||
Expect(tags["fun_pre"]).To(Equal("<function="))
|
||||
Expect(tags["fun_post"]).To(Equal(">"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("LazyRule", func() {
|
||||
It("handles recursive JSON-like structures", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
b.LazyRule("value", func() peg.ParserID {
|
||||
str := b.Seq(b.Literal("\""), b.JSONString(), b.Literal("\""))
|
||||
arr := b.Seq(
|
||||
b.Literal("["), b.Space(),
|
||||
b.Ref("value"),
|
||||
b.ZeroOrMore(b.Seq(b.Space(), b.Literal(","), b.Space(), b.Ref("value"))),
|
||||
b.Space(), b.Literal("]"),
|
||||
)
|
||||
return b.Choice(str, arr)
|
||||
})
|
||||
return b.Ref("value")
|
||||
})
|
||||
ctx := peg.NewParseContext(`["hello",["world","nested"]]`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
})
|
||||
|
||||
It("parses python dicts", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.PythonDict()
|
||||
})
|
||||
ctx := peg.NewParseContext(`{'key': 'value', 'num': 42}`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
})
|
||||
|
||||
It("parses nested python values", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.PythonValue()
|
||||
})
|
||||
ctx := peg.NewParseContext(`{'outer': {'inner': [1, 2, 'three']}}`, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
})
|
||||
|
||||
It("parses python booleans and None", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.PythonValue()
|
||||
})
|
||||
for _, input := range []string{"True", "False", "None"} {
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("Marker", func() {
|
||||
It("matches angle brackets", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Marker()
|
||||
})
|
||||
ctx := peg.NewParseContext("<tool_call>", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
})
|
||||
|
||||
It("matches square brackets", func() {
|
||||
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
|
||||
return b.Marker()
|
||||
})
|
||||
ctx := peg.NewParseContext("[TOOL]", false)
|
||||
r := arena.Parse(ctx)
|
||||
Expect(r.Type).To(Equal(peg.Success))
|
||||
})
|
||||
})
|
||||
})
|
||||
13
pkg/functions/peg/peg_suite_test.go
Normal file
13
pkg/functions/peg/peg_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package peg_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestPeg(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "PEG Parser test suite")
|
||||
}
|
||||
80
pkg/functions/peg/trie.go
Normal file
80
pkg/functions/peg/trie.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package peg
|
||||
|
||||
// trie is used for multi-delimiter matching in UntilParser.
|
||||
type trie struct {
|
||||
nodes []trieNode
|
||||
}
|
||||
|
||||
type trieNode struct {
|
||||
children map[rune]int
|
||||
isWord bool
|
||||
}
|
||||
|
||||
type trieMatch int
|
||||
|
||||
const (
|
||||
trieNoMatch trieMatch = 0
|
||||
triePartialMatch trieMatch = 1
|
||||
trieCompleteMatch trieMatch = 2
|
||||
)
|
||||
|
||||
func newTrie(words []string) *trie {
|
||||
t := &trie{}
|
||||
t.createNode() // root
|
||||
for _, w := range words {
|
||||
t.insert(w)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *trie) createNode() int {
|
||||
idx := len(t.nodes)
|
||||
t.nodes = append(t.nodes, trieNode{children: make(map[rune]int)})
|
||||
return idx
|
||||
}
|
||||
|
||||
func (t *trie) insert(word string) {
|
||||
current := 0
|
||||
for _, ch := range word {
|
||||
if next, ok := t.nodes[current].children[ch]; ok {
|
||||
current = next
|
||||
} else {
|
||||
child := t.createNode()
|
||||
t.nodes[current].children[ch] = child
|
||||
current = child
|
||||
}
|
||||
}
|
||||
t.nodes[current].isWord = true
|
||||
}
|
||||
|
||||
// checkAt checks if any delimiter starts at position pos in the input.
|
||||
func (t *trie) checkAt(input string, pos int) trieMatch {
|
||||
current := 0
|
||||
p := pos
|
||||
|
||||
for p < len(input) {
|
||||
r, size, status := parseUTF8Codepoint(input, p)
|
||||
if status != utf8Success {
|
||||
break
|
||||
}
|
||||
|
||||
next, ok := t.nodes[current].children[r]
|
||||
if !ok {
|
||||
return trieNoMatch
|
||||
}
|
||||
|
||||
current = next
|
||||
p += size
|
||||
|
||||
if t.nodes[current].isWord {
|
||||
return trieCompleteMatch
|
||||
}
|
||||
}
|
||||
|
||||
// Reached end of input while still in the trie
|
||||
if current != 0 {
|
||||
return triePartialMatch
|
||||
}
|
||||
|
||||
return trieNoMatch
|
||||
}
|
||||
175
pkg/functions/peg/types.go
Normal file
175
pkg/functions/peg/types.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package peg
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// ParserID is a unique identifier for a parser in the arena.
|
||||
type ParserID int
|
||||
|
||||
const InvalidParserID ParserID = -1
|
||||
|
||||
// AstID is a unique identifier for an AST node.
|
||||
type AstID int
|
||||
|
||||
const InvalidAstID AstID = -1
|
||||
|
||||
// ParseResultType indicates the outcome of a parse attempt.
|
||||
type ParseResultType int
|
||||
|
||||
const (
|
||||
Fail ParseResultType = 0
|
||||
Success ParseResultType = 1
|
||||
NeedMoreInput ParseResultType = 2
|
||||
)
|
||||
|
||||
func (t ParseResultType) String() string {
|
||||
switch t {
|
||||
case Fail:
|
||||
return "fail"
|
||||
case Success:
|
||||
return "success"
|
||||
case NeedMoreInput:
|
||||
return "need_more_input"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ParseResult holds the result of a parse operation.
|
||||
type ParseResult struct {
|
||||
Type ParseResultType
|
||||
Start int
|
||||
End int
|
||||
Nodes []AstID
|
||||
}
|
||||
|
||||
func NewParseResult(typ ParseResultType, start int) ParseResult {
|
||||
return ParseResult{Type: typ, Start: start, End: start}
|
||||
}
|
||||
|
||||
func NewParseResultRange(typ ParseResultType, start, end int) ParseResult {
|
||||
return ParseResult{Type: typ, Start: start, End: end}
|
||||
}
|
||||
|
||||
func NewParseResultNodes(typ ParseResultType, start, end int, nodes []AstID) ParseResult {
|
||||
return ParseResult{Type: typ, Start: start, End: end, Nodes: nodes}
|
||||
}
|
||||
|
||||
// AstNode is a node in the parse AST.
|
||||
type AstNode struct {
|
||||
ID AstID
|
||||
Rule string
|
||||
Tag string
|
||||
Start int
|
||||
End int
|
||||
Text string
|
||||
Children []AstID
|
||||
IsPartial bool
|
||||
}
|
||||
|
||||
// AstArena stores AST nodes.
|
||||
type AstArena struct {
|
||||
nodes []AstNode
|
||||
}
|
||||
|
||||
func (a *AstArena) AddNode(rule, tag string, start, end int, text string, children []AstID, isPartial bool) AstID {
|
||||
id := AstID(len(a.nodes))
|
||||
a.nodes = append(a.nodes, AstNode{
|
||||
ID: id,
|
||||
Rule: rule,
|
||||
Tag: tag,
|
||||
Start: start,
|
||||
End: end,
|
||||
Text: text,
|
||||
Children: children,
|
||||
IsPartial: isPartial,
|
||||
})
|
||||
return id
|
||||
}
|
||||
|
||||
func (a *AstArena) Get(id AstID) *AstNode {
|
||||
return &a.nodes[id]
|
||||
}
|
||||
|
||||
func (a *AstArena) Size() int {
|
||||
return len(a.nodes)
|
||||
}
|
||||
|
||||
func (a *AstArena) Clear() {
|
||||
a.nodes = a.nodes[:0]
|
||||
}
|
||||
|
||||
// Visit traverses the AST tree rooted at the given node, calling fn for each node.
|
||||
func (a *AstArena) Visit(id AstID, fn func(*AstNode)) {
|
||||
if id == InvalidAstID {
|
||||
return
|
||||
}
|
||||
node := a.Get(id)
|
||||
fn(node)
|
||||
for _, child := range node.Children {
|
||||
a.Visit(child, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// VisitResult traverses all top-level nodes in a parse result.
|
||||
func (a *AstArena) VisitResult(result *ParseResult, fn func(*AstNode)) {
|
||||
for _, id := range result.Nodes {
|
||||
a.Visit(id, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseContext holds the state for a parse operation.
|
||||
type ParseContext struct {
|
||||
Input string
|
||||
IsPartial bool
|
||||
Debug bool
|
||||
Ast AstArena
|
||||
}
|
||||
|
||||
func NewParseContext(input string, isPartial bool) *ParseContext {
|
||||
return &ParseContext{
|
||||
Input: input,
|
||||
IsPartial: isPartial,
|
||||
}
|
||||
}
|
||||
|
||||
// parseUTF8Codepoint parses a single UTF-8 codepoint at position pos.
|
||||
// Returns the codepoint, bytes consumed, and status.
|
||||
type utf8Status int
|
||||
|
||||
const (
|
||||
utf8Success utf8Status = 0
|
||||
utf8Incomplete utf8Status = 1
|
||||
utf8Invalid utf8Status = 2
|
||||
)
|
||||
|
||||
func parseUTF8Codepoint(input string, pos int) (rune, int, utf8Status) {
|
||||
if pos >= len(input) {
|
||||
return 0, 0, utf8Incomplete
|
||||
}
|
||||
r, size := utf8.DecodeRuneInString(input[pos:])
|
||||
if r == utf8.RuneError {
|
||||
if size == 0 {
|
||||
return 0, 0, utf8Incomplete
|
||||
}
|
||||
// Could be incomplete multi-byte sequence
|
||||
b := input[pos]
|
||||
var expectedLen int
|
||||
switch {
|
||||
case b&0x80 == 0:
|
||||
expectedLen = 1
|
||||
case b&0xE0 == 0xC0:
|
||||
expectedLen = 2
|
||||
case b&0xF0 == 0xE0:
|
||||
expectedLen = 3
|
||||
case b&0xF8 == 0xF0:
|
||||
expectedLen = 4
|
||||
default:
|
||||
return 0, 0, utf8Invalid
|
||||
}
|
||||
if pos+expectedLen > len(input) {
|
||||
return 0, 0, utf8Incomplete
|
||||
}
|
||||
return 0, 0, utf8Invalid
|
||||
}
|
||||
return r, size, utf8Success
|
||||
}
|
||||
268
pkg/functions/peg/utils_test.go
Normal file
268
pkg/functions/peg/utils_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package peg_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/functions/peg"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("PEG Utils", func() {
|
||||
Context("peg.NormalizeQuotesToJSON", func() {
|
||||
It("converts basic single quotes to double quotes", func() {
|
||||
input := "{'key': 'value'}"
|
||||
expected := `{"key": "value"}`
|
||||
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles escaped single quotes", func() {
|
||||
input := `{'code': 'print(\'hello\')'}`
|
||||
expected := `{"code": "print('hello')"}`
|
||||
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles double quotes inside single-quoted strings", func() {
|
||||
input := `{'msg': 'He said "hi"'}`
|
||||
expected := `{"msg": "He said \"hi\""}`
|
||||
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles nested backslash escapes", func() {
|
||||
input := `{'path': 'C:\\Users\\test'}`
|
||||
expected := `{"path": "C:\\Users\\test"}`
|
||||
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles newline escapes", func() {
|
||||
input := `{'text': 'line1\nline2'}`
|
||||
expected := `{"text": "line1\nline2"}`
|
||||
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles mixed quotes", func() {
|
||||
input := `{"already_double": 'single_value'}`
|
||||
expected := `{"already_double": "single_value"}`
|
||||
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles embedded quotes complex case", func() {
|
||||
input := `{'filename': 'foo.cpp', 'oldString': 'def foo(arg = "14"):\n return arg + "bar"\n', 'newString': 'def foo(arg = "15"):\n pass\n'}`
|
||||
result := peg.NormalizeQuotesToJSON(input)
|
||||
|
||||
var parsed map[string]string
|
||||
err := json.Unmarshal([]byte(result), &parsed)
|
||||
Expect(err).NotTo(HaveOccurred(), "result is not valid JSON: %s", result)
|
||||
|
||||
Expect(parsed["filename"]).To(Equal("foo.cpp"))
|
||||
Expect(parsed["oldString"]).NotTo(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("peg.EscapeJSONStringInner", func() {
|
||||
It("leaves basic strings unchanged", func() {
|
||||
Expect(peg.EscapeJSONStringInner("hello")).To(Equal("hello"))
|
||||
})
|
||||
|
||||
It("escapes double quotes", func() {
|
||||
Expect(peg.EscapeJSONStringInner(`hello "world"`)).To(Equal(`hello \"world\"`))
|
||||
})
|
||||
|
||||
It("escapes backslash-n sequences", func() {
|
||||
Expect(peg.EscapeJSONStringInner(`line1\nline2`)).To(Equal(`line1\\nline2`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("StandardJSONTools OpenAI format", func() {
|
||||
It("parses OpenAI-style tool calls with call ID", func() {
|
||||
tools := []peg.ToolDef{
|
||||
{
|
||||
Name: "get_current_weather",
|
||||
Properties: map[string]peg.PropDef{
|
||||
"location": {Type: "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
|
||||
SectionStart: "<tool_call>",
|
||||
SectionEnd: "</tool_call>",
|
||||
Tools: tools,
|
||||
CallIDKey: "id",
|
||||
ParametersOrder: []string{"id", "name", "arguments"},
|
||||
})
|
||||
return p.Seq(
|
||||
p.Content(p.Until("<tool_call>")),
|
||||
p.Optional(p.Seq(p.Space(), toolCall)),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
|
||||
input := `Let me check the weather.<tool_call>{"id": "call_abc123", "name": "get_current_weather", "arguments": {"location": "NYC"}}</tool_call>`
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
|
||||
Expect(msg.ToolCalls[0].ID).To(Equal("call_abc123"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("StandardJSONTools Cohere format", func() {
|
||||
It("parses Cohere-style tool calls with custom keys", func() {
|
||||
tools := []peg.ToolDef{
|
||||
{
|
||||
Name: "get_current_weather",
|
||||
Properties: map[string]peg.PropDef{
|
||||
"location": {Type: "string"},
|
||||
"unit": {Type: "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
|
||||
SectionStart: "<|START_ACTION|>[",
|
||||
SectionEnd: "]<|END_ACTION|>",
|
||||
Tools: tools,
|
||||
NameKey: "tool_name",
|
||||
ArgsKey: "parameters",
|
||||
GenCallIDKey: "tool_call_id",
|
||||
ParametersOrder: []string{"tool_call_id", "tool_name", "parameters"},
|
||||
})
|
||||
return p.Seq(
|
||||
p.Content(p.Until("<|START_ACTION|>")),
|
||||
p.Optional(p.Seq(p.Space(), toolCall)),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
|
||||
input := `Let me search for that.<|START_ACTION|>[{"tool_call_id": 0, "tool_name": "get_current_weather", "parameters": {"location": "NYC", "unit": "celsius"}}]<|END_ACTION|>`
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
|
||||
Expect(msg.ToolCalls[0].ID).To(Equal("0"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("StandardJSONTools function-as-key format", func() {
|
||||
It("parses function name as JSON key", func() {
|
||||
tools := []peg.ToolDef{
|
||||
{
|
||||
Name: "get_current_weather",
|
||||
Properties: map[string]peg.PropDef{
|
||||
"location": {Type: "string"},
|
||||
"unit": {Type: "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
|
||||
SectionStart: "<tool_calls>[",
|
||||
SectionEnd: "]</tool_calls>",
|
||||
Tools: tools,
|
||||
ArgsKey: "args",
|
||||
FunctionIsKey: true,
|
||||
CallIDKey: "id",
|
||||
})
|
||||
return p.Seq(
|
||||
p.Content(p.Until("<tool_calls>")),
|
||||
p.Optional(p.Seq(p.Space(), toolCall)),
|
||||
p.End(),
|
||||
)
|
||||
})
|
||||
|
||||
input := `I'll call the weather function.<tool_calls>[{"get_current_weather": {"id": "call-0001", "args": {"location": "NYC", "unit": "celsius"}}}]</tool_calls>`
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
|
||||
Expect(msg.ToolCalls[0].ID).To(Equal("call-0001"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Tagged args with embedded quotes", func() {
|
||||
It("handles embedded double quotes in tagged parameters", func() {
|
||||
tools := []peg.ToolDef{
|
||||
{
|
||||
Name: "edit",
|
||||
Properties: map[string]peg.PropDef{
|
||||
"filename": {Type: "string"},
|
||||
"oldString": {Type: "string"},
|
||||
"newString": {Type: "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
toolCall := p.StandardConstructedTools(
|
||||
map[string]string{
|
||||
"tool_call_start_marker": "<seed:tool_call>",
|
||||
"tool_call_end_marker": "</seed:tool_call>",
|
||||
"function_opener": "<function=",
|
||||
"function_name_suffix": ">",
|
||||
"function_closer": "</function>",
|
||||
"parameter_key_prefix": "<parameter=",
|
||||
"parameter_key_suffix": ">",
|
||||
"parameter_closer": "</parameter>",
|
||||
},
|
||||
tools,
|
||||
false,
|
||||
true,
|
||||
)
|
||||
return p.Seq(toolCall, p.Space(), p.End())
|
||||
})
|
||||
|
||||
input := "<seed:tool_call>\n" +
|
||||
"<function=edit>\n" +
|
||||
"<parameter=filename>\nfoo.cpp\n</parameter>\n" +
|
||||
"<parameter=oldString>def foo(arg = \"14\"):\n return arg + \"bar\"\n</parameter>\n" +
|
||||
"<parameter=newString>def foo(arg = \"15\"):\n pass\n</parameter>\n" +
|
||||
"</function>\n" +
|
||||
"</seed:tool_call>"
|
||||
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := parser.Parse(ctx)
|
||||
|
||||
Expect(result.Type).To(Equal(peg.Success))
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
Expect(msg.ToolCalls).To(HaveLen(1))
|
||||
Expect(msg.ToolCalls[0].Name).To(Equal("edit"))
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal([]byte(msg.ToolCalls[0].Arguments), &parsed)
|
||||
Expect(err).NotTo(HaveOccurred(), "arguments not valid JSON: %s", msg.ToolCalls[0].Arguments)
|
||||
Expect(parsed["filename"]).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
538
pkg/functions/peg_integration.go
Normal file
538
pkg/functions/peg_integration.go
Normal file
@@ -0,0 +1,538 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/functions/peg"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// PEGFormatType identifies the format type for PEG parsing.
|
||||
type PEGFormatType int
|
||||
|
||||
const (
|
||||
FormatJSONNative PEGFormatType = iota
|
||||
FormatTagWithJSON // <function=name>{"key": "val"}</function>
|
||||
FormatTagWithTagged // <function=name><param=key>value</param></function>
|
||||
)
|
||||
|
||||
// ParseFunctionCallPEG attempts to parse tool calls using the PEG parser.
|
||||
// Returns nil if no tool calls were found.
|
||||
func ParseFunctionCallPEG(llmresult string, config FunctionsConfig) []FuncCallResults {
|
||||
xlog.Debug("[PEG] starting PEG tool call parsing")
|
||||
|
||||
// If auto-detected markers from the C++ backend are available, use them first
|
||||
if config.ToolFormatMarkers != nil {
|
||||
m := config.ToolFormatMarkers
|
||||
xlog.Debug("[PEG] using auto-detected markers from C++ backend",
|
||||
"format_type", m.FormatType,
|
||||
"section_start", m.SectionStart,
|
||||
"section_end", m.SectionEnd,
|
||||
"per_call_start", m.PerCallStart,
|
||||
"per_call_end", m.PerCallEnd,
|
||||
"func_name_prefix", m.FuncNamePrefix,
|
||||
"func_name_suffix", m.FuncNameSuffix,
|
||||
"func_close", m.FuncClose,
|
||||
"arg_name_prefix", m.ArgNamePrefix,
|
||||
"arg_name_suffix", m.ArgNameSuffix,
|
||||
"arg_value_prefix", m.ArgValuePrefix,
|
||||
"arg_value_suffix", m.ArgValueSuffix,
|
||||
"arg_separator", m.ArgSeparator,
|
||||
"name_field", m.NameField,
|
||||
"args_field", m.ArgsField,
|
||||
"id_field", m.IDField,
|
||||
"reasoning_start", m.ReasoningStart,
|
||||
"reasoning_end", m.ReasoningEnd,
|
||||
)
|
||||
arena := BuildPEGParserFromMarkers(config.ToolFormatMarkers)
|
||||
if arena != nil {
|
||||
results := parsePEG(arena, llmresult)
|
||||
if len(results) > 0 {
|
||||
xlog.Debug("[PEG] markers-based parser matched", "count", len(results))
|
||||
return results
|
||||
}
|
||||
xlog.Debug("[PEG] markers-based parser found no tool calls")
|
||||
} else {
|
||||
xlog.Debug("[PEG] failed to build parser from markers")
|
||||
}
|
||||
}
|
||||
|
||||
// If a specific XML format preset is set, use its PEG format
|
||||
if config.XMLFormatPreset != "" {
|
||||
xlog.Debug("[PEG] trying XML format preset", "preset", config.XMLFormatPreset)
|
||||
preset := GetXMLFormatPreset(config.XMLFormatPreset)
|
||||
if preset != nil {
|
||||
pegType := classifyXMLFormat(preset)
|
||||
xlog.Debug("[PEG] classified preset", "preset", config.XMLFormatPreset, "peg_type", pegTypeName(pegType))
|
||||
arena := BuildPEGParserFromFormat(preset, pegType)
|
||||
if arena != nil {
|
||||
results := parsePEG(arena, llmresult)
|
||||
if len(results) > 0 {
|
||||
xlog.Debug("[PEG] preset parser matched", "preset", config.XMLFormatPreset, "count", len(results))
|
||||
return results
|
||||
}
|
||||
xlog.Debug("[PEG] preset parser found no tool calls", "preset", config.XMLFormatPreset)
|
||||
}
|
||||
} else {
|
||||
xlog.Debug("[PEG] unknown preset name", "preset", config.XMLFormatPreset)
|
||||
}
|
||||
}
|
||||
|
||||
// If a custom XML format is set, classify and try it
|
||||
if config.XMLFormat != nil {
|
||||
pegType := classifyXMLFormat(config.XMLFormat)
|
||||
xlog.Debug("[PEG] trying custom XML format", "peg_type", pegTypeName(pegType))
|
||||
arena := BuildPEGParserFromFormat(config.XMLFormat, pegType)
|
||||
if arena != nil {
|
||||
results := parsePEG(arena, llmresult)
|
||||
if len(results) > 0 {
|
||||
xlog.Debug("[PEG] custom format parser matched", "count", len(results))
|
||||
return results
|
||||
}
|
||||
xlog.Debug("[PEG] custom format parser found no tool calls")
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect: try all three format types
|
||||
xlog.Debug("[PEG] auto-detecting format across all presets")
|
||||
for _, pegType := range []PEGFormatType{FormatJSONNative, FormatTagWithJSON, FormatTagWithTagged} {
|
||||
for _, preset := range getAllXMLFormats() {
|
||||
classified := classifyXMLFormat(preset.format)
|
||||
if classified != pegType {
|
||||
continue
|
||||
}
|
||||
arena := BuildPEGParserFromFormat(preset.format, pegType)
|
||||
if arena == nil {
|
||||
continue
|
||||
}
|
||||
results := parsePEG(arena, llmresult)
|
||||
if len(results) > 0 {
|
||||
xlog.Debug("[PEG] auto-detect matched", "preset", preset.name, "peg_type", pegTypeName(pegType), "count", len(results))
|
||||
return results
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Debug("[PEG] no tool calls found by any format")
|
||||
return nil
|
||||
}
|
||||
|
||||
func pegTypeName(t PEGFormatType) string {
|
||||
switch t {
|
||||
case FormatJSONNative:
|
||||
return "json_native"
|
||||
case FormatTagWithJSON:
|
||||
return "tag_with_json"
|
||||
case FormatTagWithTagged:
|
||||
return "tag_with_tagged"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// classifyXMLFormat determines the PEG format type from an XML format config.
|
||||
func classifyXMLFormat(f *XMLToolCallFormat) PEGFormatType {
|
||||
// If there's an explicit function opener like "<function=", it's a tag-based format
|
||||
hasTagOpener := f.ToolStart != "" && f.ToolSep != ""
|
||||
|
||||
if f.RawArgVal != nil && !*f.RawArgVal {
|
||||
// JSON-only args
|
||||
if hasTagOpener {
|
||||
return FormatTagWithJSON
|
||||
}
|
||||
if f.KeyStart == "" || f.KeyStart == "\"" {
|
||||
return FormatJSONNative
|
||||
}
|
||||
return FormatTagWithJSON
|
||||
}
|
||||
if f.KeyStart != "" {
|
||||
return FormatTagWithTagged
|
||||
}
|
||||
return FormatTagWithJSON
|
||||
}
|
||||
|
||||
// BuildPEGParserFromFormat builds a PEG parser arena from an XML format config.
|
||||
func BuildPEGParserFromFormat(f *XMLToolCallFormat, pegType PEGFormatType) *peg.Arena {
|
||||
switch pegType {
|
||||
case FormatTagWithTagged, FormatTagWithJSON:
|
||||
return buildTaggedPEGParser(f)
|
||||
case FormatJSONNative:
|
||||
return buildJSONNativePEGParser(f)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildTaggedPEGParser(f *XMLToolCallFormat) *peg.Arena {
|
||||
markers := map[string]string{}
|
||||
|
||||
funcOpener := f.ToolStart
|
||||
funcNameSuffix := f.ToolSep
|
||||
funcCloser := f.ToolEnd
|
||||
|
||||
hasScope := f.ScopeStart != ""
|
||||
|
||||
if hasScope {
|
||||
markers["tool_call_start_marker"] = f.ScopeStart
|
||||
markers["tool_call_end_marker"] = f.ScopeEnd
|
||||
}
|
||||
|
||||
markers["function_opener"] = funcOpener
|
||||
markers["function_name_suffix"] = funcNameSuffix
|
||||
markers["function_closer"] = funcCloser
|
||||
|
||||
// Always set parameter markers explicitly to avoid relying on defaults.
|
||||
// Formats without tagged params (e.g., functionary) need empty strings.
|
||||
markers["parameter_key_prefix"] = f.KeyStart
|
||||
markers["parameter_key_suffix"] = f.KeyValSep
|
||||
markers["parameter_closer"] = f.ValEnd
|
||||
|
||||
// Determine what to use as the content delimiter
|
||||
contentDelim := f.ScopeStart
|
||||
if contentDelim == "" {
|
||||
contentDelim = f.ToolStart
|
||||
}
|
||||
|
||||
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
tools := []peg.ToolDef{} // empty = accept anything
|
||||
content := p.Content(p.Until(contentDelim))
|
||||
|
||||
if hasScope {
|
||||
// With scope markers: use StandardConstructedTools which wraps in scope
|
||||
toolCall := p.StandardConstructedTools(markers, tools, true, false)
|
||||
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
|
||||
}
|
||||
|
||||
// No scope markers (e.g., functionary): build tool parser directly without scope wrapper
|
||||
hasTaggedParams := f.KeyStart != ""
|
||||
var args peg.ParserID
|
||||
if hasTaggedParams {
|
||||
paramKeyPrefix := f.KeyStart
|
||||
paramKeySuffix := f.KeyValSep
|
||||
paramCloser := f.ValEnd
|
||||
argRule := p.ToolArg(p.Seq(
|
||||
p.ToolArgOpen(p.Literal(paramKeyPrefix)),
|
||||
p.ToolArgName(p.Until(paramKeySuffix)),
|
||||
p.Literal(paramKeySuffix),
|
||||
p.ToolArgValue(p.Until(paramCloser)),
|
||||
p.ToolArgClose(p.Literal(paramCloser)),
|
||||
))
|
||||
args = p.ToolArgs(p.ZeroOrMore(p.Seq(argRule, p.Space())))
|
||||
} else {
|
||||
// JSON arguments
|
||||
args = p.ToolArgs(p.Until(funcCloser))
|
||||
}
|
||||
|
||||
toolParser := p.Tool(p.Seq(
|
||||
p.ToolOpen(p.Seq(
|
||||
p.Literal(funcOpener),
|
||||
p.ToolName(p.Until(funcNameSuffix)),
|
||||
p.Literal(funcNameSuffix),
|
||||
)),
|
||||
p.Space(),
|
||||
args,
|
||||
p.Space(),
|
||||
p.ToolClose(p.Literal(funcCloser)),
|
||||
))
|
||||
|
||||
toolCall := p.TriggerRule("tool-call", p.OneOrMore(p.Seq(toolParser, p.Space())))
|
||||
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
|
||||
})
|
||||
}
|
||||
|
||||
func buildJSONNativePEGParser(f *XMLToolCallFormat) *peg.Arena {
|
||||
sectionStart := f.ScopeStart
|
||||
sectionEnd := f.ScopeEnd
|
||||
|
||||
if sectionStart == "" && f.ToolStart != "" {
|
||||
sectionStart = f.ToolStart
|
||||
}
|
||||
if sectionEnd == "" && f.ToolEnd != "" {
|
||||
sectionEnd = f.ToolEnd
|
||||
}
|
||||
|
||||
if sectionStart == "" || sectionEnd == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
// For JSON native, tool call is { "name": ..., "arguments": ... }
|
||||
// Build a generic parser that accepts any JSON tool call
|
||||
toolCall := p.TriggerRule("tool-call", p.Seq(
|
||||
p.Literal(sectionStart), p.Space(),
|
||||
p.Tool(p.Seq(
|
||||
p.ToolOpen(p.Literal("{")), p.Space(),
|
||||
p.ZeroOrMore(p.Seq(
|
||||
p.Choice(
|
||||
p.Seq(
|
||||
p.Literal("\"name\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
|
||||
),
|
||||
p.Seq(
|
||||
p.Literal("\"arguments\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.ToolArgs(p.JSON()),
|
||||
),
|
||||
p.Seq(
|
||||
p.Literal("\""), p.JSONString(), p.Literal("\""), p.Space(),
|
||||
p.Literal(":"), p.Space(), p.JSON(),
|
||||
),
|
||||
),
|
||||
p.Optional(p.Seq(p.Space(), p.Literal(","), p.Space())),
|
||||
)),
|
||||
p.Space(), p.ToolClose(p.Literal("}")),
|
||||
)),
|
||||
p.Space(), p.Literal(sectionEnd),
|
||||
))
|
||||
|
||||
content := p.Content(p.Until(sectionStart))
|
||||
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
|
||||
})
|
||||
}
|
||||
|
||||
// BuildPEGParserFromMarkers builds a PEG parser from auto-detected C++ autoparser markers.
|
||||
func BuildPEGParserFromMarkers(m *ToolFormatMarkers) *peg.Arena {
|
||||
switch m.FormatType {
|
||||
case "tag_with_json":
|
||||
return buildPEGFromMarkersTagJSON(m)
|
||||
case "tag_with_tagged":
|
||||
return buildPEGFromMarkersTagTagged(m)
|
||||
case "json_native":
|
||||
return buildPEGFromMarkersJSONNative(m)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildPEGFromMarkersTagJSON(m *ToolFormatMarkers) *peg.Arena {
|
||||
markers := map[string]string{}
|
||||
|
||||
// Use section markers if available, otherwise fall back to per-call markers
|
||||
scopeStart, scopeEnd := effectiveScope(m)
|
||||
|
||||
if scopeStart != "" {
|
||||
markers["tool_call_start_marker"] = scopeStart
|
||||
markers["tool_call_end_marker"] = scopeEnd
|
||||
}
|
||||
|
||||
markers["function_opener"] = strings.TrimRight(m.FuncNamePrefix, " \t\n")
|
||||
markers["function_name_suffix"] = strings.TrimRight(m.FuncNameSuffix, " \t\n")
|
||||
markers["function_closer"] = strings.TrimRight(m.FuncClose, " \t\n")
|
||||
markers["parameter_key_prefix"] = ""
|
||||
markers["parameter_key_suffix"] = ""
|
||||
markers["parameter_closer"] = ""
|
||||
|
||||
if m.CallIDPosition == "between_func_and_args" {
|
||||
markers["call_id_prefix"] = m.CallIDPrefix
|
||||
markers["call_id_suffix"] = m.CallIDSuffix
|
||||
}
|
||||
|
||||
contentDelim := scopeStart
|
||||
if contentDelim == "" {
|
||||
contentDelim = strings.TrimRight(m.FuncNamePrefix, " \t\n")
|
||||
}
|
||||
|
||||
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
tools := []peg.ToolDef{}
|
||||
content := p.Content(p.Until(contentDelim))
|
||||
|
||||
if scopeStart != "" {
|
||||
toolCall := p.StandardConstructedTools(markers, tools, true, false)
|
||||
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
|
||||
}
|
||||
|
||||
// No scope: build tool parser directly
|
||||
funcOpener := m.FuncNamePrefix
|
||||
funcNameSuffix := m.FuncNameSuffix
|
||||
funcCloser := m.FuncClose
|
||||
|
||||
args := p.ToolArgs(p.Until(funcCloser))
|
||||
|
||||
// Build call ID section if detected
|
||||
callIDSection := buildCallIDParser(p, m)
|
||||
|
||||
toolParser := p.Tool(p.Seq(
|
||||
p.ToolOpen(p.Seq(
|
||||
p.Literal(funcOpener),
|
||||
p.ToolName(p.Until(funcNameSuffix)),
|
||||
p.Literal(funcNameSuffix),
|
||||
)),
|
||||
callIDSection,
|
||||
p.Space(),
|
||||
args,
|
||||
p.Space(),
|
||||
p.ToolClose(p.Literal(funcCloser)),
|
||||
))
|
||||
toolCall := p.TriggerRule("tool-call", p.OneOrMore(p.Seq(toolParser, p.Space())))
|
||||
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
|
||||
})
|
||||
}
|
||||
|
||||
func buildPEGFromMarkersTagTagged(m *ToolFormatMarkers) *peg.Arena {
|
||||
markers := map[string]string{}
|
||||
|
||||
// Use section markers if available, otherwise fall back to per-call markers
|
||||
scopeStart, scopeEnd := effectiveScope(m)
|
||||
|
||||
if scopeStart != "" {
|
||||
markers["tool_call_start_marker"] = scopeStart
|
||||
markers["tool_call_end_marker"] = scopeEnd
|
||||
}
|
||||
|
||||
// Trim trailing whitespace from markers — the PEG Space() parser
|
||||
// handles whitespace between elements, so baked-in \n would cause mismatches.
|
||||
markers["function_opener"] = strings.TrimRight(m.FuncNamePrefix, " \t\n")
|
||||
markers["function_name_suffix"] = strings.TrimRight(m.FuncNameSuffix, " \t\n")
|
||||
markers["function_closer"] = strings.TrimRight(m.FuncClose, " \t\n")
|
||||
markers["parameter_key_prefix"] = strings.TrimRight(m.ArgNamePrefix, " \t\n")
|
||||
markers["parameter_key_suffix"] = strings.TrimRight(m.ArgNameSuffix, " \t\n")
|
||||
markers["parameter_closer"] = strings.TrimRight(m.ArgValueSuffix, " \t\n")
|
||||
|
||||
if m.CallIDPosition == "between_func_and_args" {
|
||||
markers["call_id_prefix"] = m.CallIDPrefix
|
||||
markers["call_id_suffix"] = m.CallIDSuffix
|
||||
}
|
||||
|
||||
contentDelim := scopeStart
|
||||
if contentDelim == "" {
|
||||
contentDelim = strings.TrimRight(m.FuncNamePrefix, " \t\n")
|
||||
}
|
||||
|
||||
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
tools := []peg.ToolDef{}
|
||||
content := p.Content(p.Until(contentDelim))
|
||||
toolCall := p.StandardConstructedTools(markers, tools, true, false)
|
||||
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
|
||||
})
|
||||
}
|
||||
|
||||
func buildPEGFromMarkersJSONNative(m *ToolFormatMarkers) *peg.Arena {
|
||||
sectionStart, sectionEnd := effectiveScope(m)
|
||||
|
||||
if sectionStart == "" || sectionEnd == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
nameKey := m.NameField
|
||||
if nameKey == "" {
|
||||
nameKey = "name"
|
||||
}
|
||||
argsKey := m.ArgsField
|
||||
if argsKey == "" {
|
||||
argsKey = "arguments"
|
||||
}
|
||||
|
||||
idField := m.IDField
|
||||
genIDField := m.GenIDField
|
||||
|
||||
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
|
||||
// Build field matchers for known keys
|
||||
knownFields := []peg.ParserID{
|
||||
p.Seq(
|
||||
p.Literal("\""+nameKey+"\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
|
||||
),
|
||||
p.Seq(
|
||||
p.Literal("\""+argsKey+"\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.ToolArgs(p.JSON()),
|
||||
),
|
||||
}
|
||||
|
||||
// Add ID field matching if detected
|
||||
if idField != "" {
|
||||
knownFields = append(knownFields, p.Seq(
|
||||
p.Literal("\""+idField+"\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
|
||||
))
|
||||
}
|
||||
if genIDField != "" && genIDField != idField {
|
||||
knownFields = append(knownFields, p.Seq(
|
||||
p.Literal("\""+genIDField+"\""), p.Space(), p.Literal(":"), p.Space(),
|
||||
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
|
||||
))
|
||||
}
|
||||
|
||||
// Catch-all for unknown JSON fields
|
||||
knownFields = append(knownFields, p.Seq(
|
||||
p.Literal("\""), p.JSONString(), p.Literal("\""), p.Space(),
|
||||
p.Literal(":"), p.Space(), p.JSON(),
|
||||
))
|
||||
|
||||
// Build a generic JSON tool call parser that accepts any tool
|
||||
toolCall := p.TriggerRule("tool-call", p.Seq(
|
||||
p.Literal(sectionStart), p.Space(),
|
||||
p.Tool(p.Seq(
|
||||
p.ToolOpen(p.Literal("{")), p.Space(),
|
||||
p.ZeroOrMore(p.Seq(
|
||||
p.Choice(knownFields...),
|
||||
p.Optional(p.Seq(p.Space(), p.Literal(","), p.Space())),
|
||||
)),
|
||||
p.Space(), p.ToolClose(p.Literal("}")),
|
||||
)),
|
||||
p.Space(), p.Literal(sectionEnd),
|
||||
))
|
||||
content := p.Content(p.Until(sectionStart))
|
||||
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
|
||||
})
|
||||
}
|
||||
|
||||
// effectiveScope returns the scope start/end markers to use.
|
||||
// Prefers section markers, falls back to per-call markers, stripping trailing
|
||||
// whitespace so the PEG Space() parser can handle it flexibly.
|
||||
func effectiveScope(m *ToolFormatMarkers) (string, string) {
|
||||
if m.SectionStart != "" {
|
||||
return strings.TrimRight(m.SectionStart, " \t\n"), strings.TrimRight(m.SectionEnd, " \t\n")
|
||||
}
|
||||
if m.PerCallStart != "" {
|
||||
return strings.TrimRight(m.PerCallStart, " \t\n"), strings.TrimRight(m.PerCallEnd, " \t\n")
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// buildCallIDParser creates a parser for call ID markers based on position.
|
||||
// Currently only BETWEEN_FUNC_AND_ARGS is supported (matching llama.cpp behavior).
|
||||
func buildCallIDParser(p *peg.ChatBuilder, m *ToolFormatMarkers) peg.ParserID {
|
||||
if m.CallIDPosition == "between_func_and_args" && m.CallIDPrefix != "" && m.CallIDSuffix != "" {
|
||||
return p.Optional(p.Seq(
|
||||
p.Literal(m.CallIDPrefix),
|
||||
p.ToolID(p.Until(m.CallIDSuffix)),
|
||||
p.Literal(m.CallIDSuffix),
|
||||
))
|
||||
}
|
||||
return p.Eps()
|
||||
}
|
||||
|
||||
// parsePEG runs the PEG parser and extracts tool call results.
|
||||
func parsePEG(arena *peg.Arena, input string) []FuncCallResults {
|
||||
ctx := peg.NewParseContext(input, false)
|
||||
result := arena.Parse(ctx)
|
||||
|
||||
if result.Type != peg.Success {
|
||||
inputPreview := input
|
||||
if len(inputPreview) > 200 {
|
||||
inputPreview = inputPreview[:200] + "..."
|
||||
}
|
||||
xlog.Debug("[PEG] parse did not succeed", "result_type", result.Type, "input_preview", inputPreview)
|
||||
return nil
|
||||
}
|
||||
|
||||
mapper := &peg.ChatPegMapper{}
|
||||
mapper.FromAST(&ctx.Ast, &result)
|
||||
msg := mapper.Result
|
||||
|
||||
xlog.Debug("[PEG] parse succeeded", "content_len", len(msg.Content), "reasoning_len", len(msg.ReasoningContent), "tool_calls", len(msg.ToolCalls))
|
||||
|
||||
if len(msg.ToolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results []FuncCallResults
|
||||
for _, tc := range msg.ToolCalls {
|
||||
xlog.Debug("[PEG] extracted tool call", "name", tc.Name, "id", tc.ID, "args_len", len(tc.Arguments))
|
||||
results = append(results, FuncCallResults{
|
||||
Name: tc.Name,
|
||||
Arguments: tc.Arguments,
|
||||
ID: tc.ID,
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
301
pkg/functions/peg_integration_test.go
Normal file
301
pkg/functions/peg_integration_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package functions_test
|
||||
|
||||
import (
|
||||
. "github.com/mudler/LocalAI/pkg/functions"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("PEG Integration", func() {
|
||||
Context("format presets", func() {
|
||||
It("parses functionary format", func() {
|
||||
input := `I'll help you with that.<function=get_weather>{"location": "NYC", "unit": "celsius"}</function>`
|
||||
|
||||
config := FunctionsConfig{
|
||||
XMLFormatPreset: "functionary",
|
||||
}
|
||||
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
|
||||
})
|
||||
|
||||
It("parses qwen3-coder format", func() {
|
||||
input := "<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function>\n</tool_call>"
|
||||
|
||||
config := FunctionsConfig{
|
||||
XMLFormatPreset: "qwen3-coder",
|
||||
}
|
||||
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
|
||||
})
|
||||
|
||||
It("parses qwen3-coder format with preceding content", func() {
|
||||
input := "Let me think about this...\n<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>"
|
||||
|
||||
config := FunctionsConfig{
|
||||
XMLFormatPreset: "qwen3-coder",
|
||||
}
|
||||
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
})
|
||||
|
||||
It("parses minimax-m2 format", func() {
|
||||
input := "Here's the result.\n<minimax:tool_call>\n<invoke name=\"search\">\n<parameter name=\"query\">test query</parameter>\n</invoke>\n</minimax:tool_call>"
|
||||
|
||||
config := FunctionsConfig{
|
||||
XMLFormatPreset: "minimax-m2",
|
||||
}
|
||||
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("search"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"query"`))
|
||||
})
|
||||
|
||||
It("handles glm-4.5 format gracefully", func() {
|
||||
input := "<tool_call><arg_key>location</arg_key><arg_value>NYC</arg_value></tool_call>"
|
||||
|
||||
config := FunctionsConfig{
|
||||
XMLFormatPreset: "glm-4.5",
|
||||
}
|
||||
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
// GLM-4.5 uses tool_call as both scope and tool start with no function name separator,
|
||||
// so the PEG parser may not handle it perfectly.
|
||||
if len(results) > 0 {
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("auto-detect", func() {
|
||||
It("detects format without preset", func() {
|
||||
input := "<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>"
|
||||
|
||||
config := FunctionsConfig{}
|
||||
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("custom XML format", func() {
|
||||
It("parses with custom format config", func() {
|
||||
input := "<tool_call>\n<function=edit>\n<parameter=filename>\ntest.py\n</parameter>\n<parameter=content>\nhello world\n</parameter>\n</function>\n</tool_call>"
|
||||
|
||||
config := FunctionsConfig{
|
||||
XMLFormat: &XMLToolCallFormat{
|
||||
ScopeStart: "<tool_call>",
|
||||
ToolStart: "<function=",
|
||||
ToolSep: ">",
|
||||
KeyStart: "<parameter=",
|
||||
KeyValSep: ">",
|
||||
ValEnd: "</parameter>",
|
||||
ToolEnd: "</function>",
|
||||
ScopeEnd: "</tool_call>",
|
||||
TrimRawArgVal: true,
|
||||
},
|
||||
}
|
||||
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("edit"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"filename"`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("no tool calls", func() {
|
||||
It("returns empty results for plain text", func() {
|
||||
input := "This is just a regular response with no tool calls."
|
||||
|
||||
config := FunctionsConfig{
|
||||
XMLFormatPreset: "qwen3-coder",
|
||||
}
|
||||
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ParseFunctionCall integration", func() {
|
||||
It("finds tool calls via PEG in ParseFunctionCall flow", func() {
|
||||
input := "<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>"
|
||||
|
||||
config := FunctionsConfig{}
|
||||
|
||||
results := ParseFunctionCall(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
})
|
||||
|
||||
It("finds functionary tool calls via ParseFunctionCall", func() {
|
||||
input := `Sure!<function=calculator>{"expression": "2+2"}</function>`
|
||||
|
||||
config := FunctionsConfig{
|
||||
XMLFormatPreset: "functionary",
|
||||
}
|
||||
|
||||
results := ParseFunctionCall(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("calculator"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"expression"`))
|
||||
})
|
||||
})
|
||||
|
||||
Context("DisablePEGParser", func() {
|
||||
It("still works when called directly but skips PEG in ParseFunctionCall", func() {
|
||||
input := "<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>"
|
||||
|
||||
config := FunctionsConfig{
|
||||
DisablePEGParser: true,
|
||||
}
|
||||
|
||||
// ParseFunctionCallPEG should still work when called directly
|
||||
pegResults := ParseFunctionCallPEG(input, config)
|
||||
// May or may not find results depending on auto-detect
|
||||
_ = pegResults
|
||||
|
||||
// ParseFunctionCall with PEG disabled should NOT find XML tool calls
|
||||
disabledResults := ParseFunctionCall(input, config)
|
||||
// May find via JSON extraction
|
||||
_ = disabledResults
|
||||
|
||||
// ParseXML (iterative parser) should still find results
|
||||
xmlResults, err := ParseXML(input, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(xmlResults).NotTo(BeEmpty())
|
||||
Expect(xmlResults[0].Name).To(Equal("get_weather"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("markers-based parsing", func() {
|
||||
It("parses tag_with_json format from markers", func() {
|
||||
input := `Hello!<function=get_weather>{"location": "NYC"}</function>`
|
||||
|
||||
markers := &ToolFormatMarkers{
|
||||
FormatType: "tag_with_json",
|
||||
FuncNamePrefix: "<function=",
|
||||
FuncNameSuffix: ">",
|
||||
FuncClose: "</function>",
|
||||
}
|
||||
|
||||
arena := BuildPEGParserFromMarkers(markers)
|
||||
Expect(arena).NotTo(BeNil())
|
||||
|
||||
config := FunctionsConfig{
|
||||
ToolFormatMarkers: markers,
|
||||
}
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
|
||||
})
|
||||
|
||||
It("parses tag_with_tagged format from markers", func() {
|
||||
input := "<tool_call>\n<function=get_weather>\n<parameter=location>NYC</parameter>\n</function>\n</tool_call>"
|
||||
|
||||
markers := &ToolFormatMarkers{
|
||||
FormatType: "tag_with_tagged",
|
||||
SectionStart: "<tool_call>",
|
||||
SectionEnd: "</tool_call>",
|
||||
FuncNamePrefix: "<function=",
|
||||
FuncNameSuffix: ">",
|
||||
FuncClose: "</function>",
|
||||
ArgNamePrefix: "<parameter=",
|
||||
ArgNameSuffix: ">",
|
||||
ArgValueSuffix: "</parameter>",
|
||||
}
|
||||
|
||||
config := FunctionsConfig{
|
||||
ToolFormatMarkers: markers,
|
||||
}
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
|
||||
})
|
||||
|
||||
It("parses json_native format from markers", func() {
|
||||
input := `Some content<tool_call>{"name": "get_weather", "arguments": {"location": "NYC"}}</tool_call>`
|
||||
|
||||
markers := &ToolFormatMarkers{
|
||||
FormatType: "json_native",
|
||||
SectionStart: "<tool_call>",
|
||||
SectionEnd: "</tool_call>",
|
||||
NameField: "name",
|
||||
ArgsField: "arguments",
|
||||
}
|
||||
|
||||
config := FunctionsConfig{
|
||||
ToolFormatMarkers: markers,
|
||||
}
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
|
||||
})
|
||||
|
||||
It("returns nil arena for unknown format type", func() {
|
||||
markers := &ToolFormatMarkers{
|
||||
FormatType: "unknown",
|
||||
}
|
||||
arena := BuildPEGParserFromMarkers(markers)
|
||||
Expect(arena).To(BeNil())
|
||||
})
|
||||
|
||||
It("parses json_native format with ID field", func() {
|
||||
input := `Some content<tool_call>{"name": "get_weather", "arguments": {"location": "NYC"}, "id": "call_123"}</tool_call>`
|
||||
|
||||
markers := &ToolFormatMarkers{
|
||||
FormatType: "json_native",
|
||||
SectionStart: "<tool_call>",
|
||||
SectionEnd: "</tool_call>",
|
||||
NameField: "name",
|
||||
ArgsField: "arguments",
|
||||
IDField: "id",
|
||||
}
|
||||
|
||||
config := FunctionsConfig{
|
||||
ToolFormatMarkers: markers,
|
||||
}
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
|
||||
Expect(results[0].ID).To(Equal("call_123"))
|
||||
})
|
||||
|
||||
It("parses call ID between function name and arguments", func() {
|
||||
input := `<tool_call><function=get_weather>[call_abc]{"location": "NYC"}</function></tool_call>`
|
||||
|
||||
markers := &ToolFormatMarkers{
|
||||
FormatType: "tag_with_json",
|
||||
SectionStart: "<tool_call>",
|
||||
SectionEnd: "</tool_call>",
|
||||
FuncNamePrefix: "<function=",
|
||||
FuncNameSuffix: ">",
|
||||
FuncClose: "</function>",
|
||||
CallIDPosition: "between_func_and_args",
|
||||
CallIDPrefix: "[",
|
||||
CallIDSuffix: "]",
|
||||
}
|
||||
|
||||
config := FunctionsConfig{
|
||||
ToolFormatMarkers: markers,
|
||||
}
|
||||
results := ParseFunctionCallPEG(input, config)
|
||||
Expect(results).NotTo(BeEmpty())
|
||||
Expect(results[0].Name).To(Equal("get_weather"))
|
||||
Expect(results[0].ID).To(Equal("call_abc"))
|
||||
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user