feat(functions): add peg-based parsing and allow backends to return tool calls directly (#8838)

* feat(functions): add peg-based parsing

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: support returning toolcalls directly from backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: do run PEG only if backend didn't send deltas

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-03-08 22:21:57 +01:00
committed by GitHub
parent b57a6e42f1
commit b2f81bfa2e
25 changed files with 6204 additions and 1090 deletions

View File

@@ -165,6 +165,22 @@ message PredictOptions {
map<string, string> Metadata = 52; // Generic per-request metadata (e.g., enable_thinking)
}
// ToolCallDelta represents an incremental tool call update from the C++ parser.
// Used for both streaming (partial diffs) and non-streaming (final tool calls).
message ToolCallDelta {
int32 index = 1; // tool call index (0-based)
string id = 2; // tool call ID (e.g., "call_abc123")
string name = 3; // function name (set on first appearance)
string arguments = 4; // arguments chunk (incremental in streaming, full in non-streaming)
}
// ChatDelta represents incremental content/reasoning/tool_call updates parsed by the C++ backend.
message ChatDelta {
string content = 1; // content text delta
string reasoning_content = 2; // reasoning/thinking text delta
repeated ToolCallDelta tool_calls = 3; // tool call deltas
}
// The response message containing the result
message Reply {
bytes message = 1;
@@ -174,6 +190,7 @@ message Reply {
double timing_token_generation = 5;
bytes audio = 6;
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
repeated ChatDelta chat_deltas = 8; // Parsed chat deltas from C++ autoparser (streaming + non-streaming)
}
message GrammarTrigger {
@@ -425,7 +442,62 @@ message DetectResponse {
repeated Detection Detections = 1;
}
message ToolFormatMarkers {
string format_type = 1; // "json_native", "tag_with_json", "tag_with_tagged"
// Tool section markers
string section_start = 2; // e.g., "<tool_call>", "[TOOL_CALLS]"
string section_end = 3; // e.g., "</tool_call>"
string per_call_start = 4; // e.g., "<|tool_call_begin|>"
string per_call_end = 5; // e.g., "<|tool_call_end|>"
// Function name markers (TAG_WITH_JSON / TAG_WITH_TAGGED)
string func_name_prefix = 6; // e.g., "<function="
string func_name_suffix = 7; // e.g., ">"
string func_close = 8; // e.g., "</function>"
// Argument markers (TAG_WITH_TAGGED)
string arg_name_prefix = 9; // e.g., "<param="
string arg_name_suffix = 10; // e.g., ">"
string arg_value_prefix = 11;
string arg_value_suffix = 12; // e.g., "</param>"
string arg_separator = 13; // e.g., "\n"
// JSON format fields (JSON_NATIVE)
string name_field = 14; // e.g., "name"
string args_field = 15; // e.g., "arguments"
string id_field = 16; // e.g., "id"
bool fun_name_is_key = 17;
bool tools_array_wrapped = 18;
bool uses_python_dicts = 19;
// Reasoning markers
string reasoning_start = 20; // e.g., "<think>"
string reasoning_end = 21; // e.g., "</think>"
// Content markers
string content_start = 22;
string content_end = 23;
// Args wrapper markers
string args_start = 24; // e.g., "<args>"
string args_end = 25; // e.g., "</args>"
// JSON parameter ordering
string function_field = 26; // e.g., "function" (wrapper key in JSON)
repeated string parameter_order = 27;
// Generated ID field (alternative field name for generated IDs)
string gen_id_field = 28; // e.g., "call_id"
// Call ID markers (position and delimiters for tool call IDs)
string call_id_position = 29; // "none", "pre_func_name", "between_func_and_args", "post_args"
string call_id_prefix = 30; // e.g., "[CALL_ID]"
string call_id_suffix = 31; // e.g., ""
}
message ModelMetadataResponse {
bool supports_thinking = 1;
string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable)
ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis
}

View File

@@ -17,6 +17,7 @@
#include "backend.pb.h"
#include "backend.grpc.pb.h"
#include "common.h"
#include "chat-auto-parser.h"
#include <getopt.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h>
@@ -866,6 +867,56 @@ public:
return logprobs_json;
}
// Helper: populate chat_deltas on a Reply from oaicompat_msg_diffs (streaming chunks)
static void populate_chat_deltas_from_diffs(backend::Reply & reply,
const std::vector<common_chat_msg_diff> & diffs) {
for (const auto & diff : diffs) {
auto* delta = reply.add_chat_deltas();
if (!diff.content_delta.empty()) {
delta->set_content(diff.content_delta);
}
if (!diff.reasoning_content_delta.empty()) {
delta->set_reasoning_content(diff.reasoning_content_delta);
}
if (diff.tool_call_index != std::string::npos) {
auto* tc = delta->add_tool_calls();
tc->set_index(static_cast<int32_t>(diff.tool_call_index));
if (!diff.tool_call_delta.id.empty()) {
tc->set_id(diff.tool_call_delta.id);
}
if (!diff.tool_call_delta.name.empty()) {
tc->set_name(diff.tool_call_delta.name);
}
if (!diff.tool_call_delta.arguments.empty()) {
tc->set_arguments(diff.tool_call_delta.arguments);
}
}
}
}
// Helper: populate chat_deltas on a Reply from final oaicompat_msg (non-streaming)
static void populate_chat_deltas_from_final(backend::Reply & reply,
const common_chat_msg & msg) {
// Content delta
if (!msg.content.empty() || !msg.reasoning_content.empty() || !msg.tool_calls.empty()) {
auto* delta = reply.add_chat_deltas();
if (!msg.content.empty()) {
delta->set_content(msg.content);
}
if (!msg.reasoning_content.empty()) {
delta->set_reasoning_content(msg.reasoning_content);
}
// Tool calls as individual deltas within the same ChatDelta
for (size_t i = 0; i < msg.tool_calls.size(); i++) {
auto* tc = delta->add_tool_calls();
tc->set_index(static_cast<int32_t>(i));
tc->set_id(msg.tool_calls[i].id);
tc->set_name(msg.tool_calls[i].name);
tc->set_arguments(msg.tool_calls[i].arguments);
}
}
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
if (params_base.model.path.empty()) {
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded");
@@ -1484,127 +1535,76 @@ public:
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred"));
}
// Lambda to build a Reply from JSON + attach chat deltas from a result
auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply {
backend::Reply reply;
std::string completion_text = res_json.value("content", "");
reply.set_message(completion_text);
reply.set_tokens(res_json.value("tokens_predicted", 0));
reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0));
if (res_json.contains("timings")) {
reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0));
reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0));
}
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
reply.set_logprobs(logprobs_json.dump());
}
return reply;
};
auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) {
// Try streaming partial result first
auto* partial = dynamic_cast<server_task_result_cmpl_partial*>(raw_result);
if (partial && !partial->oaicompat_msg_diffs.empty()) {
populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs);
return;
}
// Try final result
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(raw_result);
if (final_res && final_res->is_updated) {
populate_chat_deltas_from_diffs(reply, final_res->oaicompat_msg_diffs);
}
};
// Process first result
json first_res_json = first_result->to_json();
if (first_res_json.is_array()) {
for (const auto & res : first_res_json) {
std::string completion_text = res.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res.contains("timings")) {
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
auto reply = build_reply_from_json(res, first_result.get());
attach_chat_deltas(reply, first_result.get());
writer->Write(reply);
}
} else {
std::string completion_text = first_res_json.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (first_res_json.contains("timings")) {
double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(first_res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
auto reply = build_reply_from_json(first_res_json, first_result.get());
attach_chat_deltas(reply, first_result.get());
writer->Write(reply);
}
// Process subsequent results
while (rd.has_next()) {
// Check if context is cancelled before processing result
if (context->IsCancelled()) {
break;
}
auto result = rd.next([&context]() { return context->IsCancelled(); });
if (result == nullptr) {
// connection is closed
break;
}
json res_json = result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
std::string completion_text = res.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res.contains("timings")) {
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
auto reply = build_reply_from_json(res, result.get());
attach_chat_deltas(reply, result.get());
writer->Write(reply);
}
} else {
std::string completion_text = res_json.value("content", "");
backend::Reply reply;
reply.set_message(completion_text);
int32_t tokens_predicted = res_json.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);
if (res_json.contains("timings")) {
double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0);
reply.set_timing_prompt_processing(timing_prompt_processing);
double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0);
reply.set_timing_token_generation(timing_token_generation);
}
// Extract and set logprobs if present
json logprobs_json = extract_logprobs_from_json(res_json);
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
std::string logprobs_str = logprobs_json.dump();
reply.set_logprobs(logprobs_str);
}
auto reply = build_reply_from_json(res_json, result.get());
attach_chat_deltas(reply, result.get());
writer->Write(reply);
}
}
@@ -2264,7 +2264,8 @@ public:
std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl;
if (all_results.results.size() == 1) {
// single result
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
auto* final_res = dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get());
GGML_ASSERT(final_res != nullptr);
json result_json = all_results.results[0]->to_json();
reply->set_message(result_json.value("content", ""));
@@ -2287,6 +2288,11 @@ public:
reply->set_logprobs(logprobs_str);
}
// Populate chat deltas from the autoparser's final parsed message
if (final_res->is_updated) {
populate_chat_deltas_from_final(*reply, final_res->oaicompat_msg);
}
} else {
// multiple results (multitask)
json arr = json::array();
@@ -2609,6 +2615,113 @@ public:
response->set_rendered_template(rendered_template);
// Run differential template analysis to detect tool format markers
if (params_base.use_jinja) {
try {
// Get template source and reconstruct a common_chat_template for analysis
std::string tmpl_src = common_chat_templates_source(ctx_server.impl->chat_params.tmpls.get());
if (!tmpl_src.empty()) {
const auto * vocab = llama_model_get_vocab(ctx_server.impl->model);
std::string token_bos, token_eos;
if (vocab) {
auto bos_id = llama_vocab_bos(vocab);
auto eos_id = llama_vocab_eos(vocab);
if (bos_id != LLAMA_TOKEN_NULL) {
token_bos = common_token_to_piece(vocab, bos_id, true);
}
if (eos_id != LLAMA_TOKEN_NULL) {
token_eos = common_token_to_piece(vocab, eos_id, true);
}
}
common_chat_template tmpl(tmpl_src, token_bos, token_eos);
struct autoparser::autoparser ap;
ap.analyze_template(tmpl);
if (ap.analysis_complete && ap.tools.format.mode != autoparser::tool_format::NONE) {
auto * tf = response->mutable_tool_format();
// Format type
switch (ap.tools.format.mode) {
case autoparser::tool_format::JSON_NATIVE:
tf->set_format_type("json_native");
break;
case autoparser::tool_format::TAG_WITH_JSON:
tf->set_format_type("tag_with_json");
break;
case autoparser::tool_format::TAG_WITH_TAGGED:
tf->set_format_type("tag_with_tagged");
break;
default:
break;
}
// Tool section markers
tf->set_section_start(ap.tools.format.section_start);
tf->set_section_end(ap.tools.format.section_end);
tf->set_per_call_start(ap.tools.format.per_call_start);
tf->set_per_call_end(ap.tools.format.per_call_end);
// Function markers
tf->set_func_name_prefix(ap.tools.function.name_prefix);
tf->set_func_name_suffix(ap.tools.function.name_suffix);
tf->set_func_close(ap.tools.function.close);
// Argument markers
tf->set_arg_name_prefix(ap.tools.arguments.name_prefix);
tf->set_arg_name_suffix(ap.tools.arguments.name_suffix);
tf->set_arg_value_prefix(ap.tools.arguments.value_prefix);
tf->set_arg_value_suffix(ap.tools.arguments.value_suffix);
tf->set_arg_separator(ap.tools.arguments.separator);
tf->set_args_start(ap.tools.arguments.start);
tf->set_args_end(ap.tools.arguments.end);
// JSON format fields
tf->set_name_field(ap.tools.format.name_field);
tf->set_args_field(ap.tools.format.args_field);
tf->set_id_field(ap.tools.format.id_field);
tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key);
tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped);
tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts);
tf->set_function_field(ap.tools.format.function_field);
tf->set_gen_id_field(ap.tools.format.gen_id_field);
for (const auto & p : ap.tools.format.parameter_order) {
tf->add_parameter_order(p);
}
// Call ID markers
switch (ap.tools.call_id.pos) {
case autoparser::call_id_position::NONE:
tf->set_call_id_position("none");
break;
case autoparser::call_id_position::PRE_FUNC_NAME:
tf->set_call_id_position("pre_func_name");
break;
case autoparser::call_id_position::BETWEEN_FUNC_AND_ARGS:
tf->set_call_id_position("between_func_and_args");
break;
case autoparser::call_id_position::POST_ARGS:
tf->set_call_id_position("post_args");
break;
}
tf->set_call_id_prefix(ap.tools.call_id.prefix);
tf->set_call_id_suffix(ap.tools.call_id.suffix);
// Reasoning markers
tf->set_reasoning_start(ap.reasoning.start);
tf->set_reasoning_end(ap.reasoning.end);
// Content markers
tf->set_content_start(ap.content.start);
tf->set_content_end(ap.content.end);
}
}
} catch (const std::exception & e) {
SRV_WRN("ModelMetadata: failed to run autoparser analysis: %s\n", e.what());
}
}
return grpc::Status::OK;
}
};

View File

@@ -28,6 +28,7 @@ type LLMResponse struct {
Usage TokenUsage
AudioOutput string
Logprobs *schema.Logprobs // Logprobs from the backend response
ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser
}
type TokenUsage struct {
@@ -142,6 +143,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
ss := ""
var logprobs *schema.Logprobs
var allChatDeltas []*proto.ChatDelta
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
@@ -153,6 +155,11 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
// Collect chat deltas from C++ autoparser
if len(reply.ChatDeltas) > 0 {
allChatDeltas = append(allChatDeltas, reply.ChatDeltas...)
}
// Parse logprobs from reply if present (collect from last chunk that has them)
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
@@ -183,10 +190,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
tokenCallback("", tokenUsage)
}
})
if len(allChatDeltas) > 0 {
xlog.Debug("[ChatDeltas] streaming completed, accumulated deltas from C++ autoparser", "total_deltas", len(allChatDeltas))
}
return LLMResponse{
Response: ss,
Usage: tokenUsage,
Logprobs: logprobs,
Response: ss,
Usage: tokenUsage,
Logprobs: logprobs,
ChatDeltas: allChatDeltas,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
@@ -218,10 +229,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
}
}
if len(reply.ChatDeltas) > 0 {
xlog.Debug("[ChatDeltas] non-streaming Predict received deltas from C++ autoparser", "total_deltas", len(reply.ChatDeltas))
}
return LLMResponse{
Response: response,
Usage: tokenUsage,
Logprobs: logprobs,
Response: response,
Usage: tokenUsage,
Logprobs: logprobs,
ChatDeltas: reply.ChatDeltas,
}, err
}
}

View File

@@ -3,6 +3,7 @@ package config
import (
"context"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/reasoning"
@@ -118,5 +119,46 @@ func DetectThinkingSupportFromBackend(ctx context.Context, cfg *ModelConfig, bac
cfg.ReasoningConfig.DisableReasoningTagPrefill = ptr.To(true)
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: thinking support detected", "supports_thinking", metadata.SupportsThinking, "thinking_forced_open", false)
}
// Extract tool format markers from autoparser analysis
if tf := metadata.GetToolFormat(); tf != nil && tf.FormatType != "" {
cfg.FunctionsConfig.ToolFormatMarkers = &functions.ToolFormatMarkers{
FormatType: tf.FormatType,
SectionStart: tf.SectionStart,
SectionEnd: tf.SectionEnd,
PerCallStart: tf.PerCallStart,
PerCallEnd: tf.PerCallEnd,
FuncNamePrefix: tf.FuncNamePrefix,
FuncNameSuffix: tf.FuncNameSuffix,
FuncClose: tf.FuncClose,
ArgNamePrefix: tf.ArgNamePrefix,
ArgNameSuffix: tf.ArgNameSuffix,
ArgValuePrefix: tf.ArgValuePrefix,
ArgValueSuffix: tf.ArgValueSuffix,
ArgSeparator: tf.ArgSeparator,
ArgsStart: tf.ArgsStart,
ArgsEnd: tf.ArgsEnd,
NameField: tf.NameField,
ArgsField: tf.ArgsField,
IDField: tf.IdField,
FunNameIsKey: tf.FunNameIsKey,
ToolsArrayWrapped: tf.ToolsArrayWrapped,
UsesPythonDicts: tf.UsesPythonDicts,
FunctionField: tf.FunctionField,
ParameterOrder: tf.ParameterOrder,
GenIDField: tf.GenIdField,
CallIDPosition: tf.CallIdPosition,
CallIDPrefix: tf.CallIdPrefix,
CallIDSuffix: tf.CallIdSuffix,
ReasoningStart: tf.ReasoningStart,
ReasoningEnd: tf.ReasoningEnd,
ContentStart: tf.ContentStart,
ContentEnd: tf.ContentEnd,
}
xlog.Debug("[gguf] DetectThinkingSupportFromBackend: tool format markers detected",
"format_type", tf.FormatType,
"section_start", tf.SectionStart,
"func_name_prefix", tf.FuncNamePrefix)
}
}
}

View File

@@ -141,8 +141,15 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
xlog.Warn("Anthropic: retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
}
// Check if the result contains tool calls
toolCalls := functions.ParseFunctionCall(result, cfg.FunctionsConfig)
// Try pre-parsed tool calls from C++ autoparser first, fall back to text parsing
var toolCalls []functions.FuncCallResults
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] Anthropic: using pre-parsed tool calls", "count", len(deltaToolCalls))
toolCalls = deltaToolCalls
} else {
xlog.Debug("[ChatDeltas] Anthropic: no pre-parsed tool calls, falling back to Go-side text parsing")
toolCalls = functions.ParseFunctionCall(result, cfg.FunctionsConfig)
}
var contentBlocks []schema.AnthropicContentBlock
var stopReason string

View File

@@ -16,6 +16,7 @@ import (
reason "github.com/mudler/LocalAI/pkg/reasoning"
"github.com/mudler/LocalAI/core/templates"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
@@ -55,7 +56,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
lastEmittedReasoning := ""
lastEmittedCleanedContent := ""
_, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
_, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
accumulatedContent += s
currentReasoning, cleanedContent := reason.ExtractReasoningWithConfig(accumulatedContent, thinkingStartToken, config.ReasoningConfig)
@@ -141,7 +142,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
result := ""
lastEmittedCount := 0
_, tokenUsage, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
_, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
// Try incremental XML parsing for streaming support using iterative parser
// This allows emitting partial tool calls as they're being generated
@@ -250,13 +251,25 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
if err != nil {
return err
}
// Prepend thinking token if needed, then extract reasoning before processing tool calls
reasoning, result := reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig)
// Try using pre-parsed tool calls from C++ autoparser (chat deltas)
var functionResults []functions.FuncCallResults
var reasoning string
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig)
xlog.Debug("Text content to return", "text", textContentToReturn)
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls))
functionResults = deltaToolCalls
// Use content/reasoning from deltas too
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
reasoning = functions.ReasoningFromChatDeltas(chatDeltas)
} else {
// Fallback: parse tool calls from raw text (no chat deltas from backend)
xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing")
reasoning, result = reason.ExtractReasoningWithConfig(result, thinkingStartToken, config.ReasoningConfig)
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
functionResults = functions.ParseFunctionCall(result, config.FunctionsConfig)
}
xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", textContentToReturn)
noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0
switch {
@@ -308,6 +321,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
default:
for i, ss := range functionResults {
name, args := ss.Name, ss.Arguments
toolCallID := ss.ID
if toolCallID == "" {
toolCallID = id
}
initialMessage := schema.OpenAIResponse{
ID: id,
@@ -319,7 +336,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
ID: toolCallID,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
@@ -345,7 +362,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
ID: toolCallID,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
@@ -656,10 +673,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
xlog.Debug("Thinking start token", "thinkingStartToken", thinkingStartToken, "template", template)
// When shouldUseFn, the callback just stores the raw text — tool parsing
// is deferred to after ComputeChoices so we can check chat deltas first
// and avoid redundant Go-side parsing.
var cbRawResult, cbReasoning string
var emptyRetryNeeded bool
tokenCallback := func(s string, c *[]schema.Choice) {
// Prepend thinking token if needed, then extract reasoning from the response
reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig)
if !shouldUseFn {
@@ -672,102 +692,20 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
return
}
textContentToReturn = functions.ParseTextContent(s, config.FunctionsConfig)
s = functions.CleanupLLMResult(s, config.FunctionsConfig)
results := functions.ParseFunctionCall(s, config.FunctionsConfig)
xlog.Debug("Text content to return", "text", textContentToReturn)
noActionsToRun := len(results) > 0 && results[0].Name == noActionName || len(results) == 0
switch {
case noActionsToRun:
if s == "" && textContentToReturn == "" {
xlog.Warn("Backend returned empty content in tool-calling context, will retry")
emptyRetryNeeded = true
return
}
result, err := handleQuestion(config, results, s, predInput)
if err != nil {
xlog.Error("error handling question", "error", err)
emptyRetryNeeded = true
return
}
stopReason := FinishReasonStop
message := &schema.Message{Role: "assistant", Content: &result}
if reasoning != "" {
message.Reasoning = &reasoning
}
*c = append(*c, schema.Choice{
FinishReason: &stopReason,
Message: message})
default:
toolCallsReason := FinishReasonToolCalls
toolChoice := schema.Choice{
FinishReason: &toolCallsReason,
Message: &schema.Message{
Role: "assistant",
},
}
if reasoning != "" {
toolChoice.Message.Reasoning = &reasoning
}
for _, ss := range results {
name, args := ss.Name, ss.Arguments
if len(input.Tools) > 0 {
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
toolChoice.Message.Content = textContentToReturn
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
)
} else {
// otherwise we return more choices directly (deprecated)
functionCallReason := FinishReasonFunctionCall
message := &schema.Message{
Role: "assistant",
Content: &textContentToReturn,
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
}
if reasoning != "" {
message.Reasoning = &reasoning
}
*c = append(*c, schema.Choice{
FinishReason: &functionCallReason,
Message: message,
})
}
}
if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
}
}
// Store raw text for deferred tool parsing
cbRawResult = s
cbReasoning = reasoning
}
// Echo properly supports context cancellation via c.Request().Context()
// No workaround needed!
const maxEmptyRetries = 5
var result []schema.Choice
var tokenUsage backend.TokenUsage
var err error
var chatDeltas []*pb.ChatDelta
for attempt := 0; attempt <= maxEmptyRetries; attempt++ {
emptyRetryNeeded = false
result, tokenUsage, err = ComputeChoices(
result, tokenUsage, chatDeltas, err = ComputeChoices(
input,
predInput,
config,
@@ -777,7 +715,111 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
tokenCallback,
nil,
)
if err != nil || !emptyRetryNeeded {
if err != nil {
break
}
// Tool parsing is deferred here (only when shouldUseFn)
if shouldUseFn {
var funcResults []functions.FuncCallResults
// Try pre-parsed tool calls from C++ autoparser first
if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls))
funcResults = deltaToolCalls
textContentToReturn = functions.ContentFromChatDeltas(chatDeltas)
cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas)
} else {
// Fallback: parse tool calls from raw text
xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing")
textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig)
cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig)
funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig)
}
noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0
switch {
case noActionsToRun:
if cbRawResult == "" && textContentToReturn == "" {
xlog.Warn("Backend returned empty content in tool-calling context, will retry")
emptyRetryNeeded = true
continue
}
qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput)
if qErr != nil {
xlog.Error("error handling question", "error", qErr)
emptyRetryNeeded = true
continue
}
stopReason := FinishReasonStop
message := &schema.Message{Role: "assistant", Content: &qResult}
if cbReasoning != "" {
message.Reasoning = &cbReasoning
}
result = append(result, schema.Choice{
FinishReason: &stopReason,
Message: message,
})
default:
toolCallsReason := FinishReasonToolCalls
toolChoice := schema.Choice{
FinishReason: &toolCallsReason,
Message: &schema.Message{
Role: "assistant",
},
}
if cbReasoning != "" {
toolChoice.Message.Reasoning = &cbReasoning
}
for _, ss := range funcResults {
name, args := ss.Name, ss.Arguments
toolCallID := ss.ID
if toolCallID == "" {
toolCallID = id
}
if len(input.Tools) > 0 {
toolChoice.Message.Content = textContentToReturn
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: toolCallID,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
)
} else {
// Deprecated function_call format
functionCallReason := FinishReasonFunctionCall
message := &schema.Message{
Role: "assistant",
Content: &textContentToReturn,
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
}
if cbReasoning != "" {
message.Reasoning = &cbReasoning
}
result = append(result, schema.Choice{
FinishReason: &functionCallReason,
Message: message,
})
}
}
if len(input.Tools) > 0 {
result = append(result, toolChoice)
}
}
}
if !emptyRetryNeeded {
break
}
xlog.Warn("Retrying prediction due to empty backend response", "attempt", attempt+1, "maxRetries", maxEmptyRetries)
@@ -796,6 +838,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
Message: &schema.Message{Role: "assistant", Content: &empty},
})
}
usage := schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,

View File

@@ -57,7 +57,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
responses <- resp
return true
}
_, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
_, _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
close(responses)
return err
}
@@ -216,7 +216,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
xlog.Debug("Template found, input modified", "input", i)
}
r, tokenUsage, err := ComputeChoices(
r, tokenUsage, _, err := ComputeChoices(
input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
stopReason := FinishReasonStop
*c = append(*c, schema.Choice{Text: s, FinishReason: &stopReason, Index: k})

View File

@@ -58,7 +58,7 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
xlog.Debug("Template found, input modified", "input", i)
}
r, tokenUsage, err := ComputeChoices(input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
r, tokenUsage, _, err := ComputeChoices(input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
*c = append(*c, schema.Choice{Text: s})
}, nil)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
)
@@ -18,7 +19,7 @@ func ComputeChoices(
o *config.ApplicationConfig,
loader *model.ModelLoader,
cb func(string, *[]schema.Choice),
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) {
n := req.N // number of completions to return
result := []schema.Choice{}
@@ -84,15 +85,16 @@ func ComputeChoices(
predFunc, err := backend.ModelInference(
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, req.Metadata)
if err != nil {
return result, backend.TokenUsage{}, err
return result, backend.TokenUsage{}, nil, err
}
tokenUsage := backend.TokenUsage{}
var allChatDeltas []*pb.ChatDelta
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return result, backend.TokenUsage{}, err
return result, backend.TokenUsage{}, nil, err
}
tokenUsage.Prompt += prediction.Usage.Prompt
@@ -100,6 +102,11 @@ func ComputeChoices(
tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing
tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration
// Collect chat deltas from C++ autoparser
if len(prediction.ChatDeltas) > 0 {
allChatDeltas = append(allChatDeltas, prediction.ChatDeltas...)
}
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
@@ -111,5 +118,5 @@ func ComputeChoices(
//result = append(result, Choice{Text: prediction})
}
return result, tokenUsage, err
return result, tokenUsage, allChatDeltas, err
}

View File

@@ -826,9 +826,20 @@ func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, respon
var toolCalls []schema.ToolCall
if shouldUseFn {
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
funcCallResults := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
var funcCallResults []functions.FuncCallResults
var textContent string
// Try pre-parsed tool calls from C++ autoparser first
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls))
funcCallResults = deltaToolCalls
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
} else {
xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing")
cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig)
funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
}
noActionName := "answer"
if cfg.FunctionsConfig.NoActionFunctionName != "" {
@@ -1535,13 +1546,22 @@ func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt i
}
if shouldUseFn {
// Clean up the result (already extracted reasoning above)
cleanedResult = functions.CleanupLLMResult(cleanedResult, cfg.FunctionsConfig)
xlog.Debug("Open Responses - Cleaned result", "cleanedResult", cleanedResult)
var funcCallResults []functions.FuncCallResults
var textContent string
funcCallResults := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
xlog.Debug("Open Responses - Parsed function calls", "count", len(funcCallResults), "textContent", textContent)
// Try pre-parsed tool calls from C++ autoparser first
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls))
funcCallResults = deltaToolCalls
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
} else {
xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing")
// Clean up the result (already extracted reasoning above)
cleanedResult = functions.CleanupLLMResult(cleanedResult, cfg.FunctionsConfig)
funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
}
xlog.Debug("[ChatDeltas] OpenResponses: final tool call decision", "count", len(funcCallResults), "textContent", textContent)
// Check for noAction function (model chose to respond without tool)
noActionName := "answer"
@@ -2128,11 +2148,20 @@ func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int6
}
}
cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig)
xlog.Debug("Open Responses Stream - Cleaned result", "cleanedResult", cleanedResult)
var parsedToolCalls []functions.FuncCallResults
var textContent string
parsedToolCalls := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
textContent := functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
// Try pre-parsed tool calls from C++ autoparser first
if deltaToolCalls := functions.ToolCallsFromChatDeltas(prediction.ChatDeltas); len(deltaToolCalls) > 0 {
xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls))
parsedToolCalls = deltaToolCalls
textContent = functions.ContentFromChatDeltas(prediction.ChatDeltas)
} else {
xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing")
cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig)
parsedToolCalls = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig)
textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)
}
// Handle noAction function (model chose to respond without tool)
noActionName := "answer"

View File

@@ -0,0 +1,107 @@
package functions
import (
"strings"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
)
// ToolCallsFromChatDeltas extracts tool calls from C++ autoparser chat deltas.
// Returns nil if no tool calls are present in the deltas.
func ToolCallsFromChatDeltas(deltas []*pb.ChatDelta) []FuncCallResults {
if len(deltas) == 0 {
xlog.Debug("[ChatDeltas] no chat deltas received from backend")
return nil
}
// Count what's in the deltas for logging
totalContentChunks := 0
totalReasoningChunks := 0
totalToolCallChunks := 0
for _, d := range deltas {
if d.Content != "" {
totalContentChunks++
}
if d.ReasoningContent != "" {
totalReasoningChunks++
}
totalToolCallChunks += len(d.ToolCalls)
}
xlog.Debug("[ChatDeltas] received deltas from backend",
"total_deltas", len(deltas),
"content_chunks", totalContentChunks,
"reasoning_chunks", totalReasoningChunks,
"tool_call_chunks", totalToolCallChunks,
)
type toolCallAccum struct {
Name string
Arguments string
ID string
}
byIndex := map[int32]*toolCallAccum{}
var maxIndex int32 = -1
for _, d := range deltas {
for _, tc := range d.ToolCalls {
acc, ok := byIndex[tc.Index]
if !ok {
acc = &toolCallAccum{}
byIndex[tc.Index] = acc
}
if tc.Name != "" {
acc.Name = tc.Name
}
if tc.Id != "" {
acc.ID = tc.Id
}
acc.Arguments += tc.Arguments
if tc.Index > maxIndex {
maxIndex = tc.Index
}
}
}
if len(byIndex) == 0 {
xlog.Debug("[ChatDeltas] deltas present but no tool calls found, falling back to text parsing")
return nil
}
results := make([]FuncCallResults, 0, len(byIndex))
for i := int32(0); i <= maxIndex; i++ {
if acc, ok := byIndex[i]; ok {
xlog.Debug("[ChatDeltas] extracted tool call",
"index", i,
"name", acc.Name,
"id", acc.ID,
"args_length", len(acc.Arguments),
)
results = append(results, FuncCallResults{
Name: acc.Name,
Arguments: acc.Arguments,
ID: acc.ID,
})
}
}
xlog.Debug("[ChatDeltas] using C++ autoparser tool calls, skipping Go-side parsing", "count", len(results))
return results
}
// ContentFromChatDeltas extracts accumulated content text from chat deltas.
func ContentFromChatDeltas(deltas []*pb.ChatDelta) string {
var sb strings.Builder
for _, d := range deltas {
sb.WriteString(d.Content)
}
return sb.String()
}
// ReasoningFromChatDeltas extracts accumulated reasoning text from chat deltas.
func ReasoningFromChatDeltas(deltas []*pb.ChatDelta) string {
var sb strings.Builder
for _, d := range deltas {
sb.WriteString(d.ReasoningContent)
}
return sb.String()
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -378,12 +378,23 @@ roses are red
</parameter>
</function>`
results, err := ParseXML(input, nil)
Expect(err).NotTo(HaveOccurred())
// Use PEG parser with a custom format that has no scope and tagged params
config := FunctionsConfig{
XMLFormat: &XMLToolCallFormat{
ToolStart: "<function=",
ToolSep: ">",
ToolEnd: "</function>",
KeyStart: "<parameter=",
KeyValSep: ">",
ValEnd: "</parameter>",
TrimRawArgVal: true,
},
}
results := ParseFunctionCall(input, config)
Expect(results).To(HaveLen(1))
Expect(results[0].Name).To(Equal("add"))
// JSON parsing converts numeric strings to numbers (matching llama.cpp behavior)
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
Expect(results[0].Arguments).To(ContainSubstring(`"x"`))
Expect(results[0].Arguments).To(ContainSubstring(`"y"`))
})
It("should parse XML tool call with multiple parameters", func() {
@@ -667,19 +678,18 @@ functions.search:0<|tool_call_argument_begin|>{"query": "test", "limit": 10}<|to
})
It("should support partial parsing for streaming", func() {
// Partial XML that ends mid-tag should be detected as partial
// Partial XML that ends mid-tag should be detected
input := `<tool_call>
<function=test>
<parameter=key>
value
</parameter>`
partialResult, err := ParseXMLPartial(input, nil)
// ParseXMLIterative with isPartial=true handles streaming
results, err := ParseXMLIterative(input, nil, true)
Expect(err).NotTo(HaveOccurred())
Expect(partialResult).NotTo(BeNil())
// Should detect partial content
Expect(partialResult).NotTo(BeNil())
Expect(partialResult.IsPartial).To(BeTrue())
// Should return partial results (may have 0 complete tool calls since function is not closed)
_ = results
})
It("should parse JSON values correctly in all formats", func() {

136
pkg/functions/peg/arena.go Normal file
View File

@@ -0,0 +1,136 @@
package peg
import "fmt"
// Arena stores parser instances and provides the Parse entry point.
type Arena struct {
parsers []Parser
rules map[string]ParserID
root ParserID
}
func NewArena() *Arena {
return &Arena{
rules: make(map[string]ParserID),
root: InvalidParserID,
}
}
func (a *Arena) addParser(p Parser) ParserID {
id := ParserID(len(a.parsers))
a.parsers = append(a.parsers, p)
return id
}
func (a *Arena) Get(id ParserID) Parser {
return a.parsers[id]
}
func (a *Arena) Root() ParserID {
return a.root
}
func (a *Arena) SetRoot(id ParserID) {
a.root = id
}
func (a *Arena) GetRule(name string) ParserID {
id, ok := a.rules[name]
if !ok {
panic(fmt.Sprintf("Rule not found: %s", name))
}
return id
}
func (a *Arena) HasRule(name string) bool {
_, ok := a.rules[name]
return ok
}
// Parse parses from the root parser.
func (a *Arena) Parse(ctx *ParseContext) ParseResult {
if a.root == InvalidParserID {
panic("No root parser set")
}
return a.ParseAt(a.root, ctx, 0)
}
// ParseFrom parses from the root parser starting at position start.
func (a *Arena) ParseFrom(ctx *ParseContext, start int) ParseResult {
if a.root == InvalidParserID {
panic("No root parser set")
}
return a.ParseAt(a.root, ctx, start)
}
// ParseAt parses using a specific parser at a given position.
func (a *Arena) ParseAt(id ParserID, ctx *ParseContext, start int) ParseResult {
parser := a.parsers[id]
return parser.parse(a, ctx, start)
}
// ParseAnywhere tries parsing from every position in the input until it succeeds.
func (a *Arena) ParseAnywhere(ctx *ParseContext) ParseResult {
if a.root == InvalidParserID {
panic("No root parser set")
}
if len(ctx.Input) == 0 {
return a.ParseAt(a.root, ctx, 0)
}
for i := 0; i < len(ctx.Input); i++ {
result := a.ParseAt(a.root, ctx, i)
if result.Type == Success || i == len(ctx.Input)-1 {
return result
}
}
return NewParseResult(Fail, 0)
}
// resolveRefs walks all parsers and replaces refs with resolved rule IDs.
func (a *Arena) resolveRefs() {
for i, p := range a.parsers {
switch pt := p.(type) {
case *SequenceParser:
for j, child := range pt.Children {
pt.Children[j] = a.resolveRef(child)
}
case *ChoiceParser:
for j, child := range pt.Children {
pt.Children[j] = a.resolveRef(child)
}
case *RepetitionParser:
pt.Child = a.resolveRef(pt.Child)
case *AndParser:
pt.Child = a.resolveRef(pt.Child)
case *NotParser:
pt.Child = a.resolveRef(pt.Child)
case *RuleParser:
pt.Child = a.resolveRef(pt.Child)
case *TagParser:
pt.Child = a.resolveRef(pt.Child)
case *AtomicParser:
pt.Child = a.resolveRef(pt.Child)
case *SchemaParser:
pt.Child = a.resolveRef(pt.Child)
// Leaf parsers — no children to resolve
case *EpsilonParser, *StartParser, *EndParser, *LiteralParser,
*AnyParser, *SpaceParser, *CharsParser, *JSONStringParser,
*PythonDictStringParser, *UntilParser, *RefParser, *JSONParser,
*jsonNumberParser:
// nothing to do
default:
_ = i // satisfy compiler
}
}
if a.root != InvalidParserID {
a.root = a.resolveRef(a.root)
}
}
func (a *Arena) resolveRef(id ParserID) ParserID {
if ref, ok := a.parsers[id].(*RefParser); ok {
return a.GetRule(ref.Name)
}
return id
}

View File

@@ -0,0 +1,435 @@
package peg
import "regexp"
var invalidRuleCharsRe = regexp.MustCompile(`[^a-zA-Z0-9-]+`)
// Builder provides a fluent API for constructing parsers.
type Builder struct {
arena Arena
}
func NewBuilder() *Builder {
return &Builder{
arena: Arena{
rules: make(map[string]ParserID),
root: InvalidParserID,
},
}
}
func (b *Builder) add(p Parser) ParserID {
return b.arena.addParser(p)
}
// Eps matches nothing, always succeeds.
func (b *Builder) Eps() ParserID {
return b.add(&EpsilonParser{})
}
// Start matches start of input.
func (b *Builder) Start() ParserID {
return b.add(&StartParser{})
}
// End matches end of input.
func (b *Builder) End() ParserID {
return b.add(&EndParser{})
}
// Literal matches an exact string.
func (b *Builder) Literal(s string) ParserID {
return b.add(&LiteralParser{Literal: s})
}
// Seq matches a sequence of parsers in order.
func (b *Builder) Seq(children ...ParserID) ParserID {
// Flatten nested sequences
var flattened []ParserID
for _, id := range children {
if seq, ok := b.arena.parsers[id].(*SequenceParser); ok {
flattened = append(flattened, seq.Children...)
} else {
flattened = append(flattened, id)
}
}
return b.add(&SequenceParser{Children: flattened})
}
// Choice tries alternatives until one succeeds.
func (b *Builder) Choice(children ...ParserID) ParserID {
// Flatten nested choices
var flattened []ParserID
for _, id := range children {
if ch, ok := b.arena.parsers[id].(*ChoiceParser); ok {
flattened = append(flattened, ch.Children...)
} else {
flattened = append(flattened, id)
}
}
return b.add(&ChoiceParser{Children: flattened})
}
// Optional matches zero or one occurrence.
func (b *Builder) Optional(child ParserID) ParserID {
return b.Repeat(child, 0, 1)
}
// ZeroOrMore matches zero or more occurrences.
func (b *Builder) ZeroOrMore(child ParserID) ParserID {
return b.Repeat(child, 0, -1)
}
// OneOrMore matches one or more occurrences.
func (b *Builder) OneOrMore(child ParserID) ParserID {
return b.Repeat(child, 1, -1)
}
// Repeat matches between min and max times. Use -1 for unbounded max.
func (b *Builder) Repeat(child ParserID, min, max int) ParserID {
return b.add(&RepetitionParser{Child: child, MinCount: min, MaxCount: max})
}
// Peek is a positive lookahead — succeeds if child succeeds, consumes nothing.
func (b *Builder) Peek(child ParserID) ParserID {
return b.add(&AndParser{Child: child})
}
// Negate is a negative lookahead — succeeds if child fails, consumes nothing.
func (b *Builder) Negate(child ParserID) ParserID {
return b.add(&NotParser{Child: child})
}
// Any matches a single UTF-8 codepoint.
func (b *Builder) Any() ParserID {
return b.add(&AnyParser{})
}
// Space matches zero or more whitespace characters.
func (b *Builder) Space() ParserID {
return b.add(&SpaceParser{})
}
// Chars matches characters from a character class expression like "[a-z]".
func (b *Builder) Chars(classes string, min, max int) ParserID {
ranges, negated := parseCharClasses(classes)
return b.add(&CharsParser{
Pattern: classes,
Ranges: ranges,
Negated: negated,
MinCount: min,
MaxCount: max,
})
}
// Until matches all characters until a delimiter is found (not consumed).
func (b *Builder) Until(delimiter string) ParserID {
return b.add(&UntilParser{Delimiters: []string{delimiter}})
}
// UntilOneOf matches until any of the delimiters is found.
func (b *Builder) UntilOneOf(delimiters ...string) ParserID {
return b.add(&UntilParser{Delimiters: delimiters})
}
// Rest matches everything to end of input.
func (b *Builder) Rest() ParserID {
return b.add(&UntilParser{Delimiters: nil})
}
// JSONString matches JSON string content (without surrounding quotes).
func (b *Builder) JSONString() ParserID {
return b.add(&JSONStringParser{})
}
// JSON matches a complete JSON value.
func (b *Builder) JSON() ParserID {
return b.add(&JSONParser{})
}
// JSONNumber matches a JSON number.
func (b *Builder) JSONNumber() ParserID {
// We implement this as a dedicated parser entry that delegates to parseJSONNumber
return b.add(&jsonNumberParser{})
}
// PythonDictString matches single-quoted string content (without quotes).
func (b *Builder) PythonDictString() ParserID {
return b.add(&PythonDictStringParser{})
}
// DoubleQuotedString matches a double-quoted string: "content" + space
func (b *Builder) DoubleQuotedString() ParserID {
return b.LazyRule("dq-string", func() ParserID {
return b.Seq(b.Literal(`"`), b.JSONString(), b.Literal(`"`), b.Space())
})
}
// SingleQuotedString matches a single-quoted string: 'content' + space
func (b *Builder) SingleQuotedString() ParserID {
return b.LazyRule("sq-string", func() ParserID {
return b.Seq(b.Literal("'"), b.PythonDictString(), b.Literal("'"), b.Space())
})
}
// FlexibleString matches either a double or single-quoted string.
func (b *Builder) FlexibleString() ParserID {
return b.LazyRule("flexible-string", func() ParserID {
return b.Choice(b.DoubleQuotedString(), b.SingleQuotedString())
})
}
// Marker matches <...> or [...] delimited text.
func (b *Builder) Marker() ParserID {
return b.Choice(
b.Seq(b.Literal("<"), b.Until(">"), b.Literal(">")),
b.Seq(b.Literal("["), b.Until("]"), b.Literal("]")),
)
}
// PythonValue matches a Python-style value (dict, array, string, number, bool, None).
func (b *Builder) PythonValue() ParserID {
return b.LazyRule("python-value", func() ParserID {
return b.Choice(
b.PythonDict(), b.PythonArray(), b.PythonString(),
b.JSONNumber(), b.PythonBool(), b.PythonNull(),
)
})
}
// PythonString matches a Python string (double or single-quoted).
func (b *Builder) PythonString() ParserID {
return b.LazyRule("python-string", func() ParserID {
return b.Choice(b.DoubleQuotedString(), b.SingleQuotedString())
})
}
// PythonBool matches True or False.
func (b *Builder) PythonBool() ParserID {
return b.LazyRule("python-bool", func() ParserID {
return b.Seq(b.Choice(b.Literal("True"), b.Literal("False")), b.Space())
})
}
// PythonNull matches None.
func (b *Builder) PythonNull() ParserID {
return b.LazyRule("python-none", func() ParserID {
return b.Seq(b.Literal("None"), b.Space())
})
}
// PythonDict matches a Python dictionary {key: value, ...}.
func (b *Builder) PythonDict() ParserID {
return b.LazyRule("python-dict", func() ParserID {
member := b.Seq(b.PythonString(), b.Space(), b.Literal(":"), b.Space(), b.PythonValue())
return b.Seq(
b.Literal("{"), b.Space(),
b.Optional(b.Seq(member, b.ZeroOrMore(b.Seq(b.Space(), b.Literal(","), b.Space(), member)))),
b.Space(), b.Literal("}"), b.Space(),
)
})
}
// PythonArray matches a Python array [value, ...].
func (b *Builder) PythonArray() ParserID {
return b.LazyRule("python-array", func() ParserID {
return b.Seq(
b.Literal("["), b.Space(),
b.Optional(b.Seq(b.PythonValue(), b.ZeroOrMore(b.Seq(b.Space(), b.Literal(","), b.Space(), b.PythonValue())))),
b.Space(), b.Literal("]"), b.Space(),
)
})
}
// LazyRule creates a named rule with deferred construction to support recursion.
// If the rule already exists, returns a ref to it. Otherwise, creates a placeholder,
// builds the child, and replaces the placeholder.
func (b *Builder) LazyRule(name string, builderFn func() ParserID) ParserID {
cleanName := invalidRuleCharsRe.ReplaceAllString(name, "-")
if _, exists := b.arena.rules[cleanName]; exists {
return b.add(&RefParser{Name: cleanName})
}
// Create placeholder rule to allow recursive references
placeholderChild := b.add(&AnyParser{})
ruleID := b.add(&RuleParser{Name: cleanName, Child: placeholderChild})
b.arena.rules[cleanName] = ruleID
// Build the actual parser
child := builderFn()
// Update the rule with the real child
b.arena.parsers[ruleID] = &RuleParser{Name: cleanName, Child: child}
return b.add(&RefParser{Name: cleanName})
}
// Rule creates a named rule and returns a ref to it.
func (b *Builder) Rule(name string, child ParserID) ParserID {
cleanName := invalidRuleCharsRe.ReplaceAllString(name, "-")
ruleID := b.add(&RuleParser{Name: cleanName, Child: child})
b.arena.rules[cleanName] = ruleID
return b.add(&RefParser{Name: cleanName})
}
// TriggerRule creates a named rule marked as a trigger (for lazy grammar generation).
func (b *Builder) TriggerRule(name string, child ParserID) ParserID {
cleanName := invalidRuleCharsRe.ReplaceAllString(name, "-")
ruleID := b.add(&RuleParser{Name: cleanName, Child: child, Trigger: true})
b.arena.rules[cleanName] = ruleID
return b.add(&RefParser{Name: cleanName})
}
// Ref creates a forward reference to a named rule.
func (b *Builder) Ref(name string) ParserID {
return b.add(&RefParser{Name: name})
}
// Atomic creates a parser that suppresses partial AST nodes.
func (b *Builder) Atomic(child ParserID) ParserID {
return b.add(&AtomicParser{Child: child})
}
// Tag creates a semantic tag in the AST.
func (b *Builder) Tag(tag string, child ParserID) ParserID {
return b.add(&TagParser{Child: child, Tag: tag})
}
// Schema wraps a parser with schema metadata (pass-through at parse time).
func (b *Builder) Schema(child ParserID, name string) ParserID {
return b.add(&SchemaParser{Child: child, Name: name})
}
// SetRoot sets the root parser.
func (b *Builder) SetRoot(id ParserID) {
b.arena.root = id
}
// Build resolves references and returns the arena.
func (b *Builder) Build() *Arena {
b.arena.resolveRefs()
arena := b.arena
// Reset builder
b.arena = Arena{
rules: make(map[string]ParserID),
root: InvalidParserID,
}
return &arena
}
// parseCharClasses parses a character class expression and returns ranges and negation.
func parseCharClasses(classes string) ([]CharRange, bool) {
content := classes
negated := false
if len(content) > 0 && content[0] == '[' {
content = content[1:]
}
if len(content) > 0 && content[len(content)-1] == ']' {
content = content[:len(content)-1]
}
if len(content) > 0 && content[0] == '^' {
negated = true
content = content[1:]
}
var ranges []CharRange
i := 0
for i < len(content) {
startChar, startLen := ParseCharClassChar(content, i)
i += startLen
if i+1 < len(content) && content[i] == '-' {
endChar, endLen := ParseCharClassChar(content, i+1)
ranges = append(ranges, CharRange{Start: startChar, End: endChar})
i += 1 + endLen
} else {
ranges = append(ranges, CharRange{Start: startChar, End: startChar})
}
}
return ranges, negated
}
func ParseCharClassChar(content string, pos int) (rune, int) {
if content[pos] == '\\' && pos+1 < len(content) {
switch content[pos+1] {
case 'n':
return '\n', 2
case 't':
return '\t', 2
case 'r':
return '\r', 2
case '\\':
return '\\', 2
case ']':
return ']', 2
case '[':
return '[', 2
case 'x':
if r, n := parseHexEscape(content, pos+2, 2); n > 0 {
return r, 2 + n
}
return 'x', 2
case 'u':
if r, n := parseHexEscape(content, pos+2, 4); n > 0 {
return r, 2 + n
}
return 'u', 2
case 'U':
if r, n := parseHexEscape(content, pos+2, 8); n > 0 {
return r, 2 + n
}
return 'U', 2
default:
return rune(content[pos+1]), 2
}
}
return rune(content[pos]), 1
}
func parseHexEscape(s string, pos, count int) (rune, int) {
if pos+count > len(s) {
return 0, 0
}
var value rune
for i := 0; i < count; i++ {
c := s[pos+i]
value <<= 4
switch {
case c >= '0' && c <= '9':
value += rune(c - '0')
case c >= 'a' && c <= 'f':
value += rune(c-'a') + 10
case c >= 'A' && c <= 'F':
value += rune(c-'A') + 10
default:
return 0, 0
}
}
return value, count
}
// jsonNumberParser is a dedicated parser for JSON numbers used by JSONNumber().
type jsonNumberParser struct{}
func (p *jsonNumberParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
if start >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, start)
}
return NewParseResult(Fail, start)
}
if ctx.Input[start] == '-' || (ctx.Input[start] >= '0' && ctx.Input[start] <= '9') {
return parseJSONNumber(ctx, start, start)
}
return NewParseResult(Fail, start)
}
// BuildPegParser is a helper that creates a parser using a builder function.
func BuildPegParser(fn func(b *Builder) ParserID) *Arena {
b := NewBuilder()
root := fn(b)
b.SetRoot(root)
return b.Build()
}

954
pkg/functions/peg/chat.go Normal file
View File

@@ -0,0 +1,954 @@
package peg
import "encoding/json"
// Tag constants matching llama.cpp
const (
TagReasoningBlock = "reasoning-block"
TagReasoning = "reasoning"
TagContent = "content"
TagTool = "tool"
TagToolOpen = "tool-open"
TagToolClose = "tool-close"
TagToolID = "tool-id"
TagToolName = "tool-name"
TagToolArgs = "tool-args"
TagToolArg = "tool-arg"
TagToolArgOpen = "tool-arg-open"
TagToolArgClose = "tool-arg-close"
TagToolArgName = "tool-arg-name"
TagToolArgValue = "tool-arg-value"
TagToolArgStrVal = "tool-arg-string-value"
)
// ChatBuilder extends Builder with chat-specific tag helpers.
type ChatBuilder struct {
*Builder
}
func NewChatBuilder() *ChatBuilder {
return &ChatBuilder{Builder: NewBuilder()}
}
// Semantic tag wrappers
func (cb *ChatBuilder) ReasoningBlock(child ParserID) ParserID {
return cb.Tag(TagReasoningBlock, child)
}
func (cb *ChatBuilder) Reasoning(child ParserID) ParserID {
return cb.Tag(TagReasoning, child)
}
func (cb *ChatBuilder) Content(child ParserID) ParserID {
return cb.Tag(TagContent, child)
}
func (cb *ChatBuilder) Tool(child ParserID) ParserID {
return cb.Tag(TagTool, child)
}
func (cb *ChatBuilder) ToolOpen(child ParserID) ParserID {
return cb.Atomic(cb.Tag(TagToolOpen, child))
}
func (cb *ChatBuilder) ToolClose(child ParserID) ParserID {
return cb.Atomic(cb.Tag(TagToolClose, child))
}
func (cb *ChatBuilder) ToolID(child ParserID) ParserID {
return cb.Atomic(cb.Tag(TagToolID, child))
}
func (cb *ChatBuilder) ToolName(child ParserID) ParserID {
return cb.Atomic(cb.Tag(TagToolName, child))
}
func (cb *ChatBuilder) ToolArgs(child ParserID) ParserID {
return cb.Tag(TagToolArgs, child)
}
func (cb *ChatBuilder) ToolArg(child ParserID) ParserID {
return cb.Tag(TagToolArg, child)
}
func (cb *ChatBuilder) ToolArgOpen(child ParserID) ParserID {
return cb.Atomic(cb.Tag(TagToolArgOpen, child))
}
func (cb *ChatBuilder) ToolArgClose(child ParserID) ParserID {
return cb.Atomic(cb.Tag(TagToolArgClose, child))
}
func (cb *ChatBuilder) ToolArgName(child ParserID) ParserID {
return cb.Atomic(cb.Tag(TagToolArgName, child))
}
func (cb *ChatBuilder) ToolArgValue(child ParserID) ParserID {
return cb.Tag(TagToolArgValue, child)
}
func (cb *ChatBuilder) ToolArgStringValue(child ParserID) ParserID {
return cb.Tag(TagToolArgStrVal, child)
}
func (cb *ChatBuilder) ToolArgJSONValue(child ParserID) ParserID {
return cb.Atomic(cb.Tag(TagToolArgValue, child))
}
// TagWithSafeContent creates content parsing that avoids a marker string.
func (cb *ChatBuilder) TagWithSafeContent(tagName, marker string, p ParserID) ParserID {
if marker == "" {
return cb.ZeroOrMore(cb.Choice(p,
cb.Rule(tagName, cb.Content(cb.Any())),
))
}
contentChunk := cb.Rule(tagName,
cb.Content(cb.Seq(
cb.Negate(cb.Literal(marker)),
cb.Any(),
cb.Until(marker),
)),
)
return cb.ZeroOrMore(cb.Choice(p, contentChunk))
}
// ToolDef holds a tool definition used to build parsers.
type ToolDef struct {
Name string
Properties map[string]PropDef
}
// PropDef holds a property definition for tool arguments.
type PropDef struct {
Type string
}
// StandardConstructedTools builds XML/tagged-style tool parsers.
func (cb *ChatBuilder) StandardConstructedTools(
markers map[string]string,
tools []ToolDef,
parallelToolCalls bool,
forceToolCalls bool,
) ParserID {
getMarker := func(key, defaultVal string) string {
if v, ok := markers[key]; ok {
return v
}
return defaultVal
}
sectionStart := getMarker("tool_call_start_marker", "<tool_call>")
sectionEnd := getMarker("tool_call_end_marker", "</tool_call>")
funcOpener := getMarker("function_opener", "<function=")
funcNameSuffix := getMarker("function_name_suffix", ">")
funcCloser := getMarker("function_closer", "</function>")
paramKeyPrefix := getMarker("parameter_key_prefix", "<param=")
paramKeySuffix := getMarker("parameter_key_suffix", ">")
paramCloser := getMarker("parameter_closer", "</param>")
callIDPrefix := getMarker("call_id_prefix", "")
callIDSuffix := getMarker("call_id_suffix", "")
hasTaggedParams := paramKeyPrefix != ""
var toolChoices []ParserID
if len(tools) == 0 {
// Generic parser: accept any function name
var args ParserID
if hasTaggedParams {
// Tagged parameters: <param=key>value</param>
argRule := cb.ToolArg(cb.Seq(
cb.ToolArgOpen(cb.Literal(paramKeyPrefix)),
cb.ToolArgName(cb.Until(paramKeySuffix)),
cb.Literal(paramKeySuffix),
cb.ToolArgValue(cb.Until(paramCloser)),
cb.ToolArgClose(cb.Literal(paramCloser)),
))
args = cb.ToolArgs(cb.ZeroOrMore(cb.Seq(argRule, cb.Space())))
} else {
// JSON arguments: {"key": "val"}
args = cb.ToolArgs(cb.Until(funcCloser))
}
// Build optional call ID section (between function name and args)
callIDSection := cb.Eps()
if callIDPrefix != "" && callIDSuffix != "" {
callIDSection = cb.Optional(cb.Seq(
cb.Literal(callIDPrefix),
cb.ToolID(cb.Until(callIDSuffix)),
cb.Literal(callIDSuffix),
))
}
toolParser := cb.Tool(cb.Seq(
cb.ToolOpen(cb.Seq(
cb.Literal(funcOpener),
cb.ToolName(cb.Until(funcNameSuffix)),
cb.Literal(funcNameSuffix),
)),
callIDSection,
cb.Space(),
args,
cb.Space(),
cb.ToolClose(cb.Literal(funcCloser)),
))
toolChoices = append(toolChoices, cb.Rule("tool-generic", toolParser))
} else {
for _, tool := range tools {
// Build argument parsers
args := cb.Eps()
if hasTaggedParams && len(tool.Properties) > 0 {
var argParsers []ParserID
for propName := range tool.Properties {
argNameParser := cb.Choice(
cb.Literal(propName),
cb.Literal("\""+propName+"\""),
cb.Literal("'"+propName+"'"),
)
argRule := cb.ToolArg(cb.Seq(
cb.ToolArgOpen(cb.Literal(paramKeyPrefix)),
cb.ToolArgName(argNameParser),
cb.Literal(paramKeySuffix),
cb.ToolArgValue(cb.Until(paramCloser)),
cb.ToolArgClose(cb.Literal(paramCloser)),
))
argParsers = append(argParsers, argRule)
}
argChoice := cb.Choice(argParsers...)
args = cb.ZeroOrMore(cb.Seq(argChoice, cb.Space()))
} else if !hasTaggedParams {
// JSON arguments
args = cb.Until(funcCloser)
}
// Build optional call ID section
toolCallIDSection := cb.Eps()
if callIDPrefix != "" && callIDSuffix != "" {
toolCallIDSection = cb.Optional(cb.Seq(
cb.Literal(callIDPrefix),
cb.ToolID(cb.Until(callIDSuffix)),
cb.Literal(callIDSuffix),
))
}
// Build function parser
toolParser := cb.Tool(cb.Seq(
cb.ToolOpen(cb.Seq(
cb.Literal(funcOpener),
cb.ToolName(cb.Literal(tool.Name)),
cb.Literal(funcNameSuffix),
)),
toolCallIDSection,
cb.Space(),
cb.ToolArgs(args),
cb.Space(),
cb.ToolClose(cb.Literal(funcCloser)),
))
toolChoices = append(toolChoices, cb.Rule("tool-"+tool.Name, toolParser))
}
}
toolChoice := cb.Choice(toolChoices...)
var section ParserID
if parallelToolCalls {
section = cb.TriggerRule("tool-call", cb.Seq(
cb.Literal(sectionStart), cb.Space(),
cb.OneOrMore(cb.Seq(toolChoice, cb.Space())),
cb.Literal(sectionEnd),
))
} else {
section = cb.TriggerRule("tool-call", cb.Seq(
cb.Literal(sectionStart), cb.Space(),
toolChoice, cb.Space(),
cb.Literal(sectionEnd),
))
}
if forceToolCalls {
return section
}
return cb.Optional(section)
}
// StandardJSONToolsOpts holds options for building JSON tool call parsers.
type StandardJSONToolsOpts struct {
SectionStart string
SectionEnd string
Tools []ToolDef
ParallelCalls bool
ForceToolCalls bool
NameKey string
ArgsKey string
ArrayWrapped bool
FunctionIsKey bool
CallIDKey string
GenCallIDKey string
ParametersOrder []string
}
// StandardJSONTools builds JSON-format tool call parsers.
func (cb *ChatBuilder) StandardJSONTools(opts StandardJSONToolsOpts) ParserID {
if len(opts.Tools) == 0 {
return cb.Eps()
}
effectiveNameKey := opts.NameKey
if effectiveNameKey == "" {
effectiveNameKey = "name"
}
effectiveArgsKey := opts.ArgsKey
if effectiveArgsKey == "" {
effectiveArgsKey = "arguments"
}
var toolChoices ParserID
if opts.FunctionIsKey {
toolChoices = cb.buildJSONToolsFunctionIsKey(opts.Tools, opts.ArgsKey, effectiveArgsKey, opts.CallIDKey, opts.GenCallIDKey)
} else {
nameSpec := parseKeySpec(effectiveNameKey)
argsSpec := parseKeySpec(effectiveArgsKey)
if nameSpec.prefix != "" || argsSpec.prefix != "" {
toolChoices = cb.buildJSONToolsNestedKeys(opts.Tools, effectiveNameKey, effectiveArgsKey, opts.CallIDKey, opts.GenCallIDKey)
} else {
toolChoices = cb.buildJSONToolsFlatKeys(opts.Tools, effectiveNameKey, effectiveArgsKey, opts.CallIDKey, opts.GenCallIDKey, opts.ParametersOrder)
}
}
toolCalls := toolChoices
if opts.ParallelCalls {
toolCalls = cb.Seq(
toolChoices,
cb.ZeroOrMore(cb.Seq(cb.Space(), cb.Literal(","), cb.Space(), toolChoices)),
)
}
if opts.ArrayWrapped {
toolCalls = cb.Seq(cb.Literal("["), cb.Space(), toolCalls, cb.Space(), cb.Literal("]"))
}
section := cb.TriggerRule("tool-call", cb.Seq(
cb.Literal(opts.SectionStart), cb.Space(),
toolCalls, cb.Space(),
cb.Literal(opts.SectionEnd),
))
if opts.ForceToolCalls {
return section
}
return cb.Optional(section)
}
func (cb *ChatBuilder) buildJSONToolsFunctionIsKey(
tools []ToolDef,
argsKey, effectiveArgsKey, callIDKey, genCallIDKey string,
) ParserID {
var toolChoices []ParserID
for _, tool := range tools {
var innerFields []ParserID
if callIDKey != "" {
idParser := cb.Atomic(cb.Seq(
cb.Literal("\""+callIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\""),
))
innerFields = append(innerFields, cb.Optional(cb.Seq(idParser, cb.Space(), cb.Optional(cb.Seq(cb.Literal(","), cb.Space())))))
}
if genCallIDKey != "" {
genIDParser := cb.Atomic(cb.Seq(
cb.Literal("\""+genCallIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.Choice(
cb.Seq(cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\"")),
cb.ToolID(cb.JSONNumber()),
),
))
innerFields = append(innerFields, cb.Optional(cb.Seq(genIDParser, cb.Space(), cb.Optional(cb.Seq(cb.Literal(","), cb.Space())))))
}
// Arguments
var argsParser ParserID
if argsKey == "" {
argsParser = cb.ToolArgs(cb.JSON())
} else {
argsParser = cb.Seq(
cb.Literal("\""+effectiveArgsKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.ToolArgs(cb.JSON()),
)
}
innerFields = append(innerFields, argsParser)
// Build inner object
var innerObject ParserID
if argsKey == "" && len(innerFields) == 1 {
innerObject = innerFields[0]
} else {
innerObject = cb.Literal("{")
for i, f := range innerFields {
innerObject = cb.Seq(innerObject, cb.Space(), f)
if i < len(innerFields)-1 {
innerObject = cb.Seq(innerObject, cb.Space())
}
}
innerObject = cb.Seq(innerObject, cb.Space(), cb.Literal("}"))
}
toolParser := cb.Tool(cb.Seq(
cb.ToolOpen(cb.Literal("{")), cb.Space(),
cb.Literal("\""), cb.ToolName(cb.Literal(tool.Name)), cb.Literal("\""),
cb.Space(), cb.Literal(":"), cb.Space(),
innerObject,
cb.Space(), cb.ToolClose(cb.Literal("}")),
))
toolChoices = append(toolChoices, cb.Rule("tool-"+tool.Name, toolParser))
}
return cb.Choice(toolChoices...)
}
// keySpec represents a dot-notation key split into prefix and field.
type keySpec struct {
prefix string
field string
}
func parseKeySpec(key string) keySpec {
for i, c := range key {
if c == '.' {
return keySpec{prefix: key[:i], field: key[i+1:]}
}
}
return keySpec{field: key}
}
func (cb *ChatBuilder) buildJSONToolsNestedKeys(
tools []ToolDef,
effectiveNameKey, effectiveArgsKey, callIDKey, genCallIDKey string,
) ParserID {
var toolChoices []ParserID
nameSpec := parseKeySpec(effectiveNameKey)
argsSpec := parseKeySpec(effectiveArgsKey)
nestedPrefix := nameSpec.prefix
if nestedPrefix == "" {
nestedPrefix = argsSpec.prefix
}
nestedNameField := nameSpec.field
if nameSpec.prefix == "" {
nestedNameField = effectiveNameKey
}
nestedArgsField := argsSpec.field
if argsSpec.prefix == "" {
nestedArgsField = effectiveArgsKey
}
for _, tool := range tools {
nestedName := cb.Seq(
cb.Literal("\""+nestedNameField+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.Literal("\""), cb.ToolName(cb.Literal(tool.Name)), cb.Literal("\""),
)
nestedArgs := cb.Seq(
cb.Literal("\""+nestedArgsField+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.ToolArgs(cb.JSON()),
)
nestedObject := cb.Seq(
cb.Literal("{"), cb.Space(),
nestedName, cb.Space(), cb.Literal(","), cb.Space(),
nestedArgs,
cb.Space(), cb.Literal("}"),
)
toolParserBody := cb.ToolOpen(cb.Literal("{"))
if callIDKey != "" {
idSpec := parseKeySpec(callIDKey)
if idSpec.prefix == "" {
idParser := cb.Atomic(cb.Seq(
cb.Literal("\""+callIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\""),
))
toolParserBody = cb.Seq(toolParserBody, cb.Space(),
cb.Optional(cb.Seq(idParser, cb.Space(), cb.Literal(","), cb.Space())))
}
}
if genCallIDKey != "" {
genIDSpec := parseKeySpec(genCallIDKey)
if genIDSpec.prefix == "" {
genIDParser := cb.Atomic(cb.Seq(
cb.Literal("\""+genCallIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.Choice(
cb.Seq(cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\"")),
cb.ToolID(cb.JSONNumber()),
),
))
toolParserBody = cb.Seq(toolParserBody, cb.Space(),
cb.Optional(cb.Seq(genIDParser, cb.Space(), cb.Literal(","), cb.Space())))
}
}
nestedField := cb.Seq(
cb.Literal("\""+nestedPrefix+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
nestedObject,
)
toolParserBody = cb.Seq(toolParserBody, cb.Space(), nestedField, cb.Space(), cb.ToolClose(cb.Literal("}")))
toolChoices = append(toolChoices, cb.Rule("tool-"+tool.Name, cb.Tool(toolParserBody)))
}
return cb.Choice(toolChoices...)
}
func (cb *ChatBuilder) buildJSONToolsFlatKeys(
tools []ToolDef,
effectiveNameKey, effectiveArgsKey, callIDKey, genCallIDKey string,
parametersOrder []string,
) ParserID {
var toolChoices []ParserID
nameKeyParser := cb.Literal("\"" + effectiveNameKey + "\"")
argsKeyParser := cb.Literal("\"" + effectiveArgsKey + "\"")
for _, tool := range tools {
toolNameP := cb.Seq(
nameKeyParser, cb.Space(), cb.Literal(":"), cb.Space(),
cb.Literal("\""), cb.ToolName(cb.Literal(tool.Name)), cb.Literal("\""),
)
toolArgsP := cb.Seq(
argsKeyParser, cb.Space(), cb.Literal(":"), cb.Space(),
cb.ToolArgs(cb.JSON()),
)
pairs := []parserPair{
{toolNameP, effectiveNameKey},
{toolArgsP, effectiveArgsKey},
}
if callIDKey != "" {
idParser := cb.Atomic(cb.Seq(
cb.Literal("\""+callIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.Choice(
cb.Seq(cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\"")),
cb.ToolID(cb.JSONNumber()),
),
))
pairs = append(pairs, parserPair{cb.Optional(idParser), callIDKey})
}
if genCallIDKey != "" {
genIDParser := cb.Atomic(cb.Seq(
cb.Literal("\""+genCallIDKey+"\""), cb.Space(), cb.Literal(":"), cb.Space(),
cb.Choice(
cb.Seq(cb.Literal("\""), cb.ToolID(cb.JSONString()), cb.Literal("\"")),
cb.ToolID(cb.JSONNumber()),
),
))
pairs = append(pairs, parserPair{cb.Optional(genIDParser), genCallIDKey})
}
// Sort by parameters_order if provided
if len(parametersOrder) > 0 {
sortPairsByOrder(pairs, parametersOrder)
}
orderedBody := cb.ToolOpen(cb.Literal("{"))
for i, p := range pairs {
orderedBody = cb.Seq(orderedBody, cb.Space(), p.parser)
if i < len(pairs)-1 {
orderedBody = cb.Seq(orderedBody, cb.Space(), cb.Literal(","), cb.Space())
}
}
orderedBody = cb.Seq(orderedBody, cb.Space(), cb.ToolClose(cb.Literal("}")))
toolChoices = append(toolChoices, cb.Rule("tool-"+tool.Name, cb.Tool(orderedBody)))
}
return cb.Choice(toolChoices...)
}
type parserPair struct {
parser ParserID
key string
}
func sortPairsByOrder(pairs []parserPair, order []string) {
indexOf := func(key string) int {
for i, o := range order {
if o == key {
return i
}
}
return len(order)
}
// Simple insertion sort (small N)
for i := 1; i < len(pairs); i++ {
for j := i; j > 0 && indexOf(pairs[j].key) < indexOf(pairs[j-1].key); j-- {
pairs[j], pairs[j-1] = pairs[j-1], pairs[j]
}
}
}
// BuildChatPegParser is a convenience function to build a chat parser.
func BuildChatPegParser(fn func(cb *ChatBuilder) ParserID) *Arena {
cb := NewChatBuilder()
root := fn(cb)
cb.SetRoot(root)
return cb.Build()
}
// ToolCall represents a parsed tool call.
type ToolCall struct {
Name string
Arguments string
ID string
}
// ChatMsg represents a parsed chat message.
type ChatMsg struct {
Content string
ReasoningContent string
ToolCalls []ToolCall
}
// ChatPegMapper maps AST nodes to a ChatMsg.
type ChatPegMapper struct {
Result ChatMsg
pendingToolCall *ToolCall
currentTool *ToolCall
argCount int
closingQuotePend bool
argsBuffer string
}
func (m *ChatPegMapper) argsTarget() *string {
if m.currentTool != nil && m.currentTool.Name != "" {
return &m.currentTool.Arguments
}
return &m.argsBuffer
}
// FromAST populates the ChatMsg from parse results.
func (m *ChatPegMapper) FromAST(ast *AstArena, result *ParseResult) {
ast.VisitResult(result, func(node *AstNode) {
m.mapNode(node)
})
// Flush pending tool call
if m.pendingToolCall != nil && m.pendingToolCall.Name != "" {
if m.argsBuffer != "" {
m.pendingToolCall.Arguments = m.argsBuffer
}
if m.closingQuotePend && m.pendingToolCall.Arguments != "" {
m.pendingToolCall.Arguments += "\""
}
m.Result.ToolCalls = append(m.Result.ToolCalls, *m.pendingToolCall)
m.pendingToolCall = nil
}
}
func (m *ChatPegMapper) mapNode(node *AstNode) {
switch node.Tag {
case TagReasoning:
m.Result.ReasoningContent += node.Text
case TagContent:
m.Result.Content += node.Text
case TagToolOpen:
tc := ToolCall{}
m.pendingToolCall = &tc
m.currentTool = m.pendingToolCall
m.argCount = 0
m.argsBuffer = ""
m.closingQuotePend = false
case TagToolID:
if m.currentTool != nil {
text := trimTrailingSpace(node.Text)
if len(text) >= 2 && text[0] == '"' && text[len(text)-1] == '"' {
text = text[1 : len(text)-1]
}
m.currentTool.ID = text
}
case TagToolName:
if m.currentTool != nil {
m.currentTool.Name = trimTrailingSpace(node.Text)
if m.argsBuffer != "" {
m.currentTool.Arguments = m.argsBuffer
m.argsBuffer = ""
} else if m.currentTool.Arguments == "" {
m.currentTool.Arguments = "{"
}
// Add tool call to results for streaming
if m.pendingToolCall != nil {
m.Result.ToolCalls = append(m.Result.ToolCalls, *m.pendingToolCall)
m.pendingToolCall = nil
m.currentTool = &m.Result.ToolCalls[len(m.Result.ToolCalls)-1]
}
}
case TagToolArgs:
if m.currentTool != nil {
text := trimTrailingSpace(node.Text)
if len(text) > 0 && text[0] == '{' {
*m.argsTarget() = text
}
}
case TagToolArgOpen:
m.closingQuotePend = false
case TagToolArgName:
if m.currentTool != nil {
argEntry := ""
if m.argCount > 0 {
argEntry = ","
}
trimmed := trimSpace(node.Text)
escapedKey := escapeJSONString(trimmed)
argEntry += escapedKey + ":"
m.argCount++
target := m.argsTarget()
if *target == "" {
*target = "{"
}
*target += argEntry
}
case TagToolArgStrVal:
if m.currentTool != nil {
content := trimOneSpace(node.Text)
var valueToAdd string
if content == "" {
valueToAdd = "\""
m.closingQuotePend = true
} else {
if !m.closingQuotePend {
valueToAdd = "\""
m.closingQuotePend = true
}
valueToAdd += EscapeJSONStringInner(content)
}
*m.argsTarget() += valueToAdd
}
case TagToolArgValue:
if m.currentTool != nil {
content := trimOneSpace(node.Text)
var valueToAdd string
if content != "" {
isPotentialContainer := content[0] == '[' || content[0] == '{'
if isPotentialContainer {
content = NormalizeQuotesToJSON(content)
}
// Try to parse as JSON
var parsed json.RawMessage
if err := json.Unmarshal([]byte(content), &parsed); err == nil {
// Check if it's a string
var s string
if err2 := json.Unmarshal(parsed, &s); err2 == nil {
// It's a string — strip closing quote for monotonic streaming
escaped, _ := json.Marshal(s)
str := string(escaped)
if len(str) > 0 && str[len(str)-1] == '"' {
str = str[:len(str)-1]
}
valueToAdd = str
m.closingQuotePend = true
} else {
// Non-string: use raw content
valueToAdd = content
}
} else {
if node.IsPartial && isPotentialContainer {
valueToAdd = content
} else {
if !m.closingQuotePend {
valueToAdd = "\""
m.closingQuotePend = true
}
valueToAdd += EscapeJSONStringInner(content)
}
}
}
*m.argsTarget() += valueToAdd
}
case TagToolArgClose:
if m.currentTool != nil {
if m.closingQuotePend {
*m.argsTarget() += "\""
m.closingQuotePend = false
}
}
case TagToolClose:
if m.currentTool != nil {
// Flush buffer if tool name was never seen
if m.currentTool.Name == "" && m.argsBuffer != "" {
m.currentTool.Arguments = m.argsBuffer
m.argsBuffer = ""
}
if m.closingQuotePend {
m.currentTool.Arguments += "\""
m.closingQuotePend = false
}
// Close unclosed braces
for depth := jsonBraceDepth(m.currentTool.Arguments); depth > 0; depth-- {
m.currentTool.Arguments += "}"
}
// Add if pending and named
if m.pendingToolCall != nil {
if m.currentTool.Name != "" {
m.Result.ToolCalls = append(m.Result.ToolCalls, *m.pendingToolCall)
}
m.pendingToolCall = nil
}
}
}
}
// NormalizeQuotesToJSON converts Python-style single-quoted strings to JSON double-quoted.
func NormalizeQuotesToJSON(input string) string {
result := make([]byte, 0, len(input)+16)
inSingleQuoted := false
inDoubleQuoted := false
for i := 0; i < len(input); i++ {
c := input[i]
if c == '\\' && i+1 < len(input) {
next := input[i+1]
if inSingleQuoted {
if next == '\'' {
result = append(result, '\'')
i++
continue
}
if next == '"' {
result = append(result, '\\', '"')
i++
continue
}
result = append(result, c, next)
i++
continue
}
if inDoubleQuoted {
result = append(result, c, next)
i++
continue
}
result = append(result, c)
continue
}
if c == '"' {
if inSingleQuoted {
result = append(result, '\\', '"')
} else {
inDoubleQuoted = !inDoubleQuoted
result = append(result, c)
}
} else if c == '\'' {
if inDoubleQuoted {
result = append(result, c)
} else if inSingleQuoted {
inSingleQuoted = false
result = append(result, '"')
} else {
inSingleQuoted = true
result = append(result, '"')
}
} else {
result = append(result, c)
}
}
return string(result)
}
// EscapeJSONStringInner JSON-escapes a string and returns the inner content (without surrounding quotes).
func EscapeJSONStringInner(s string) string {
escaped, err := json.Marshal(s)
if err != nil {
return s
}
str := string(escaped)
if len(str) >= 2 && str[0] == '"' && str[len(str)-1] == '"' {
return str[1 : len(str)-1]
}
return str
}
func escapeJSONString(s string) string {
escaped, err := json.Marshal(s)
if err != nil {
return "\"" + s + "\""
}
return string(escaped)
}
func jsonBraceDepth(s string) int {
depth := 0
inString := false
escaped := false
for i := 0; i < len(s); i++ {
c := s[i]
if escaped {
escaped = false
continue
}
if c == '\\' && inString {
escaped = true
continue
}
if c == '"' {
inString = !inString
continue
}
if !inString {
if c == '{' {
depth++
} else if c == '}' {
depth--
}
}
}
return depth
}
func trimTrailingSpace(s string) string {
end := len(s)
for end > 0 && isWhitespace(s[end-1]) {
end--
}
return s[:end]
}
func trimLeadingSpace(s string, max int) string {
start := 0
count := 0
for start < len(s) && isWhitespace(s[start]) {
if max >= 0 && count >= max {
break
}
start++
count++
}
return s[start:]
}
func trimSpace(s string) string {
s = trimLeadingSpace(s, 1)
return trimTrailingSpace(s)
}
func trimOneSpace(s string) string {
s = trimLeadingSpace(s, 1)
end := len(s)
count := 0
for end > 0 && isWhitespace(s[end-1]) && count < 1 {
end--
count++
}
return s[:end]
}

View File

@@ -0,0 +1,910 @@
package peg_test
import (
"github.com/mudler/LocalAI/pkg/functions/peg"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func createTools() []peg.ToolDef {
return []peg.ToolDef{
{
Name: "get_current_weather",
Properties: map[string]peg.PropDef{
"location": {Type: "string"},
"unit": {Type: "string"},
},
},
{
Name: "get_forecast",
Properties: map[string]peg.PropDef{
"location": {Type: "string"},
"unit": {Type: "string"},
"days": {Type: "integer"},
},
},
{
Name: "search_knowledge_base",
Properties: map[string]peg.PropDef{
"query": {Type: "string"},
"max_results": {Type: "integer"},
"category": {Type: "string"},
},
},
}
}
func simpleTokenize(input string) []string {
var result []string
var current string
for _, c := range input {
switch c {
case ' ', '\n', '\t', '{', '}', ',', '[', '"', ']', '.', '<', '>', '=', '/':
if current != "" {
result = append(result, current)
current = ""
}
}
current += string(c)
}
if current != "" {
result = append(result, current)
}
return result
}
var _ = Describe("Chat PEG Parser", func() {
Context("ExampleNative", func() {
type testCase struct {
name string
tools []peg.ToolDef
reasoningFormat string
parallelCalls bool
forcedOpen bool
forceToolCalls bool
input string
expectReasoning string
expectContent string
expectToolCalls []peg.ToolCall
}
buildParser := func(tc testCase) *peg.Arena {
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
reasoningInContent := tc.reasoningFormat == "none"
var reasoning peg.ParserID
if tc.forcedOpen {
reasoning = p.Seq(
p.Reasoning(p.Until("</think>")),
p.Literal("</think>"),
p.Space(),
)
} else {
reasoning = p.Optional(p.Seq(
p.Literal("<think>"),
p.Reasoning(p.Until("</think>")),
p.Literal("</think>"),
p.Space(),
))
}
if len(tc.tools) > 0 {
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
SectionStart: "<tool_call>[",
SectionEnd: "]</tool_call>",
Tools: tc.tools,
ParallelCalls: tc.parallelCalls,
ForceToolCalls: tc.forceToolCalls,
})
var parts []peg.ParserID
if reasoningInContent {
parts = append(parts, p.Eps())
} else {
parts = append(parts, reasoning)
}
parts = append(parts,
p.Content(p.Until("<tool_call>")),
p.Optional(p.Seq(p.Space(), toolCall)),
p.Space(),
p.End(),
)
return p.Seq(parts...)
}
var parts []peg.ParserID
if reasoningInContent {
parts = append(parts, p.Eps())
} else {
parts = append(parts, reasoning)
}
parts = append(parts, p.Content(p.Rest()), p.End())
return p.Seq(parts...)
})
}
DescribeTable("native format cases",
func(tc testCase) {
parser := buildParser(tc)
ctx := peg.NewParseContext(tc.input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.Content).To(Equal(tc.expectContent))
Expect(msg.ReasoningContent).To(Equal(tc.expectReasoning))
Expect(msg.ToolCalls).To(HaveLen(len(tc.expectToolCalls)))
for i := 0; i < len(tc.expectToolCalls) && i < len(msg.ToolCalls); i++ {
Expect(msg.ToolCalls[i].Name).To(Equal(tc.expectToolCalls[i].Name))
Expect(msg.ToolCalls[i].Arguments).To(Equal(tc.expectToolCalls[i].Arguments))
}
},
Entry("content with thinking", testCase{
reasoningFormat: "auto",
input: "<think>The user said hello, I must say hello back</think>\nHello",
expectReasoning: "The user said hello, I must say hello back",
expectContent: "Hello",
}),
Entry("content without thinking", testCase{
reasoningFormat: "auto",
input: "Hello",
expectContent: "Hello",
}),
Entry("content with reasoning_format = none", testCase{
reasoningFormat: "none",
forcedOpen: true,
input: "<think>The user said hello, I must say hello back</think>\nHello",
expectContent: "<think>The user said hello, I must say hello back</think>\nHello",
}),
Entry("content with forced_open", testCase{
reasoningFormat: "auto",
forcedOpen: true,
input: "The user said hello, I must say hello back</think>\nHello",
expectReasoning: "The user said hello, I must say hello back",
expectContent: "Hello",
}),
Entry("content with forced_open and reasoning_format = none", testCase{
reasoningFormat: "none",
forcedOpen: true,
input: "The user said hello, I must say hello back</think>\nHello",
expectContent: "The user said hello, I must say hello back</think>\nHello",
}),
Entry("single tool call", testCase{
tools: createTools(),
reasoningFormat: "auto",
forcedOpen: true,
input: "I must get the weather in New York</think>\n" +
"<tool_call>[" +
`{"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}}` +
"]</tool_call>",
expectReasoning: "I must get the weather in New York",
expectToolCalls: []peg.ToolCall{
{
Name: "get_current_weather",
Arguments: `{"location": "New York City, NY", "unit": "fahrenheit"}`,
},
},
}),
Entry("parallel tool calls", testCase{
tools: createTools(),
reasoningFormat: "auto",
parallelCalls: true,
forcedOpen: true,
input: "I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me search that for you." +
"<tool_call>[" +
`{"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}}` +
", " +
`{"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}` +
", " +
`{"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}}` +
", " +
`{"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}}` +
"]</tool_call>",
expectReasoning: "I must get the weather in New York and San Francisco and a 3 day forecast of each.",
expectContent: "Let me search that for you.",
expectToolCalls: []peg.ToolCall{
{Name: "get_current_weather", Arguments: `{"location": "New York City, NY", "unit": "fahrenheit"}`},
{Name: "get_current_weather", Arguments: `{"location": "San Francisco, CA", "unit": "fahrenheit"}`},
{Name: "get_forecast", Arguments: `{"location": "New York City, NY", "unit": "fahrenheit", "days": 3}`},
{Name: "get_forecast", Arguments: `{"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}`},
},
}),
Entry("JSON schema response format", testCase{
tools: createTools(),
reasoningFormat: "auto",
forcedOpen: true,
forceToolCalls: false,
input: "Thinking about the answer</think>\n" +
`<tool_call>[{"name": "get_current_weather", "arguments": {"location": "NYC", "unit": "celsius"}}]</tool_call>`,
expectReasoning: "Thinking about the answer",
expectToolCalls: []peg.ToolCall{
{Name: "get_current_weather", Arguments: `{"location": "NYC", "unit": "celsius"}`},
},
}),
)
})
Context("ExampleQwen3Coder", func() {
It("parses tool calls with tagged parameters", func() {
tools := createTools()
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
content := p.Rule("content", p.Content(p.Until("<tool_call>")))
var toolParsers []peg.ParserID
for _, tool := range tools {
var argChoices []peg.ParserID
for propName, prop := range tool.Properties {
var argValueParser peg.ParserID
if prop.Type == "string" {
argValueParser = p.ToolArgStringValue(p.UntilOneOf("</parameter>\n<parameter=", "</parameter>\n</function>"))
} else {
argValueParser = p.ToolArgJSONValue(p.JSON())
}
arg := p.ToolArg(p.Seq(
p.ToolArgOpen(p.Literal("<parameter="+propName+">")),
argValueParser,
p.ToolArgClose(p.Seq(
p.Literal("</parameter>\n"),
p.Peek(p.Choice(p.Literal("<parameter="), p.Literal("</function>"))),
)),
))
argChoices = append(argChoices, arg)
}
argChoice := p.Choice(argChoices...)
args := p.ZeroOrMore(argChoice)
toolParser := p.Rule("tool-"+tool.Name, p.Seq(
p.ToolOpen(p.Seq(
p.Literal("<function="),
p.ToolName(p.Literal(tool.Name)),
p.Literal(">\n"),
)),
args,
p.ToolClose(p.Literal("</function>")),
))
toolParsers = append(toolParsers, toolParser)
}
toolCall := p.TriggerRule("tool-call", p.Seq(
p.Literal("<tool_call>"), p.Space(),
p.Choice(toolParsers...), p.Space(),
p.Literal("</tool_call>"),
))
return p.Seq(content, p.ZeroOrMore(p.Seq(p.Space(), toolCall)), p.End())
})
input := "Let me search the knowledge base for cat pictures." +
"<tool_call>\n" +
"<function=search_knowledge_base>\n" +
"<parameter=query>cat pictures</parameter>\n" +
"<parameter=category>general</parameter>\n" +
"</function>\n" +
"</tool_call>"
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.Content).To(Equal("Let me search the knowledge base for cat pictures."))
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("search_knowledge_base"))
Expect(msg.ToolCalls[0].Arguments).NotTo(BeEmpty())
})
})
Context("ExampleQwen3NonCoder", func() {
It("parses JSON tool calls", func() {
tools := createTools()
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
SectionStart: "<tool_call>",
SectionEnd: "</tool_call>",
Tools: tools,
ParallelCalls: true,
})
return p.Seq(
p.Content(p.Until("<tool_call>")),
p.Optional(p.Seq(p.Space(), toolCall)),
p.End(),
)
})
input := "I need to get the weather.\n" +
"<tool_call>" +
`{"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}}` +
"</tool_call>"
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.Content).To(Equal("I need to get the weather.\n"))
Expect(msg.ReasoningContent).To(BeEmpty())
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
Expect(msg.ToolCalls[0].Arguments).To(Equal(`{"location": "New York City, NY", "unit": "fahrenheit"}`))
})
})
Context("Command7", func() {
var parser *peg.Arena
BeforeEach(func() {
parser = peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
thinking := p.ReasoningBlock(p.Seq(
p.Literal("<|START_THINKING|>"), p.Space(),
p.Reasoning(p.Until("<|END_THINKING|>")), p.Space(),
p.Literal("<|END_THINKING|>"),
))
response := p.Seq(
p.Literal("<|START_RESPONSE|>"), p.Space(),
p.Content(p.Until("<|END_RESPONSE|>")), p.Space(),
p.Literal("<|END_RESPONSE|>"),
)
toolCallID := p.Atomic(p.Seq(
p.Literal("\"tool_call_id\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
))
toolCallName := p.Atomic(p.Seq(
p.Literal("\"tool_name\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
))
toolCallArgs := p.Seq(
p.Literal("\"parameters\""), p.Space(), p.Literal(":"), p.Space(),
p.ToolArgs(p.JSON()),
)
toolCallFields := p.Rule("tool-call-fields", p.Choice(toolCallID, toolCallName, toolCallArgs))
toolCall := p.Rule("tool-call-single", p.Tool(p.Seq(
p.ToolOpen(p.Literal("{")), p.Space(),
toolCallFields,
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCallFields)),
p.Space(), p.ToolClose(p.Literal("}")),
)))
toolCalls := p.Rule("tool-calls", p.Seq(
p.Literal("<|START_ACTION|>"), p.Space(),
p.Literal("["), p.Space(),
toolCall,
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCall)),
p.Space(), p.Literal("]"), p.Space(),
p.Literal("<|END_ACTION|>"),
))
return p.Seq(
p.Optional(p.Seq(thinking, p.Space())),
p.Choice(toolCalls, response),
p.End(),
)
})
})
It("parses tool call with reasoning", func() {
input := "<|START_THINKING|>I need to plan a trip to Japan.\n<|END_THINKING|>" +
"<|START_ACTION|>[" +
`{"tool_call_id": "call_0", "tool_name": "plan_trip", "parameters": {"destination": "Japan", "duration": 14}}` +
"]<|END_ACTION|>"
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.ReasoningContent).To(Equal("I need to plan a trip to Japan.\n"))
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("plan_trip"))
Expect(msg.ToolCalls[0].ID).To(Equal("call_0"))
})
It("parses content-only response", func() {
input := "<|START_RESPONSE|>Hello, how can I help you?<|END_RESPONSE|>"
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.Content).To(Equal("Hello, how can I help you?"))
Expect(msg.ToolCalls).To(BeEmpty())
})
})
Context("PrefixToolNames", func() {
var parser *peg.Arena
BeforeEach(func() {
tools := []peg.ToolDef{
{Name: "special_function", Properties: map[string]peg.PropDef{"arg1": {Type: "string"}}},
{Name: "special_function_with_opt", Properties: map[string]peg.PropDef{"arg1": {Type: "string"}}},
}
parser = peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardConstructedTools(
map[string]string{},
tools,
true,
false,
)
return p.Seq(
p.Content(p.Until("<tool_call>")),
p.Optional(p.Seq(p.Space(), toolCall)),
p.End(),
)
})
})
It("parses long tool name", func() {
input := "Let me call the function.<tool_call><function=special_function_with_opt><param=arg1>42</param></function></tool_call>"
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("special_function_with_opt"))
})
It("parses short tool name", func() {
input := "Let me call the function.<tool_call><function=special_function><param=arg1>42</param></function></tool_call>"
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("special_function"))
})
It("never prematurely matches during incremental parsing", func() {
input := "Let me call the function." +
"<tool_call>" +
"<function=special_function_with_opt>" +
"<param=arg1>42</param>" +
"</function>" +
"</tool_call>"
tokens := simpleTokenize(input)
var accumulated string
for i, tok := range tokens {
accumulated += tok
isPartial := i < len(tokens)-1
ctx := peg.NewParseContext(accumulated, isPartial)
result := parser.Parse(ctx)
Expect(result.Type).NotTo(Equal(peg.Fail), "parse failed at token %d, input: %s", i, accumulated)
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
for _, tc := range mapper.Result.ToolCalls {
Expect(tc.Name).NotTo(Equal("special_function"),
"premature tool name match at token %d, input: %s", i, accumulated)
}
}
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("special_function_with_opt"))
})
})
Context("IncrementalParsing", func() {
It("handles qwen3 coder format incrementally", func() {
tools := createTools()
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
content := p.Rule("content", p.Content(p.Until("<tool_call>")))
var toolParsers []peg.ParserID
for _, tool := range tools {
var argChoices []peg.ParserID
for propName, prop := range tool.Properties {
var argValueParser peg.ParserID
if prop.Type == "string" {
argValueParser = p.ToolArgStringValue(p.UntilOneOf("</parameter>\n<parameter=", "</parameter>\n</function>"))
} else {
argValueParser = p.ToolArgJSONValue(p.JSON())
}
arg := p.ToolArg(p.Seq(
p.ToolArgOpen(p.Literal("<parameter="+propName+">")),
argValueParser,
p.ToolArgClose(p.Seq(
p.Literal("</parameter>\n"),
p.Peek(p.Choice(p.Literal("<parameter="), p.Literal("</function>"))),
)),
))
argChoices = append(argChoices, arg)
}
argChoice := p.Choice(argChoices...)
args := p.ZeroOrMore(argChoice)
toolParser := p.Rule("tool-"+tool.Name, p.Seq(
p.ToolOpen(p.Seq(p.Literal("<function="), p.ToolName(p.Literal(tool.Name)), p.Literal(">\n"))),
args,
p.ToolClose(p.Literal("</function>")),
))
toolParsers = append(toolParsers, toolParser)
}
toolCall := p.TriggerRule("tool-call", p.Seq(
p.Literal("<tool_call>"), p.Space(),
p.Choice(toolParsers...), p.Space(),
p.Literal("</tool_call>"),
))
return p.Seq(content, p.ZeroOrMore(p.Seq(p.Space(), toolCall)), p.End())
})
input := "Let me search the knowledge base for cat pictures." +
"<tool_call>\n" +
"<function=search_knowledge_base>\n" +
"<parameter=query>cat pictures</parameter>\n" +
"<parameter=category>general</parameter>\n" +
"</function>\n" +
"</tool_call>"
tokens := simpleTokenize(input)
var accumulated string
var prevToolCalls int
for i, tok := range tokens {
accumulated += tok
isPartial := i < len(tokens)-1
ctx := peg.NewParseContext(accumulated, isPartial)
result := parser.Parse(ctx)
Expect(result.Type).NotTo(Equal(peg.Fail), "parse failed at token %d, input: %s", i, accumulated)
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
Expect(len(mapper.Result.ToolCalls)).To(BeNumerically(">=", prevToolCalls),
"tool call count decreased at token %d", i)
prevToolCalls = len(mapper.Result.ToolCalls)
}
})
It("handles qwen3 non-coder format incrementally", func() {
tools := createTools()
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
SectionStart: "<tool_call>",
SectionEnd: "</tool_call>",
Tools: tools,
ParallelCalls: true,
})
return p.Seq(
p.Content(p.Until("<tool_call>")),
p.Optional(p.Seq(p.Space(), toolCall)),
p.End(),
)
})
input := "I need to get the weather.\n" +
"<tool_call>" +
`{"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}}` +
"</tool_call>"
tokens := simpleTokenize(input)
var accumulated string
for i, tok := range tokens {
accumulated += tok
isPartial := i < len(tokens)-1
ctx := peg.NewParseContext(accumulated, isPartial)
result := parser.Parse(ctx)
Expect(result.Type).NotTo(Equal(peg.Fail), "parse failed at token %d, input: %s", i, accumulated)
}
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("get_current_weather"))
})
})
Context("Command7 complex input", func() {
It("parses complex reasoning and tool calls", func() {
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
thinking := p.ReasoningBlock(p.Seq(
p.Literal("<|START_THINKING|>"), p.Space(),
p.Reasoning(p.Until("<|END_THINKING|>")), p.Space(),
p.Literal("<|END_THINKING|>"),
))
toolCallID := p.Atomic(p.Seq(
p.Literal("\"tool_call_id\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
))
toolCallName := p.Atomic(p.Seq(
p.Literal("\"tool_name\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
))
toolCallArgs := p.Seq(
p.Literal("\"parameters\""), p.Space(), p.Literal(":"), p.Space(),
p.ToolArgs(p.JSON()),
)
toolCallFields := p.Rule("tool-call-fields", p.Choice(toolCallID, toolCallName, toolCallArgs))
toolCall := p.Rule("tool-call-single", p.Tool(p.Seq(
p.ToolOpen(p.Literal("{")), p.Space(),
toolCallFields,
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCallFields)),
p.Space(), p.ToolClose(p.Literal("}")),
)))
toolCalls := p.Rule("tool-calls", p.Seq(
p.Literal("<|START_ACTION|>"), p.Space(),
p.Literal("["), p.Space(),
toolCall,
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCall)),
p.Space(), p.Literal("]"), p.Space(),
p.Literal("<|END_ACTION|>"),
))
return p.Seq(
p.Optional(p.Seq(thinking, p.Space())),
toolCalls,
p.End(),
)
})
reasoning := "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " +
"budget of $4000 for a two-week stay, we need to:\n\n" +
"1. Identify key historical sites and modern attractions in Japan.\n" +
"2. Find affordable accommodation options that provide a balance between comfort and cost.\n" +
"3. Determine the best modes of transportation for getting around Japan.\n" +
"4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without " +
"overspending.\n" +
"5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees " +
"to attractions."
input := "<|START_THINKING|>" + reasoning + "<|END_THINKING|>" +
`<|START_ACTION|>[{"tool_call_id": "call_0", "tool_name": "plan_trip", "parameters": {"destination": "Japan", "duration": 14, "budget": 4000, "interests": ["historical sites", "modern attractions"], "accommodation_preferences": "affordable", "transportation_preferences": "efficient", "meal_preferences": "local cuisine"}}]<|END_ACTION|>`
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.ReasoningContent).To(Equal(reasoning))
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("plan_trip"))
Expect(msg.ToolCalls[0].ID).To(Equal("call_0"))
Expect(msg.ToolCalls[0].Arguments).To(ContainSubstring(`"interests"`))
Expect(msg.ToolCalls[0].Arguments).To(ContainSubstring(`"historical sites"`))
})
})
Context("ForceToolCalls", func() {
var parser *peg.Arena
BeforeEach(func() {
tools := createTools()
parser = peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
SectionStart: "<tool_call>[",
SectionEnd: "]</tool_call>",
Tools: tools,
ForceToolCalls: true,
})
return p.Seq(
p.Content(p.Until("<tool_call>")),
p.Space(),
toolCall,
p.Space(),
p.End(),
)
})
})
It("succeeds with tool call present", func() {
input := "Let me check." +
`<tool_call>[{"name": "get_current_weather", "arguments": {"location": "NYC", "unit": "celsius"}}]</tool_call>`
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
Expect(mapper.Result.ToolCalls).To(HaveLen(1))
Expect(mapper.Result.ToolCalls[0].Name).To(Equal("get_current_weather"))
})
It("fails without tool call", func() {
input := "Just a response without any tool calls."
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Fail))
})
})
Context("NestedKeysJSONTools", func() {
It("parses nested function.name and function.arguments keys", func() {
tools := []peg.ToolDef{
{
Name: "get_current_weather",
Properties: map[string]peg.PropDef{
"location": {Type: "string"},
},
},
}
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
SectionStart: "<tool_call>",
SectionEnd: "</tool_call>",
Tools: tools,
NameKey: "function.name",
ArgsKey: "function.arguments",
CallIDKey: "id",
})
return p.Seq(
p.Content(p.Until("<tool_call>")),
p.Optional(p.Seq(p.Space(), toolCall)),
p.End(),
)
})
input := `Let me check.<tool_call>{"id": "call_123", "function": {"name": "get_current_weather", "arguments": {"location": "NYC"}}}</tool_call>`
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
Expect(msg.ToolCalls[0].ID).To(Equal("call_123"))
})
})
Context("Command7 incremental", func() {
It("handles incremental parsing without regressions", func() {
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
thinking := p.ReasoningBlock(p.Seq(
p.Literal("<|START_THINKING|>"), p.Space(),
p.Reasoning(p.Until("<|END_THINKING|>")), p.Space(),
p.Literal("<|END_THINKING|>"),
))
toolCallID := p.Atomic(p.Seq(
p.Literal("\"tool_call_id\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
))
toolCallName := p.Atomic(p.Seq(
p.Literal("\"tool_name\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
))
toolCallArgs := p.Seq(
p.Literal("\"parameters\""), p.Space(), p.Literal(":"), p.Space(),
p.ToolArgs(p.JSON()),
)
toolCallFields := p.Rule("tool-call-fields", p.Choice(toolCallID, toolCallName, toolCallArgs))
toolCall := p.Rule("tool-call-single", p.Tool(p.Seq(
p.ToolOpen(p.Literal("{")), p.Space(),
toolCallFields,
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCallFields)),
p.Space(), p.ToolClose(p.Literal("}")),
)))
toolCalls := p.Rule("tool-calls", p.Seq(
p.Literal("<|START_ACTION|>"), p.Space(),
p.Literal("["), p.Space(),
toolCall,
p.ZeroOrMore(p.Seq(p.Literal(","), p.Space(), toolCall)),
p.Space(), p.Literal("]"), p.Space(),
p.Literal("<|END_ACTION|>"),
))
return p.Seq(
p.Optional(p.Seq(thinking, p.Space())),
toolCalls,
p.End(),
)
})
reasoning := "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " +
"budget of $4000 for a two-week stay, we need to:\n\n" +
"1. Identify key historical sites and modern attractions in Japan.\n" +
"2. Find affordable accommodation options.\n" +
"3. Determine the best modes of transportation.\n" +
"4. Create a day-by-day itinerary.\n" +
"5. Provide a detailed cost breakdown."
input := "<|START_THINKING|>" + reasoning + "<|END_THINKING|>" +
`<|START_ACTION|>[{"tool_call_id": "call_0", "tool_name": "plan_trip", "parameters": {"destination": "Japan", "duration": 14, "budget": 4000, "interests": ["historical sites", "modern attractions"]}}]<|END_ACTION|>`
tokens := simpleTokenize(input)
var accumulated string
var prevToolCalls int
for i, tok := range tokens {
accumulated += tok
isPartial := i < len(tokens)-1
ctx := peg.NewParseContext(accumulated, isPartial)
result := parser.Parse(ctx)
Expect(result.Type).NotTo(Equal(peg.Fail), "parse failed at token %d, accumulated length=%d", i, len(accumulated))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
Expect(len(mapper.Result.ToolCalls)).To(BeNumerically(">=", prevToolCalls),
"tool call count decreased at token %d", i)
prevToolCalls = len(mapper.Result.ToolCalls)
}
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.ReasoningContent).To(Equal(reasoning))
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("plan_trip"))
Expect(msg.ToolCalls[0].ID).To(Equal("call_0"))
})
})
})

831
pkg/functions/peg/parser.go Normal file
View File

@@ -0,0 +1,831 @@
package peg
// Parser is the interface all parser types implement.
type Parser interface {
parse(arena *Arena, ctx *ParseContext, start int) ParseResult
}
// EpsilonParser always succeeds, consumes nothing.
type EpsilonParser struct{}
func (p *EpsilonParser) parse(_ *Arena, _ *ParseContext, start int) ParseResult {
return NewParseResult(Success, start)
}
// StartParser matches start of input.
type StartParser struct{}
func (p *StartParser) parse(_ *Arena, _ *ParseContext, start int) ParseResult {
if start == 0 {
return NewParseResult(Success, start)
}
return NewParseResult(Fail, start)
}
// EndParser matches end of input.
type EndParser struct{}
func (p *EndParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
if start >= len(ctx.Input) {
return NewParseResult(Success, start)
}
return NewParseResult(Fail, start)
}
// LiteralParser matches an exact string.
type LiteralParser struct {
Literal string
}
func (p *LiteralParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
pos := start
for i := 0; i < len(p.Literal); i++ {
if pos >= len(ctx.Input) {
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
if ctx.Input[pos] != p.Literal[i] {
return NewParseResult(Fail, start)
}
pos++
}
return NewParseResultRange(Success, start, pos)
}
// SequenceParser matches children in order.
type SequenceParser struct {
Children []ParserID
}
func (p *SequenceParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
pos := start
var nodes []AstID
for _, childID := range p.Children {
result := arena.ParseAt(childID, ctx, pos)
if result.Type == Fail {
if ctx.IsPartial && result.End >= len(ctx.Input) {
return NewParseResultNodes(NeedMoreInput, start, result.End, nodes)
}
return NewParseResultRange(Fail, start, result.End)
}
nodes = append(nodes, result.Nodes...)
if result.Type == NeedMoreInput {
return NewParseResultNodes(NeedMoreInput, start, result.End, nodes)
}
pos = result.End
}
return NewParseResultNodes(Success, start, pos, nodes)
}
// ChoiceParser tries each alternative until one succeeds.
type ChoiceParser struct {
Children []ParserID
}
func (p *ChoiceParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
for _, childID := range p.Children {
result := arena.ParseAt(childID, ctx, start)
if result.Type != Fail {
return result
}
}
return NewParseResult(Fail, start)
}
// RepetitionParser matches min to max repetitions.
type RepetitionParser struct {
Child ParserID
MinCount int
MaxCount int // -1 for unbounded
}
func (p *RepetitionParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
pos := start
matchCount := 0
var nodes []AstID
for p.MaxCount == -1 || matchCount < p.MaxCount {
if pos >= len(ctx.Input) {
break
}
result := arena.ParseAt(p.Child, ctx, pos)
if result.Type == Success {
// Prevent infinite loop on empty matches
if result.End == pos {
break
}
nodes = append(nodes, result.Nodes...)
pos = result.End
matchCount++
continue
}
if result.Type == NeedMoreInput {
nodes = append(nodes, result.Nodes...)
return NewParseResultNodes(NeedMoreInput, start, result.End, nodes)
}
// Child failed
break
}
if p.MinCount > 0 && matchCount < p.MinCount {
if pos >= len(ctx.Input) && ctx.IsPartial {
return NewParseResultNodes(NeedMoreInput, start, pos, nodes)
}
return NewParseResultRange(Fail, start, pos)
}
return NewParseResultNodes(Success, start, pos, nodes)
}
// AndParser is a positive lookahead — succeeds if child succeeds, consumes nothing.
type AndParser struct {
Child ParserID
}
func (p *AndParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
result := arena.ParseAt(p.Child, ctx, start)
return NewParseResult(result.Type, start)
}
// NotParser is a negative lookahead — succeeds if child fails, consumes nothing.
type NotParser struct {
Child ParserID
}
func (p *NotParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
result := arena.ParseAt(p.Child, ctx, start)
if result.Type == Success {
return NewParseResult(Fail, start)
}
if result.Type == NeedMoreInput {
return NewParseResult(NeedMoreInput, start)
}
return NewParseResult(Success, start)
}
// AnyParser matches any single UTF-8 codepoint.
type AnyParser struct{}
func (p *AnyParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
_, size, status := parseUTF8Codepoint(ctx.Input, start)
if status == utf8Incomplete {
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResult(NeedMoreInput, start)
}
if status == utf8Invalid {
return NewParseResult(Fail, start)
}
return NewParseResultRange(Success, start, start+size)
}
// SpaceParser matches zero or more whitespace characters.
type SpaceParser struct{}
func (p *SpaceParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
pos := start
for pos < len(ctx.Input) {
c := ctx.Input[pos]
if c == ' ' || c == '\t' || c == '\n' || c == '\r' || c == '\v' || c == '\f' {
pos++
} else {
break
}
}
return NewParseResultRange(Success, start, pos)
}
// CharRange represents a range of Unicode codepoints.
type CharRange struct {
Start rune
End rune
}
func (r CharRange) Contains(cp rune) bool {
return cp >= r.Start && cp <= r.End
}
// CharsParser matches characters from a character class.
type CharsParser struct {
Pattern string
Ranges []CharRange
Negated bool
MinCount int
MaxCount int // -1 for unbounded
}
func (p *CharsParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
pos := start
matchCount := 0
for p.MaxCount == -1 || matchCount < p.MaxCount {
r, size, status := parseUTF8Codepoint(ctx.Input, pos)
if status == utf8Incomplete {
if matchCount >= p.MinCount {
return NewParseResultRange(Success, start, pos)
}
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
if status == utf8Invalid {
if matchCount >= p.MinCount {
return NewParseResultRange(Success, start, pos)
}
return NewParseResult(Fail, start)
}
matches := false
for _, cr := range p.Ranges {
if cr.Contains(r) {
matches = true
break
}
}
if p.Negated {
matches = !matches
}
if matches {
pos += size
matchCount++
} else {
break
}
}
if matchCount < p.MinCount {
if pos >= len(ctx.Input) && ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResultRange(Fail, start, pos)
}
return NewParseResultRange(Success, start, pos)
}
// JSONStringParser matches JSON string content (without quotes).
type JSONStringParser struct{}
func (p *JSONStringParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
pos := start
for pos < len(ctx.Input) {
c := ctx.Input[pos]
if c == '"' {
return NewParseResultRange(Success, start, pos)
}
if c == '\\' {
result := handleEscapeSequence(ctx, start, pos)
if result.Type != Success {
return result
}
pos = result.End
} else {
_, size, status := parseUTF8Codepoint(ctx.Input, pos)
if status == utf8Incomplete {
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
if status == utf8Invalid {
return NewParseResult(Fail, start)
}
pos += size
}
}
if !ctx.IsPartial {
return NewParseResultRange(Fail, start, pos)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
// PythonDictStringParser matches single-quoted string content (without quotes).
// Like JSONStringParser but terminates on single quote instead of double quote.
type PythonDictStringParser struct{}
func (p *PythonDictStringParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
pos := start
for pos < len(ctx.Input) {
c := ctx.Input[pos]
if c == '\'' {
return NewParseResultRange(Success, start, pos)
}
if c == '\\' {
result := handleEscapeSequence(ctx, start, pos)
if result.Type != Success {
return result
}
pos = result.End
} else {
_, size, status := parseUTF8Codepoint(ctx.Input, pos)
if status == utf8Incomplete {
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
if status == utf8Invalid {
return NewParseResult(Fail, start)
}
pos += size
}
}
if !ctx.IsPartial {
return NewParseResultRange(Fail, start, pos)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
func handleEscapeSequence(ctx *ParseContext, start int, pos int) ParseResult {
pos++ // consume '\'
if pos >= len(ctx.Input) {
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
switch ctx.Input[pos] {
case '"', '\'', '\\', '/', 'b', 'f', 'n', 'r', 't':
pos++
return NewParseResultRange(Success, start, pos)
case 'u':
return handleUnicodeEscape(ctx, start, pos)
default:
return NewParseResult(Fail, start)
}
}
func handleUnicodeEscape(ctx *ParseContext, start int, pos int) ParseResult {
pos++ // consume 'u'
for i := 0; i < 4; i++ {
if pos >= len(ctx.Input) {
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
if !isHexDigit(ctx.Input[pos]) {
return NewParseResult(Fail, start)
}
pos++
}
return NewParseResultRange(Success, start, pos)
}
func isHexDigit(c byte) bool {
return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')
}
// UntilParser matches everything until one of the delimiters is found.
type UntilParser struct {
Delimiters []string
}
func (p *UntilParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
matcher := newTrie(p.Delimiters)
pos := start
lastValidPos := start
for pos < len(ctx.Input) {
_, size, status := parseUTF8Codepoint(ctx.Input, pos)
if status == utf8Incomplete {
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResultRange(NeedMoreInput, start, lastValidPos)
}
if status == utf8Invalid {
return NewParseResult(Fail, start)
}
match := matcher.checkAt(ctx.Input, pos)
if match == trieCompleteMatch {
return NewParseResultRange(Success, start, pos)
}
if match == triePartialMatch {
return NewParseResultRange(Success, start, pos)
}
pos += size
lastValidPos = pos
}
if lastValidPos == len(ctx.Input) && ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, lastValidPos)
}
return NewParseResultRange(Success, start, lastValidPos)
}
// RuleParser creates an AST node with a rule name.
type RuleParser struct {
Name string
Child ParserID
Trigger bool
}
func (p *RuleParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
result := arena.ParseAt(p.Child, ctx, start)
if result.Type != Fail {
text := ""
if result.Start < len(ctx.Input) {
end := result.End
if end > len(ctx.Input) {
end = len(ctx.Input)
}
text = ctx.Input[result.Start:end]
}
nodeID := ctx.Ast.AddNode(
p.Name, "", result.Start, result.End, text,
result.Nodes, result.Type == NeedMoreInput,
)
return NewParseResultNodes(result.Type, result.Start, result.End, []AstID{nodeID})
}
return result
}
// RefParser references a named rule (resolved during Build).
type RefParser struct {
Name string
}
func (p *RefParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
ruleID := arena.GetRule(p.Name)
return arena.ParseAt(ruleID, ctx, start)
}
// AtomicParser suppresses partial AST nodes.
type AtomicParser struct {
Child ParserID
}
func (p *AtomicParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
result := arena.ParseAt(p.Child, ctx, start)
if result.Type == NeedMoreInput {
result.Nodes = nil
}
return result
}
// TagParser creates an AST node with a semantic tag.
type TagParser struct {
Child ParserID
Tag string
}
func (p *TagParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
result := arena.ParseAt(p.Child, ctx, start)
if result.Type != Fail {
text := ""
if result.Start < len(ctx.Input) {
end := result.End
if end > len(ctx.Input) {
end = len(ctx.Input)
}
text = ctx.Input[result.Start:end]
}
nodeID := ctx.Ast.AddNode(
"", p.Tag, result.Start, result.End, text,
result.Nodes, result.Type == NeedMoreInput,
)
return NewParseResultNodes(result.Type, result.Start, result.End, []AstID{nodeID})
}
return result
}
// SchemaParser wraps a parser with schema metadata (pass-through at parse time).
type SchemaParser struct {
Child ParserID
Name string
}
func (p *SchemaParser) parse(arena *Arena, ctx *ParseContext, start int) ParseResult {
return arena.ParseAt(p.Child, ctx, start)
}
// JSONParser matches a complete JSON value (object, array, string, number, bool, null).
type JSONParser struct {
arena *Arena
}
func (p *JSONParser) parse(_ *Arena, ctx *ParseContext, start int) ParseResult {
return parseJSONValue(ctx, start, start)
}
func isWhitespace(c byte) bool {
return c == ' ' || c == '\t' || c == '\n' || c == '\r'
}
func parseLiteralAt(ctx *ParseContext, start, pos int, lit string) ParseResult {
for i := 0; i < len(lit); i++ {
if pos+i >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos+i)
}
return NewParseResult(Fail, start)
}
if ctx.Input[pos+i] != lit[i] {
return NewParseResult(Fail, start)
}
}
return NewParseResultRange(Success, start, pos+len(lit))
}
func parseJSONString(ctx *ParseContext, start, pos int) ParseResult {
pos++ // skip opening "
for pos < len(ctx.Input) {
c := ctx.Input[pos]
if c == '"' {
return NewParseResultRange(Success, start, pos+1)
}
if c == '\\' {
pos++
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
switch ctx.Input[pos] {
case '"', '\\', '/', 'b', 'f', 'n', 'r', 't':
pos++
case 'u':
pos++
for i := 0; i < 4; i++ {
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
if !isHexDigit(ctx.Input[pos]) {
return NewParseResult(Fail, start)
}
pos++
}
default:
return NewParseResult(Fail, start)
}
} else {
_, size, status := parseUTF8Codepoint(ctx.Input, pos)
if status == utf8Incomplete {
if !ctx.IsPartial {
return NewParseResult(Fail, start)
}
return NewParseResultRange(NeedMoreInput, start, pos)
}
if status == utf8Invalid {
return NewParseResult(Fail, start)
}
pos += size
}
}
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
func parseJSONNumber(ctx *ParseContext, start, pos int) ParseResult {
p := pos
if p < len(ctx.Input) && ctx.Input[p] == '-' {
p++
}
if p >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, p)
}
return NewParseResult(Fail, start)
}
if ctx.Input[p] == '0' {
p++
} else if ctx.Input[p] >= '1' && ctx.Input[p] <= '9' {
p++
for p < len(ctx.Input) && ctx.Input[p] >= '0' && ctx.Input[p] <= '9' {
p++
}
} else {
return NewParseResult(Fail, start)
}
if p < len(ctx.Input) && ctx.Input[p] == '.' {
p++
digitStart := p
for p < len(ctx.Input) && ctx.Input[p] >= '0' && ctx.Input[p] <= '9' {
p++
}
if p == digitStart {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, p)
}
return NewParseResult(Fail, start)
}
}
if p < len(ctx.Input) && (ctx.Input[p] == 'e' || ctx.Input[p] == 'E') {
p++
if p < len(ctx.Input) && (ctx.Input[p] == '+' || ctx.Input[p] == '-') {
p++
}
digitStart := p
for p < len(ctx.Input) && ctx.Input[p] >= '0' && ctx.Input[p] <= '9' {
p++
}
if p == digitStart {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, p)
}
return NewParseResult(Fail, start)
}
}
// In partial mode, check if the next character could continue the number.
// This prevents premature commits (e.g. returning "3" when "3.14" is incoming).
if ctx.IsPartial && p >= len(ctx.Input) {
return NewParseResultRange(NeedMoreInput, start, p)
}
if ctx.IsPartial && p < len(ctx.Input) && isNumberContinuation(ctx.Input[p]) {
return NewParseResultRange(NeedMoreInput, start, p)
}
return NewParseResultRange(Success, start, p)
}
func isNumberContinuation(c byte) bool {
return (c >= '0' && c <= '9') || c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-'
}
func parseJSONObject(ctx *ParseContext, start, pos int) ParseResult {
pos++ // skip {
pos = skipWS(ctx.Input, pos)
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
if ctx.Input[pos] == '}' {
return NewParseResultRange(Success, start, pos+1)
}
for {
pos = skipWS(ctx.Input, pos)
// key
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
if ctx.Input[pos] != '"' {
return NewParseResult(Fail, start)
}
r := parseJSONString(ctx, start, pos)
if r.Type != Success {
return r
}
pos = r.End
pos = skipWS(ctx.Input, pos)
// colon
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
if ctx.Input[pos] != ':' {
return NewParseResult(Fail, start)
}
pos++
pos = skipWS(ctx.Input, pos)
// value
vr := parseJSONValue(ctx, start, pos)
if vr.Type != Success {
return vr
}
pos = vr.End
pos = skipWS(ctx.Input, pos)
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
if ctx.Input[pos] == '}' {
return NewParseResultRange(Success, start, pos+1)
}
if ctx.Input[pos] != ',' {
return NewParseResult(Fail, start)
}
pos++
}
}
func parseJSONArray(ctx *ParseContext, start, pos int) ParseResult {
pos++ // skip [
pos = skipWS(ctx.Input, pos)
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
if ctx.Input[pos] == ']' {
return NewParseResultRange(Success, start, pos+1)
}
for {
pos = skipWS(ctx.Input, pos)
vr := parseJSONValue(ctx, start, pos)
if vr.Type != Success {
return vr
}
pos = vr.End
pos = skipWS(ctx.Input, pos)
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
if ctx.Input[pos] == ']' {
return NewParseResultRange(Success, start, pos+1)
}
if ctx.Input[pos] != ',' {
return NewParseResult(Fail, start)
}
pos++
}
}
func parseJSONValue(ctx *ParseContext, start, pos int) ParseResult {
pos = skipWS(ctx.Input, pos)
if pos >= len(ctx.Input) {
if ctx.IsPartial {
return NewParseResultRange(NeedMoreInput, start, pos)
}
return NewParseResult(Fail, start)
}
switch ctx.Input[pos] {
case '{':
return parseJSONObject(ctx, start, pos)
case '[':
return parseJSONArray(ctx, start, pos)
case '"':
return parseJSONString(ctx, start, pos)
case 't':
return parseLiteralAt(ctx, start, pos, "true")
case 'f':
return parseLiteralAt(ctx, start, pos, "false")
case 'n':
return parseLiteralAt(ctx, start, pos, "null")
default:
if ctx.Input[pos] == '-' || (ctx.Input[pos] >= '0' && ctx.Input[pos] <= '9') {
return parseJSONNumber(ctx, start, pos)
}
return NewParseResult(Fail, start)
}
}
func skipWS(input string, pos int) int {
for pos < len(input) && isWhitespace(input[pos]) {
pos++
}
return pos
}

View File

@@ -0,0 +1,777 @@
package peg_test
import (
"github.com/mudler/LocalAI/pkg/functions/peg"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func extractTags(ast *peg.AstArena, result *peg.ParseResult) map[string]string {
tags := make(map[string]string)
ast.VisitResult(result, func(node *peg.AstNode) {
if node.Tag != "" {
tags[node.Tag] = node.Text
}
})
return tags
}
var _ = Describe("PEG Parser", func() {
Context("LiteralParser", func() {
It("succeeds on exact match", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Literal("hello")
})
ctx := peg.NewParseContext("hello world", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.Start).To(Equal(0))
Expect(r.End).To(Equal(5))
})
It("fails on mismatch", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Literal("hello")
})
ctx := peg.NewParseContext("world", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Fail))
})
It("returns NeedMoreInput in partial mode", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Literal("hello")
})
ctx := peg.NewParseContext("hel", true)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.NeedMoreInput))
})
It("fails on partial input when not in partial mode", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Literal("hello")
})
ctx := peg.NewParseContext("hel", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Fail))
})
})
Context("SequenceParser", func() {
It("matches full sequence", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("hello"), b.Literal(" "), b.Literal("world"))
})
ctx := peg.NewParseContext("hello world", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(11))
})
It("fails midway", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("hello"), b.Literal("X"))
})
ctx := peg.NewParseContext("hello world", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Fail))
})
It("returns NeedMoreInput at boundary in partial mode", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("hello"), b.Literal(" world"))
})
ctx := peg.NewParseContext("hello wo", true)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.NeedMoreInput))
})
})
Context("ChoiceParser", func() {
It("matches first alternative", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Choice(b.Literal("hello"), b.Literal("world"))
})
ctx := peg.NewParseContext("hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(5))
})
It("matches second alternative", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Choice(b.Literal("hello"), b.Literal("world"))
})
ctx := peg.NewParseContext("world", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(5))
})
It("fails when all alternatives fail", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Choice(b.Literal("hello"), b.Literal("world"))
})
ctx := peg.NewParseContext("foo", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Fail))
})
})
Context("RepetitionParser", func() {
It("handles zero or more matches", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.ZeroOrMore(b.Literal("ab"))
})
ctx := peg.NewParseContext("ababab", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(6))
ctx2 := peg.NewParseContext("xyz", false)
r2 := arena.Parse(ctx2)
Expect(r2.Type).To(Equal(peg.Success))
Expect(r2.End).To(Equal(0))
})
It("handles one or more matches", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.OneOrMore(b.Literal("ab"))
})
ctx := peg.NewParseContext("ababab", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(6))
ctx2 := peg.NewParseContext("xyz", false)
r2 := arena.Parse(ctx2)
Expect(r2.Type).To(Equal(peg.Fail))
})
It("handles optional matches", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Optional(b.Literal("hello")), b.Literal("world"))
})
ctx := peg.NewParseContext("helloworld", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(10))
ctx2 := peg.NewParseContext("world", false)
r2 := arena.Parse(ctx2)
Expect(r2.Type).To(Equal(peg.Success))
Expect(r2.End).To(Equal(5))
})
It("respects bounded repetition", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Repeat(b.Literal("a"), 2, 4)
})
ctx := peg.NewParseContext("aaaaa", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(4))
ctx2 := peg.NewParseContext("a", false)
r2 := arena.Parse(ctx2)
Expect(r2.Type).To(Equal(peg.Fail))
})
})
Context("Lookahead", func() {
It("succeeds with positive lookahead", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Peek(b.Literal("hello")), b.Literal("hello"))
})
ctx := peg.NewParseContext("hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(5))
})
It("fails with positive lookahead mismatch", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Peek(b.Literal("world")), b.Literal("hello"))
})
ctx := peg.NewParseContext("hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Fail))
})
It("succeeds with negative lookahead", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Negate(b.Literal("world")), b.Literal("hello"))
})
ctx := peg.NewParseContext("hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(5))
})
It("fails with negative lookahead match", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Negate(b.Literal("hello")), b.Literal("hello"))
})
ctx := peg.NewParseContext("hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Fail))
})
})
Context("UntilParser", func() {
It("consumes until single delimiter", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Until("<end>"), b.Literal("<end>"))
})
ctx := peg.NewParseContext("content<end>", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(12))
})
It("consumes until first of multiple delimiters", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.UntilOneOf("<a>", "<b>")
})
ctx := peg.NewParseContext("content<b>more", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(7))
})
It("consumes rest of input", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Rest()
})
ctx := peg.NewParseContext("everything", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(10))
})
It("returns NeedMoreInput in partial mode", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Until("<end>")
})
ctx := peg.NewParseContext("content", true)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.NeedMoreInput))
})
})
Context("JSONParser", func() {
It("parses objects", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.JSON()
})
ctx := peg.NewParseContext(`{"key": "value", "num": 42}`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(27))
})
It("parses arrays", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.JSON()
})
ctx := peg.NewParseContext(`[1, "two", true, null]`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(22))
})
It("parses strings with escapes", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.JSON()
})
ctx := peg.NewParseContext(`"hello \"world\""`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(17))
})
It("parses numbers", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.JSON()
})
ctx := peg.NewParseContext(`-123.45e10`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(10))
})
It("parses booleans", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.JSON()
})
ctx := peg.NewParseContext(`true`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(4))
})
It("parses null", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.JSON()
})
ctx := peg.NewParseContext(`null`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(4))
})
It("parses nested structures", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.JSON()
})
input := `{"a": [1, {"b": true}], "c": null}`
ctx := peg.NewParseContext(input, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(len(input)))
})
})
Context("Tag extraction", func() {
It("extracts basic tags", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(
b.Tag("greeting", b.Until(" ")),
b.Literal(" "),
b.Tag("name", b.Rest()),
)
})
ctx := peg.NewParseContext("Hello World", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
tags := extractTags(&ctx.Ast, &r)
Expect(tags["greeting"]).To(Equal("Hello"))
Expect(tags["name"]).To(Equal("World"))
})
It("extracts structured tags", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(
b.Tag("header", b.Until("\n")),
b.Literal("\n"),
b.Tag("body", b.Rest()),
)
})
ctx := peg.NewParseContext("Title\nBody content here", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
tags := extractTags(&ctx.Ast, &r)
Expect(tags["header"]).To(Equal("Title"))
Expect(tags["body"]).To(Equal("Body content here"))
})
It("overwrites duplicate tags", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(
b.Tag("item", b.Until(",")),
b.Literal(","),
b.Tag("item", b.Rest()),
)
})
ctx := peg.NewParseContext("first,second", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
tags := extractTags(&ctx.Ast, &r)
Expect(tags["item"]).To(Equal("second"))
})
It("returns empty map when no tags", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Rest()
})
ctx := peg.NewParseContext("Hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
tags := extractTags(&ctx.Ast, &r)
Expect(tags).To(HaveLen(0))
})
})
Context("Rule and Ref", func() {
It("handles named rules", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
word := b.Rule("word", b.Chars("[a-z]", 1, -1))
return b.Seq(word, b.Literal(" "), word)
})
ctx := peg.NewParseContext("hello world", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(11))
})
It("handles forward references", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
ref := b.Ref("greeting")
b.Rule("greeting", b.Literal("hello"))
return ref
})
ctx := peg.NewParseContext("hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(5))
})
})
Context("AtomicParser", func() {
It("suppresses partial AST nodes", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Atomic(b.Tag("test", b.Literal("hello world")))
})
ctx := peg.NewParseContext("hello", true)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.NeedMoreInput))
Expect(r.Nodes).To(HaveLen(0))
})
})
Context("Start and End parsers", func() {
It("matches at start of input", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Start(), b.Literal("hello"))
})
ctx := peg.NewParseContext("hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
})
It("matches at end of input", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("hello"), b.End())
})
ctx := peg.NewParseContext("hello", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
ctx2 := peg.NewParseContext("hello world", false)
r2 := arena.Parse(ctx2)
Expect(r2.Type).To(Equal(peg.Fail))
})
})
Context("Partial parsing", func() {
It("extracts tags during partial parse", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(
b.Tag("prefix", b.Until(":")),
b.Literal(":"),
b.Tag("value", b.Rest()),
)
})
ctx := peg.NewParseContext("key:val", true)
r := arena.Parse(ctx)
Expect(r.Type).NotTo(Equal(peg.Fail))
tags := extractTags(&ctx.Ast, &r)
Expect(tags["prefix"]).To(Equal("key"))
})
})
Context("ParseAnywhere", func() {
It("finds pattern in middle of input", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(
b.Choice(b.Literal("{"), b.Literal(":")),
b.Space(),
b.Literal("\""),
b.Atomic(b.Literal("fun_name")),
)
})
input := `This is a very long jinja template string... <tool_call>{ "fun_name" : { "arg" : 1 }</tool_call>`
found := false
for i := 0; i < len(input); i++ {
ctx := peg.NewParseContext(input, false)
r := arena.ParseFrom(ctx, i)
if r.Type == peg.Success {
found = true
break
}
}
Expect(found).To(BeTrue())
})
It("fails when pattern is not found", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(
b.Choice(b.Literal("{"), b.Literal(":")),
b.Space(),
b.Literal("\""),
b.Atomic(b.Literal("fun_name")),
)
})
input := `This is a very long jinja template string... <tool_call><fun=fun_name><arg name=arg>1</arg></tool_call>`
found := false
for i := 0; i < len(input); i++ {
ctx := peg.NewParseContext(input, false)
r := arena.ParseFrom(ctx, i)
if r.Type == peg.Success {
found = true
break
}
}
Expect(found).To(BeFalse())
})
})
Context("CharsParser", func() {
It("matches lowercase letters", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Chars("[a-z]", 1, -1)
})
ctx := peg.NewParseContext("hello123", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(5))
})
It("matches negated character class", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Chars("[^0-9]", 1, -1)
})
ctx := peg.NewParseContext("hello123", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(5))
})
})
Context("JSONStringParser", func() {
It("parses basic strings", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("\""), b.JSONString(), b.Literal("\""))
})
ctx := peg.NewParseContext(`"hello world"`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(13))
})
It("parses strings with escapes", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("\""), b.JSONString(), b.Literal("\""))
})
ctx := peg.NewParseContext(`"hello \"world\""`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(17))
})
It("parses strings with unicode escapes", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("\""), b.JSONString(), b.Literal("\""))
})
ctx := peg.NewParseContext(`"hello \u0041"`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(14))
})
})
Context("SpaceParser", func() {
It("matches whitespace", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("a"), b.Space(), b.Literal("b"))
})
ctx := peg.NewParseContext("a b", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(5))
})
It("matches zero whitespace", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("a"), b.Space(), b.Literal("b"))
})
ctx := peg.NewParseContext("ab", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(2))
})
})
Context("PythonDictStringParser", func() {
It("parses basic single-quoted strings", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("'"), b.PythonDictString(), b.Literal("'"))
})
ctx := peg.NewParseContext("'hello world'", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
Expect(r.End).To(Equal(13))
})
It("handles escaped single quotes", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("'"), b.PythonDictString(), b.Literal("'"))
})
ctx := peg.NewParseContext(`'it\'s fine'`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
})
It("handles double quotes inside single-quoted strings", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("'"), b.PythonDictString(), b.Literal("'"))
})
ctx := peg.NewParseContext(`'He said "hi"'`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
})
})
Context("peg.ParseCharClassChar", func() {
It("parses \\x hex escape", func() {
r, n := peg.ParseCharClassChar(`\x41`, 0)
Expect(r).To(Equal('A'))
Expect(n).To(Equal(4))
})
It("parses \\u unicode escape", func() {
r, n := peg.ParseCharClassChar(`\u0041`, 0)
Expect(r).To(Equal('A'))
Expect(n).To(Equal(6))
})
It("parses \\U unicode escape", func() {
r, n := peg.ParseCharClassChar(`\U00000041`, 0)
Expect(r).To(Equal('A'))
Expect(n).To(Equal(10))
})
It("falls back on invalid hex", func() {
r, n := peg.ParseCharClassChar(`\xZZ`, 0)
Expect(r).To(Equal('x'))
Expect(n).To(Equal(2))
})
})
Context("ParseAnywhere method", func() {
It("finds pattern in middle of input", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Seq(b.Literal("needle"), b.Tag("after", b.Until(".")))
})
ctx := peg.NewParseContext("some hay needle found.", false)
r := arena.ParseAnywhere(ctx)
Expect(r.Type).To(Equal(peg.Success))
tags := extractTags(&ctx.Ast, &r)
Expect(tags["after"]).To(Equal(" found"))
})
It("finds function tag with name", func() {
haystack := "\n<tool_call>\n<function=foofoo>\n<parameter=first>\nXXXX\n</parameter>\n<parameter=second>\nYYYY\n</parameter>\n</function>\n</tool_call>\n"
needle := "foofoo"
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Tag("fun_marker", b.Choice(
b.Seq(
b.Tag("fun_pre", b.Seq(b.Literal("<"), b.UntilOneOf(">", needle))),
b.Literal(needle),
b.Tag("fun_post", b.Seq(
b.Seq(b.Negate(b.Seq(b.Space(), b.Literal("<"))), b.Until(">"), b.Literal(">")),
)),
b.Space(),
),
b.Seq(
b.Tag("fun_pre", b.Seq(b.Literal("["), b.UntilOneOf("]", needle))),
b.Literal(needle),
b.Tag("fun_post", b.Seq(
b.Negate(b.Seq(b.Space(), b.Seq(b.Literal("["), b.Until("]"), b.Literal("]")))),
b.Space(),
)),
),
))
})
ctx := peg.NewParseContext(haystack, false)
r := arena.ParseAnywhere(ctx)
Expect(r.Type).To(Equal(peg.Success))
tags := extractTags(&ctx.Ast, &r)
Expect(tags["fun_pre"]).To(Equal("<function="))
Expect(tags["fun_post"]).To(Equal(">"))
})
})
Context("LazyRule", func() {
It("handles recursive JSON-like structures", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
b.LazyRule("value", func() peg.ParserID {
str := b.Seq(b.Literal("\""), b.JSONString(), b.Literal("\""))
arr := b.Seq(
b.Literal("["), b.Space(),
b.Ref("value"),
b.ZeroOrMore(b.Seq(b.Space(), b.Literal(","), b.Space(), b.Ref("value"))),
b.Space(), b.Literal("]"),
)
return b.Choice(str, arr)
})
return b.Ref("value")
})
ctx := peg.NewParseContext(`["hello",["world","nested"]]`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
})
It("parses python dicts", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.PythonDict()
})
ctx := peg.NewParseContext(`{'key': 'value', 'num': 42}`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
})
It("parses nested python values", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.PythonValue()
})
ctx := peg.NewParseContext(`{'outer': {'inner': [1, 2, 'three']}}`, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
})
It("parses python booleans and None", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.PythonValue()
})
for _, input := range []string{"True", "False", "None"} {
ctx := peg.NewParseContext(input, false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
}
})
})
Context("Marker", func() {
It("matches angle brackets", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Marker()
})
ctx := peg.NewParseContext("<tool_call>", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
})
It("matches square brackets", func() {
arena := peg.BuildPegParser(func(b *peg.Builder) peg.ParserID {
return b.Marker()
})
ctx := peg.NewParseContext("[TOOL]", false)
r := arena.Parse(ctx)
Expect(r.Type).To(Equal(peg.Success))
})
})
})

View File

@@ -0,0 +1,13 @@
package peg_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestPeg(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "PEG Parser test suite")
}

80
pkg/functions/peg/trie.go Normal file
View File

@@ -0,0 +1,80 @@
package peg
// trie is used for multi-delimiter matching in UntilParser.
type trie struct {
nodes []trieNode
}
type trieNode struct {
children map[rune]int
isWord bool
}
type trieMatch int
const (
trieNoMatch trieMatch = 0
triePartialMatch trieMatch = 1
trieCompleteMatch trieMatch = 2
)
func newTrie(words []string) *trie {
t := &trie{}
t.createNode() // root
for _, w := range words {
t.insert(w)
}
return t
}
func (t *trie) createNode() int {
idx := len(t.nodes)
t.nodes = append(t.nodes, trieNode{children: make(map[rune]int)})
return idx
}
func (t *trie) insert(word string) {
current := 0
for _, ch := range word {
if next, ok := t.nodes[current].children[ch]; ok {
current = next
} else {
child := t.createNode()
t.nodes[current].children[ch] = child
current = child
}
}
t.nodes[current].isWord = true
}
// checkAt checks if any delimiter starts at position pos in the input.
func (t *trie) checkAt(input string, pos int) trieMatch {
current := 0
p := pos
for p < len(input) {
r, size, status := parseUTF8Codepoint(input, p)
if status != utf8Success {
break
}
next, ok := t.nodes[current].children[r]
if !ok {
return trieNoMatch
}
current = next
p += size
if t.nodes[current].isWord {
return trieCompleteMatch
}
}
// Reached end of input while still in the trie
if current != 0 {
return triePartialMatch
}
return trieNoMatch
}

175
pkg/functions/peg/types.go Normal file
View File

@@ -0,0 +1,175 @@
package peg
import "unicode/utf8"
// ParserID is a unique identifier for a parser in the arena.
type ParserID int
const InvalidParserID ParserID = -1
// AstID is a unique identifier for an AST node.
type AstID int
const InvalidAstID AstID = -1
// ParseResultType indicates the outcome of a parse attempt.
type ParseResultType int
const (
Fail ParseResultType = 0
Success ParseResultType = 1
NeedMoreInput ParseResultType = 2
)
func (t ParseResultType) String() string {
switch t {
case Fail:
return "fail"
case Success:
return "success"
case NeedMoreInput:
return "need_more_input"
default:
return "unknown"
}
}
// ParseResult holds the result of a parse operation.
type ParseResult struct {
Type ParseResultType
Start int
End int
Nodes []AstID
}
func NewParseResult(typ ParseResultType, start int) ParseResult {
return ParseResult{Type: typ, Start: start, End: start}
}
func NewParseResultRange(typ ParseResultType, start, end int) ParseResult {
return ParseResult{Type: typ, Start: start, End: end}
}
func NewParseResultNodes(typ ParseResultType, start, end int, nodes []AstID) ParseResult {
return ParseResult{Type: typ, Start: start, End: end, Nodes: nodes}
}
// AstNode is a node in the parse AST.
type AstNode struct {
ID AstID
Rule string
Tag string
Start int
End int
Text string
Children []AstID
IsPartial bool
}
// AstArena stores AST nodes.
type AstArena struct {
nodes []AstNode
}
func (a *AstArena) AddNode(rule, tag string, start, end int, text string, children []AstID, isPartial bool) AstID {
id := AstID(len(a.nodes))
a.nodes = append(a.nodes, AstNode{
ID: id,
Rule: rule,
Tag: tag,
Start: start,
End: end,
Text: text,
Children: children,
IsPartial: isPartial,
})
return id
}
func (a *AstArena) Get(id AstID) *AstNode {
return &a.nodes[id]
}
func (a *AstArena) Size() int {
return len(a.nodes)
}
func (a *AstArena) Clear() {
a.nodes = a.nodes[:0]
}
// Visit traverses the AST tree rooted at the given node, calling fn for each node.
func (a *AstArena) Visit(id AstID, fn func(*AstNode)) {
if id == InvalidAstID {
return
}
node := a.Get(id)
fn(node)
for _, child := range node.Children {
a.Visit(child, fn)
}
}
// VisitResult traverses all top-level nodes in a parse result.
func (a *AstArena) VisitResult(result *ParseResult, fn func(*AstNode)) {
for _, id := range result.Nodes {
a.Visit(id, fn)
}
}
// ParseContext holds the state for a parse operation.
type ParseContext struct {
Input string
IsPartial bool
Debug bool
Ast AstArena
}
func NewParseContext(input string, isPartial bool) *ParseContext {
return &ParseContext{
Input: input,
IsPartial: isPartial,
}
}
// parseUTF8Codepoint parses a single UTF-8 codepoint at position pos.
// Returns the codepoint, bytes consumed, and status.
type utf8Status int
const (
utf8Success utf8Status = 0
utf8Incomplete utf8Status = 1
utf8Invalid utf8Status = 2
)
func parseUTF8Codepoint(input string, pos int) (rune, int, utf8Status) {
if pos >= len(input) {
return 0, 0, utf8Incomplete
}
r, size := utf8.DecodeRuneInString(input[pos:])
if r == utf8.RuneError {
if size == 0 {
return 0, 0, utf8Incomplete
}
// Could be incomplete multi-byte sequence
b := input[pos]
var expectedLen int
switch {
case b&0x80 == 0:
expectedLen = 1
case b&0xE0 == 0xC0:
expectedLen = 2
case b&0xF0 == 0xE0:
expectedLen = 3
case b&0xF8 == 0xF0:
expectedLen = 4
default:
return 0, 0, utf8Invalid
}
if pos+expectedLen > len(input) {
return 0, 0, utf8Incomplete
}
return 0, 0, utf8Invalid
}
return r, size, utf8Success
}

View File

@@ -0,0 +1,268 @@
package peg_test
import (
"encoding/json"
"github.com/mudler/LocalAI/pkg/functions/peg"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("PEG Utils", func() {
Context("peg.NormalizeQuotesToJSON", func() {
It("converts basic single quotes to double quotes", func() {
input := "{'key': 'value'}"
expected := `{"key": "value"}`
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
})
It("handles escaped single quotes", func() {
input := `{'code': 'print(\'hello\')'}`
expected := `{"code": "print('hello')"}`
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
})
It("handles double quotes inside single-quoted strings", func() {
input := `{'msg': 'He said "hi"'}`
expected := `{"msg": "He said \"hi\""}`
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
})
It("handles nested backslash escapes", func() {
input := `{'path': 'C:\\Users\\test'}`
expected := `{"path": "C:\\Users\\test"}`
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
})
It("handles newline escapes", func() {
input := `{'text': 'line1\nline2'}`
expected := `{"text": "line1\nline2"}`
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
})
It("handles mixed quotes", func() {
input := `{"already_double": 'single_value'}`
expected := `{"already_double": "single_value"}`
Expect(peg.NormalizeQuotesToJSON(input)).To(Equal(expected))
})
It("handles embedded quotes complex case", func() {
input := `{'filename': 'foo.cpp', 'oldString': 'def foo(arg = "14"):\n return arg + "bar"\n', 'newString': 'def foo(arg = "15"):\n pass\n'}`
result := peg.NormalizeQuotesToJSON(input)
var parsed map[string]string
err := json.Unmarshal([]byte(result), &parsed)
Expect(err).NotTo(HaveOccurred(), "result is not valid JSON: %s", result)
Expect(parsed["filename"]).To(Equal("foo.cpp"))
Expect(parsed["oldString"]).NotTo(BeEmpty())
})
})
Context("peg.EscapeJSONStringInner", func() {
It("leaves basic strings unchanged", func() {
Expect(peg.EscapeJSONStringInner("hello")).To(Equal("hello"))
})
It("escapes double quotes", func() {
Expect(peg.EscapeJSONStringInner(`hello "world"`)).To(Equal(`hello \"world\"`))
})
It("escapes backslash-n sequences", func() {
Expect(peg.EscapeJSONStringInner(`line1\nline2`)).To(Equal(`line1\\nline2`))
})
})
Context("StandardJSONTools OpenAI format", func() {
It("parses OpenAI-style tool calls with call ID", func() {
tools := []peg.ToolDef{
{
Name: "get_current_weather",
Properties: map[string]peg.PropDef{
"location": {Type: "string"},
},
},
}
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
SectionStart: "<tool_call>",
SectionEnd: "</tool_call>",
Tools: tools,
CallIDKey: "id",
ParametersOrder: []string{"id", "name", "arguments"},
})
return p.Seq(
p.Content(p.Until("<tool_call>")),
p.Optional(p.Seq(p.Space(), toolCall)),
p.End(),
)
})
input := `Let me check the weather.<tool_call>{"id": "call_abc123", "name": "get_current_weather", "arguments": {"location": "NYC"}}</tool_call>`
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
Expect(msg.ToolCalls[0].ID).To(Equal("call_abc123"))
})
})
Context("StandardJSONTools Cohere format", func() {
It("parses Cohere-style tool calls with custom keys", func() {
tools := []peg.ToolDef{
{
Name: "get_current_weather",
Properties: map[string]peg.PropDef{
"location": {Type: "string"},
"unit": {Type: "string"},
},
},
}
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
SectionStart: "<|START_ACTION|>[",
SectionEnd: "]<|END_ACTION|>",
Tools: tools,
NameKey: "tool_name",
ArgsKey: "parameters",
GenCallIDKey: "tool_call_id",
ParametersOrder: []string{"tool_call_id", "tool_name", "parameters"},
})
return p.Seq(
p.Content(p.Until("<|START_ACTION|>")),
p.Optional(p.Seq(p.Space(), toolCall)),
p.End(),
)
})
input := `Let me search for that.<|START_ACTION|>[{"tool_call_id": 0, "tool_name": "get_current_weather", "parameters": {"location": "NYC", "unit": "celsius"}}]<|END_ACTION|>`
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
Expect(msg.ToolCalls[0].ID).To(Equal("0"))
})
})
Context("StandardJSONTools function-as-key format", func() {
It("parses function name as JSON key", func() {
tools := []peg.ToolDef{
{
Name: "get_current_weather",
Properties: map[string]peg.PropDef{
"location": {Type: "string"},
"unit": {Type: "string"},
},
},
}
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardJSONTools(peg.StandardJSONToolsOpts{
SectionStart: "<tool_calls>[",
SectionEnd: "]</tool_calls>",
Tools: tools,
ArgsKey: "args",
FunctionIsKey: true,
CallIDKey: "id",
})
return p.Seq(
p.Content(p.Until("<tool_calls>")),
p.Optional(p.Seq(p.Space(), toolCall)),
p.End(),
)
})
input := `I'll call the weather function.<tool_calls>[{"get_current_weather": {"id": "call-0001", "args": {"location": "NYC", "unit": "celsius"}}}]</tool_calls>`
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("get_current_weather"))
Expect(msg.ToolCalls[0].ID).To(Equal("call-0001"))
})
})
Context("Tagged args with embedded quotes", func() {
It("handles embedded double quotes in tagged parameters", func() {
tools := []peg.ToolDef{
{
Name: "edit",
Properties: map[string]peg.PropDef{
"filename": {Type: "string"},
"oldString": {Type: "string"},
"newString": {Type: "string"},
},
},
}
parser := peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
toolCall := p.StandardConstructedTools(
map[string]string{
"tool_call_start_marker": "<seed:tool_call>",
"tool_call_end_marker": "</seed:tool_call>",
"function_opener": "<function=",
"function_name_suffix": ">",
"function_closer": "</function>",
"parameter_key_prefix": "<parameter=",
"parameter_key_suffix": ">",
"parameter_closer": "</parameter>",
},
tools,
false,
true,
)
return p.Seq(toolCall, p.Space(), p.End())
})
input := "<seed:tool_call>\n" +
"<function=edit>\n" +
"<parameter=filename>\nfoo.cpp\n</parameter>\n" +
"<parameter=oldString>def foo(arg = \"14\"):\n return arg + \"bar\"\n</parameter>\n" +
"<parameter=newString>def foo(arg = \"15\"):\n pass\n</parameter>\n" +
"</function>\n" +
"</seed:tool_call>"
ctx := peg.NewParseContext(input, false)
result := parser.Parse(ctx)
Expect(result.Type).To(Equal(peg.Success))
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
Expect(msg.ToolCalls).To(HaveLen(1))
Expect(msg.ToolCalls[0].Name).To(Equal("edit"))
var parsed map[string]any
err := json.Unmarshal([]byte(msg.ToolCalls[0].Arguments), &parsed)
Expect(err).NotTo(HaveOccurred(), "arguments not valid JSON: %s", msg.ToolCalls[0].Arguments)
Expect(parsed["filename"]).NotTo(BeNil())
})
})
})

View File

@@ -0,0 +1,538 @@
package functions
import (
"strings"
"github.com/mudler/LocalAI/pkg/functions/peg"
"github.com/mudler/xlog"
)
// PEGFormatType identifies the format type for PEG parsing.
type PEGFormatType int
const (
FormatJSONNative PEGFormatType = iota
FormatTagWithJSON // <function=name>{"key": "val"}</function>
FormatTagWithTagged // <function=name><param=key>value</param></function>
)
// ParseFunctionCallPEG attempts to parse tool calls using the PEG parser.
// Returns nil if no tool calls were found.
func ParseFunctionCallPEG(llmresult string, config FunctionsConfig) []FuncCallResults {
xlog.Debug("[PEG] starting PEG tool call parsing")
// If auto-detected markers from the C++ backend are available, use them first
if config.ToolFormatMarkers != nil {
m := config.ToolFormatMarkers
xlog.Debug("[PEG] using auto-detected markers from C++ backend",
"format_type", m.FormatType,
"section_start", m.SectionStart,
"section_end", m.SectionEnd,
"per_call_start", m.PerCallStart,
"per_call_end", m.PerCallEnd,
"func_name_prefix", m.FuncNamePrefix,
"func_name_suffix", m.FuncNameSuffix,
"func_close", m.FuncClose,
"arg_name_prefix", m.ArgNamePrefix,
"arg_name_suffix", m.ArgNameSuffix,
"arg_value_prefix", m.ArgValuePrefix,
"arg_value_suffix", m.ArgValueSuffix,
"arg_separator", m.ArgSeparator,
"name_field", m.NameField,
"args_field", m.ArgsField,
"id_field", m.IDField,
"reasoning_start", m.ReasoningStart,
"reasoning_end", m.ReasoningEnd,
)
arena := BuildPEGParserFromMarkers(config.ToolFormatMarkers)
if arena != nil {
results := parsePEG(arena, llmresult)
if len(results) > 0 {
xlog.Debug("[PEG] markers-based parser matched", "count", len(results))
return results
}
xlog.Debug("[PEG] markers-based parser found no tool calls")
} else {
xlog.Debug("[PEG] failed to build parser from markers")
}
}
// If a specific XML format preset is set, use its PEG format
if config.XMLFormatPreset != "" {
xlog.Debug("[PEG] trying XML format preset", "preset", config.XMLFormatPreset)
preset := GetXMLFormatPreset(config.XMLFormatPreset)
if preset != nil {
pegType := classifyXMLFormat(preset)
xlog.Debug("[PEG] classified preset", "preset", config.XMLFormatPreset, "peg_type", pegTypeName(pegType))
arena := BuildPEGParserFromFormat(preset, pegType)
if arena != nil {
results := parsePEG(arena, llmresult)
if len(results) > 0 {
xlog.Debug("[PEG] preset parser matched", "preset", config.XMLFormatPreset, "count", len(results))
return results
}
xlog.Debug("[PEG] preset parser found no tool calls", "preset", config.XMLFormatPreset)
}
} else {
xlog.Debug("[PEG] unknown preset name", "preset", config.XMLFormatPreset)
}
}
// If a custom XML format is set, classify and try it
if config.XMLFormat != nil {
pegType := classifyXMLFormat(config.XMLFormat)
xlog.Debug("[PEG] trying custom XML format", "peg_type", pegTypeName(pegType))
arena := BuildPEGParserFromFormat(config.XMLFormat, pegType)
if arena != nil {
results := parsePEG(arena, llmresult)
if len(results) > 0 {
xlog.Debug("[PEG] custom format parser matched", "count", len(results))
return results
}
xlog.Debug("[PEG] custom format parser found no tool calls")
}
}
// Auto-detect: try all three format types
xlog.Debug("[PEG] auto-detecting format across all presets")
for _, pegType := range []PEGFormatType{FormatJSONNative, FormatTagWithJSON, FormatTagWithTagged} {
for _, preset := range getAllXMLFormats() {
classified := classifyXMLFormat(preset.format)
if classified != pegType {
continue
}
arena := BuildPEGParserFromFormat(preset.format, pegType)
if arena == nil {
continue
}
results := parsePEG(arena, llmresult)
if len(results) > 0 {
xlog.Debug("[PEG] auto-detect matched", "preset", preset.name, "peg_type", pegTypeName(pegType), "count", len(results))
return results
}
}
}
xlog.Debug("[PEG] no tool calls found by any format")
return nil
}
func pegTypeName(t PEGFormatType) string {
switch t {
case FormatJSONNative:
return "json_native"
case FormatTagWithJSON:
return "tag_with_json"
case FormatTagWithTagged:
return "tag_with_tagged"
default:
return "unknown"
}
}
// classifyXMLFormat determines the PEG format type from an XML format config.
func classifyXMLFormat(f *XMLToolCallFormat) PEGFormatType {
// If there's an explicit function opener like "<function=", it's a tag-based format
hasTagOpener := f.ToolStart != "" && f.ToolSep != ""
if f.RawArgVal != nil && !*f.RawArgVal {
// JSON-only args
if hasTagOpener {
return FormatTagWithJSON
}
if f.KeyStart == "" || f.KeyStart == "\"" {
return FormatJSONNative
}
return FormatTagWithJSON
}
if f.KeyStart != "" {
return FormatTagWithTagged
}
return FormatTagWithJSON
}
// BuildPEGParserFromFormat builds a PEG parser arena from an XML format config.
func BuildPEGParserFromFormat(f *XMLToolCallFormat, pegType PEGFormatType) *peg.Arena {
switch pegType {
case FormatTagWithTagged, FormatTagWithJSON:
return buildTaggedPEGParser(f)
case FormatJSONNative:
return buildJSONNativePEGParser(f)
default:
return nil
}
}
func buildTaggedPEGParser(f *XMLToolCallFormat) *peg.Arena {
markers := map[string]string{}
funcOpener := f.ToolStart
funcNameSuffix := f.ToolSep
funcCloser := f.ToolEnd
hasScope := f.ScopeStart != ""
if hasScope {
markers["tool_call_start_marker"] = f.ScopeStart
markers["tool_call_end_marker"] = f.ScopeEnd
}
markers["function_opener"] = funcOpener
markers["function_name_suffix"] = funcNameSuffix
markers["function_closer"] = funcCloser
// Always set parameter markers explicitly to avoid relying on defaults.
// Formats without tagged params (e.g., functionary) need empty strings.
markers["parameter_key_prefix"] = f.KeyStart
markers["parameter_key_suffix"] = f.KeyValSep
markers["parameter_closer"] = f.ValEnd
// Determine what to use as the content delimiter
contentDelim := f.ScopeStart
if contentDelim == "" {
contentDelim = f.ToolStart
}
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
tools := []peg.ToolDef{} // empty = accept anything
content := p.Content(p.Until(contentDelim))
if hasScope {
// With scope markers: use StandardConstructedTools which wraps in scope
toolCall := p.StandardConstructedTools(markers, tools, true, false)
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
}
// No scope markers (e.g., functionary): build tool parser directly without scope wrapper
hasTaggedParams := f.KeyStart != ""
var args peg.ParserID
if hasTaggedParams {
paramKeyPrefix := f.KeyStart
paramKeySuffix := f.KeyValSep
paramCloser := f.ValEnd
argRule := p.ToolArg(p.Seq(
p.ToolArgOpen(p.Literal(paramKeyPrefix)),
p.ToolArgName(p.Until(paramKeySuffix)),
p.Literal(paramKeySuffix),
p.ToolArgValue(p.Until(paramCloser)),
p.ToolArgClose(p.Literal(paramCloser)),
))
args = p.ToolArgs(p.ZeroOrMore(p.Seq(argRule, p.Space())))
} else {
// JSON arguments
args = p.ToolArgs(p.Until(funcCloser))
}
toolParser := p.Tool(p.Seq(
p.ToolOpen(p.Seq(
p.Literal(funcOpener),
p.ToolName(p.Until(funcNameSuffix)),
p.Literal(funcNameSuffix),
)),
p.Space(),
args,
p.Space(),
p.ToolClose(p.Literal(funcCloser)),
))
toolCall := p.TriggerRule("tool-call", p.OneOrMore(p.Seq(toolParser, p.Space())))
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
})
}
func buildJSONNativePEGParser(f *XMLToolCallFormat) *peg.Arena {
sectionStart := f.ScopeStart
sectionEnd := f.ScopeEnd
if sectionStart == "" && f.ToolStart != "" {
sectionStart = f.ToolStart
}
if sectionEnd == "" && f.ToolEnd != "" {
sectionEnd = f.ToolEnd
}
if sectionStart == "" || sectionEnd == "" {
return nil
}
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
// For JSON native, tool call is { "name": ..., "arguments": ... }
// Build a generic parser that accepts any JSON tool call
toolCall := p.TriggerRule("tool-call", p.Seq(
p.Literal(sectionStart), p.Space(),
p.Tool(p.Seq(
p.ToolOpen(p.Literal("{")), p.Space(),
p.ZeroOrMore(p.Seq(
p.Choice(
p.Seq(
p.Literal("\"name\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
),
p.Seq(
p.Literal("\"arguments\""), p.Space(), p.Literal(":"), p.Space(),
p.ToolArgs(p.JSON()),
),
p.Seq(
p.Literal("\""), p.JSONString(), p.Literal("\""), p.Space(),
p.Literal(":"), p.Space(), p.JSON(),
),
),
p.Optional(p.Seq(p.Space(), p.Literal(","), p.Space())),
)),
p.Space(), p.ToolClose(p.Literal("}")),
)),
p.Space(), p.Literal(sectionEnd),
))
content := p.Content(p.Until(sectionStart))
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
})
}
// BuildPEGParserFromMarkers builds a PEG parser from auto-detected C++ autoparser markers.
func BuildPEGParserFromMarkers(m *ToolFormatMarkers) *peg.Arena {
switch m.FormatType {
case "tag_with_json":
return buildPEGFromMarkersTagJSON(m)
case "tag_with_tagged":
return buildPEGFromMarkersTagTagged(m)
case "json_native":
return buildPEGFromMarkersJSONNative(m)
default:
return nil
}
}
func buildPEGFromMarkersTagJSON(m *ToolFormatMarkers) *peg.Arena {
markers := map[string]string{}
// Use section markers if available, otherwise fall back to per-call markers
scopeStart, scopeEnd := effectiveScope(m)
if scopeStart != "" {
markers["tool_call_start_marker"] = scopeStart
markers["tool_call_end_marker"] = scopeEnd
}
markers["function_opener"] = strings.TrimRight(m.FuncNamePrefix, " \t\n")
markers["function_name_suffix"] = strings.TrimRight(m.FuncNameSuffix, " \t\n")
markers["function_closer"] = strings.TrimRight(m.FuncClose, " \t\n")
markers["parameter_key_prefix"] = ""
markers["parameter_key_suffix"] = ""
markers["parameter_closer"] = ""
if m.CallIDPosition == "between_func_and_args" {
markers["call_id_prefix"] = m.CallIDPrefix
markers["call_id_suffix"] = m.CallIDSuffix
}
contentDelim := scopeStart
if contentDelim == "" {
contentDelim = strings.TrimRight(m.FuncNamePrefix, " \t\n")
}
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
tools := []peg.ToolDef{}
content := p.Content(p.Until(contentDelim))
if scopeStart != "" {
toolCall := p.StandardConstructedTools(markers, tools, true, false)
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
}
// No scope: build tool parser directly
funcOpener := m.FuncNamePrefix
funcNameSuffix := m.FuncNameSuffix
funcCloser := m.FuncClose
args := p.ToolArgs(p.Until(funcCloser))
// Build call ID section if detected
callIDSection := buildCallIDParser(p, m)
toolParser := p.Tool(p.Seq(
p.ToolOpen(p.Seq(
p.Literal(funcOpener),
p.ToolName(p.Until(funcNameSuffix)),
p.Literal(funcNameSuffix),
)),
callIDSection,
p.Space(),
args,
p.Space(),
p.ToolClose(p.Literal(funcCloser)),
))
toolCall := p.TriggerRule("tool-call", p.OneOrMore(p.Seq(toolParser, p.Space())))
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
})
}
func buildPEGFromMarkersTagTagged(m *ToolFormatMarkers) *peg.Arena {
markers := map[string]string{}
// Use section markers if available, otherwise fall back to per-call markers
scopeStart, scopeEnd := effectiveScope(m)
if scopeStart != "" {
markers["tool_call_start_marker"] = scopeStart
markers["tool_call_end_marker"] = scopeEnd
}
// Trim trailing whitespace from markers — the PEG Space() parser
// handles whitespace between elements, so baked-in \n would cause mismatches.
markers["function_opener"] = strings.TrimRight(m.FuncNamePrefix, " \t\n")
markers["function_name_suffix"] = strings.TrimRight(m.FuncNameSuffix, " \t\n")
markers["function_closer"] = strings.TrimRight(m.FuncClose, " \t\n")
markers["parameter_key_prefix"] = strings.TrimRight(m.ArgNamePrefix, " \t\n")
markers["parameter_key_suffix"] = strings.TrimRight(m.ArgNameSuffix, " \t\n")
markers["parameter_closer"] = strings.TrimRight(m.ArgValueSuffix, " \t\n")
if m.CallIDPosition == "between_func_and_args" {
markers["call_id_prefix"] = m.CallIDPrefix
markers["call_id_suffix"] = m.CallIDSuffix
}
contentDelim := scopeStart
if contentDelim == "" {
contentDelim = strings.TrimRight(m.FuncNamePrefix, " \t\n")
}
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
tools := []peg.ToolDef{}
content := p.Content(p.Until(contentDelim))
toolCall := p.StandardConstructedTools(markers, tools, true, false)
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
})
}
func buildPEGFromMarkersJSONNative(m *ToolFormatMarkers) *peg.Arena {
sectionStart, sectionEnd := effectiveScope(m)
if sectionStart == "" || sectionEnd == "" {
return nil
}
nameKey := m.NameField
if nameKey == "" {
nameKey = "name"
}
argsKey := m.ArgsField
if argsKey == "" {
argsKey = "arguments"
}
idField := m.IDField
genIDField := m.GenIDField
return peg.BuildChatPegParser(func(p *peg.ChatBuilder) peg.ParserID {
// Build field matchers for known keys
knownFields := []peg.ParserID{
p.Seq(
p.Literal("\""+nameKey+"\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolName(p.JSONString()), p.Literal("\""),
),
p.Seq(
p.Literal("\""+argsKey+"\""), p.Space(), p.Literal(":"), p.Space(),
p.ToolArgs(p.JSON()),
),
}
// Add ID field matching if detected
if idField != "" {
knownFields = append(knownFields, p.Seq(
p.Literal("\""+idField+"\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
))
}
if genIDField != "" && genIDField != idField {
knownFields = append(knownFields, p.Seq(
p.Literal("\""+genIDField+"\""), p.Space(), p.Literal(":"), p.Space(),
p.Literal("\""), p.ToolID(p.JSONString()), p.Literal("\""),
))
}
// Catch-all for unknown JSON fields
knownFields = append(knownFields, p.Seq(
p.Literal("\""), p.JSONString(), p.Literal("\""), p.Space(),
p.Literal(":"), p.Space(), p.JSON(),
))
// Build a generic JSON tool call parser that accepts any tool
toolCall := p.TriggerRule("tool-call", p.Seq(
p.Literal(sectionStart), p.Space(),
p.Tool(p.Seq(
p.ToolOpen(p.Literal("{")), p.Space(),
p.ZeroOrMore(p.Seq(
p.Choice(knownFields...),
p.Optional(p.Seq(p.Space(), p.Literal(","), p.Space())),
)),
p.Space(), p.ToolClose(p.Literal("}")),
)),
p.Space(), p.Literal(sectionEnd),
))
content := p.Content(p.Until(sectionStart))
return p.Seq(content, p.Optional(p.Seq(p.Space(), toolCall)), p.End())
})
}
// effectiveScope returns the scope start/end markers to use.
// Prefers section markers, falls back to per-call markers, stripping trailing
// whitespace so the PEG Space() parser can handle it flexibly.
func effectiveScope(m *ToolFormatMarkers) (string, string) {
if m.SectionStart != "" {
return strings.TrimRight(m.SectionStart, " \t\n"), strings.TrimRight(m.SectionEnd, " \t\n")
}
if m.PerCallStart != "" {
return strings.TrimRight(m.PerCallStart, " \t\n"), strings.TrimRight(m.PerCallEnd, " \t\n")
}
return "", ""
}
// buildCallIDParser creates a parser for call ID markers based on position.
// Currently only BETWEEN_FUNC_AND_ARGS is supported (matching llama.cpp behavior).
func buildCallIDParser(p *peg.ChatBuilder, m *ToolFormatMarkers) peg.ParserID {
if m.CallIDPosition == "between_func_and_args" && m.CallIDPrefix != "" && m.CallIDSuffix != "" {
return p.Optional(p.Seq(
p.Literal(m.CallIDPrefix),
p.ToolID(p.Until(m.CallIDSuffix)),
p.Literal(m.CallIDSuffix),
))
}
return p.Eps()
}
// parsePEG runs the PEG parser and extracts tool call results.
func parsePEG(arena *peg.Arena, input string) []FuncCallResults {
ctx := peg.NewParseContext(input, false)
result := arena.Parse(ctx)
if result.Type != peg.Success {
inputPreview := input
if len(inputPreview) > 200 {
inputPreview = inputPreview[:200] + "..."
}
xlog.Debug("[PEG] parse did not succeed", "result_type", result.Type, "input_preview", inputPreview)
return nil
}
mapper := &peg.ChatPegMapper{}
mapper.FromAST(&ctx.Ast, &result)
msg := mapper.Result
xlog.Debug("[PEG] parse succeeded", "content_len", len(msg.Content), "reasoning_len", len(msg.ReasoningContent), "tool_calls", len(msg.ToolCalls))
if len(msg.ToolCalls) == 0 {
return nil
}
var results []FuncCallResults
for _, tc := range msg.ToolCalls {
xlog.Debug("[PEG] extracted tool call", "name", tc.Name, "id", tc.ID, "args_len", len(tc.Arguments))
results = append(results, FuncCallResults{
Name: tc.Name,
Arguments: tc.Arguments,
ID: tc.ID,
})
}
return results
}

View File

@@ -0,0 +1,301 @@
package functions_test
import (
. "github.com/mudler/LocalAI/pkg/functions"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("PEG Integration", func() {
Context("format presets", func() {
It("parses functionary format", func() {
input := `I'll help you with that.<function=get_weather>{"location": "NYC", "unit": "celsius"}</function>`
config := FunctionsConfig{
XMLFormatPreset: "functionary",
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
})
It("parses qwen3-coder format", func() {
input := "<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function>\n</tool_call>"
config := FunctionsConfig{
XMLFormatPreset: "qwen3-coder",
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
})
It("parses qwen3-coder format with preceding content", func() {
input := "Let me think about this...\n<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>"
config := FunctionsConfig{
XMLFormatPreset: "qwen3-coder",
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
})
It("parses minimax-m2 format", func() {
input := "Here's the result.\n<minimax:tool_call>\n<invoke name=\"search\">\n<parameter name=\"query\">test query</parameter>\n</invoke>\n</minimax:tool_call>"
config := FunctionsConfig{
XMLFormatPreset: "minimax-m2",
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("search"))
Expect(results[0].Arguments).To(ContainSubstring(`"query"`))
})
It("handles glm-4.5 format gracefully", func() {
input := "<tool_call><arg_key>location</arg_key><arg_value>NYC</arg_value></tool_call>"
config := FunctionsConfig{
XMLFormatPreset: "glm-4.5",
}
results := ParseFunctionCallPEG(input, config)
// GLM-4.5 uses tool_call as both scope and tool start with no function name separator,
// so the PEG parser may not handle it perfectly.
if len(results) > 0 {
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
}
})
})
Context("auto-detect", func() {
It("detects format without preset", func() {
input := "<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>"
config := FunctionsConfig{}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
})
})
Context("custom XML format", func() {
It("parses with custom format config", func() {
input := "<tool_call>\n<function=edit>\n<parameter=filename>\ntest.py\n</parameter>\n<parameter=content>\nhello world\n</parameter>\n</function>\n</tool_call>"
config := FunctionsConfig{
XMLFormat: &XMLToolCallFormat{
ScopeStart: "<tool_call>",
ToolStart: "<function=",
ToolSep: ">",
KeyStart: "<parameter=",
KeyValSep: ">",
ValEnd: "</parameter>",
ToolEnd: "</function>",
ScopeEnd: "</tool_call>",
TrimRawArgVal: true,
},
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("edit"))
Expect(results[0].Arguments).To(ContainSubstring(`"filename"`))
})
})
Context("no tool calls", func() {
It("returns empty results for plain text", func() {
input := "This is just a regular response with no tool calls."
config := FunctionsConfig{
XMLFormatPreset: "qwen3-coder",
}
results := ParseFunctionCallPEG(input, config)
Expect(results).To(BeEmpty())
})
})
Context("ParseFunctionCall integration", func() {
It("finds tool calls via PEG in ParseFunctionCall flow", func() {
input := "<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>"
config := FunctionsConfig{}
results := ParseFunctionCall(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
})
It("finds functionary tool calls via ParseFunctionCall", func() {
input := `Sure!<function=calculator>{"expression": "2+2"}</function>`
config := FunctionsConfig{
XMLFormatPreset: "functionary",
}
results := ParseFunctionCall(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("calculator"))
Expect(results[0].Arguments).To(ContainSubstring(`"expression"`))
})
})
Context("DisablePEGParser", func() {
It("still works when called directly but skips PEG in ParseFunctionCall", func() {
input := "<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>"
config := FunctionsConfig{
DisablePEGParser: true,
}
// ParseFunctionCallPEG should still work when called directly
pegResults := ParseFunctionCallPEG(input, config)
// May or may not find results depending on auto-detect
_ = pegResults
// ParseFunctionCall with PEG disabled should NOT find XML tool calls
disabledResults := ParseFunctionCall(input, config)
// May find via JSON extraction
_ = disabledResults
// ParseXML (iterative parser) should still find results
xmlResults, err := ParseXML(input, nil)
Expect(err).NotTo(HaveOccurred())
Expect(xmlResults).NotTo(BeEmpty())
Expect(xmlResults[0].Name).To(Equal("get_weather"))
})
})
Context("markers-based parsing", func() {
It("parses tag_with_json format from markers", func() {
input := `Hello!<function=get_weather>{"location": "NYC"}</function>`
markers := &ToolFormatMarkers{
FormatType: "tag_with_json",
FuncNamePrefix: "<function=",
FuncNameSuffix: ">",
FuncClose: "</function>",
}
arena := BuildPEGParserFromMarkers(markers)
Expect(arena).NotTo(BeNil())
config := FunctionsConfig{
ToolFormatMarkers: markers,
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
})
It("parses tag_with_tagged format from markers", func() {
input := "<tool_call>\n<function=get_weather>\n<parameter=location>NYC</parameter>\n</function>\n</tool_call>"
markers := &ToolFormatMarkers{
FormatType: "tag_with_tagged",
SectionStart: "<tool_call>",
SectionEnd: "</tool_call>",
FuncNamePrefix: "<function=",
FuncNameSuffix: ">",
FuncClose: "</function>",
ArgNamePrefix: "<parameter=",
ArgNameSuffix: ">",
ArgValueSuffix: "</parameter>",
}
config := FunctionsConfig{
ToolFormatMarkers: markers,
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
})
It("parses json_native format from markers", func() {
input := `Some content<tool_call>{"name": "get_weather", "arguments": {"location": "NYC"}}</tool_call>`
markers := &ToolFormatMarkers{
FormatType: "json_native",
SectionStart: "<tool_call>",
SectionEnd: "</tool_call>",
NameField: "name",
ArgsField: "arguments",
}
config := FunctionsConfig{
ToolFormatMarkers: markers,
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
})
It("returns nil arena for unknown format type", func() {
markers := &ToolFormatMarkers{
FormatType: "unknown",
}
arena := BuildPEGParserFromMarkers(markers)
Expect(arena).To(BeNil())
})
It("parses json_native format with ID field", func() {
input := `Some content<tool_call>{"name": "get_weather", "arguments": {"location": "NYC"}, "id": "call_123"}</tool_call>`
markers := &ToolFormatMarkers{
FormatType: "json_native",
SectionStart: "<tool_call>",
SectionEnd: "</tool_call>",
NameField: "name",
ArgsField: "arguments",
IDField: "id",
}
config := FunctionsConfig{
ToolFormatMarkers: markers,
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
Expect(results[0].ID).To(Equal("call_123"))
})
It("parses call ID between function name and arguments", func() {
input := `<tool_call><function=get_weather>[call_abc]{"location": "NYC"}</function></tool_call>`
markers := &ToolFormatMarkers{
FormatType: "tag_with_json",
SectionStart: "<tool_call>",
SectionEnd: "</tool_call>",
FuncNamePrefix: "<function=",
FuncNameSuffix: ">",
FuncClose: "</function>",
CallIDPosition: "between_func_and_args",
CallIDPrefix: "[",
CallIDSuffix: "]",
}
config := FunctionsConfig{
ToolFormatMarkers: markers,
}
results := ParseFunctionCallPEG(input, config)
Expect(results).NotTo(BeEmpty())
Expect(results[0].Name).To(Equal("get_weather"))
Expect(results[0].ID).To(Equal("call_abc"))
Expect(results[0].Arguments).To(ContainSubstring(`"location"`))
})
})
})