refactor: optimize tool calls and improve script healing

This commit is contained in:
Adrià Arrufat
2026-05-07 20:30:49 +02:00
parent 622d408d03
commit 77fc818976
3 changed files with 57 additions and 26 deletions

View File

@@ -449,10 +449,17 @@ pub fn toToolCall(arena: std.mem.Allocator, cmd: Command, substitute: Substitute
/// return null if the tool name doesn't correspond to a PandaScript
/// command. Variants emitted by `toToolCall` round-trip through this.
pub fn fromToolCall(arena: std.mem.Allocator, tool_name: []const u8, arguments: []const u8) ?Command {
const parsed = std.json.parseFromSliceLeaky(std.json.Value, arena, arguments, .{}) catch return null;
return fromToolCallValue(tool_name, parsed);
}
/// Like `fromToolCall` but takes the already-parsed JSON value directly,
/// skipping the string round-trip when the caller already has it (e.g. the
/// MCP server, which dispatches off `std.json.Value`).
pub fn fromToolCallValue(tool_name: []const u8, arguments: std.json.Value) ?Command {
const Action = lp.tools.Action;
const action = std.meta.stringToEnum(Action, tool_name) orelse return null;
const parsed = std.json.parseFromSliceLeaky(std.json.Value, arena, arguments, .{}) catch return null;
const obj = switch (parsed) {
const obj = switch (arguments) {
.object => |o| o,
else => return null,
};

View File

@@ -154,7 +154,7 @@ fn dispatchBrowserTool(
// JS errors are returned as isError tool results, not protocol errors
if (action == .eval) {
const result = browser_tools.callEval(arena, server.session, &server.node_registry, arguments);
if (!result.is_error) recordIfActive(server, arena, name, arguments);
if (!result.is_error) recordIfActive(server, name, arguments);
const content = [_]protocol.TextContent([]const u8){.{ .text = result.text }};
return server.transport.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content, .isError = result.is_error });
}
@@ -168,7 +168,7 @@ fn dispatchBrowserTool(
return server.transport.sendError(id, code, @errorName(err));
};
recordIfActive(server, arena, name, arguments);
recordIfActive(server, name, arguments);
const content = [_]protocol.TextContent([]const u8){.{ .text = result }};
try server.transport.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content });
@@ -177,11 +177,10 @@ fn dispatchBrowserTool(
/// If a recorder is active and the (name, args) pair maps to a PandaScript
/// Command, append it to the recording. Tools without a Command mapping
/// (tree, markdown, findElement, etc.) are silently skipped.
fn recordIfActive(server: *Server, arena: std.mem.Allocator, name: []const u8, arguments: ?std.json.Value) void {
fn recordIfActive(server: *Server, name: []const u8, arguments: ?std.json.Value) void {
if (server.recorder == null) return;
const args_value = arguments orelse return;
const args_json = Command.stringifyJson(arena, args_value);
const cmd = Command.fromToolCall(arena, name, args_json) orelse return;
const cmd = Command.fromToolCallValue(name, args_value) orelse return;
server.recorder.?.record(cmd);
server.record_lines += 1;
}
@@ -203,11 +202,15 @@ fn handleRecordStart(server: *Server, arena: std.mem.Allocator, id: std.json.Val
const path_owned = server.allocator.dupe(u8, args.path) catch return sendErrorContent(server, id, "out of memory");
errdefer server.allocator.free(path_owned);
const msg = std.fmt.allocPrint(arena, "recording started: {s}", .{path_owned}) catch {
server.allocator.free(path_owned);
return sendErrorContent(server, id, "out of memory");
};
server.recorder = Recorder.init(server.allocator, path_owned);
server.record_path = path_owned;
server.record_lines = 0;
const msg = std.fmt.allocPrint(arena, "recording started: {s}", .{path_owned}) catch return;
const content = [_]protocol.TextContent([]const u8){.{ .text = msg }};
try server.transport.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content });
}
@@ -219,13 +222,16 @@ fn handleRecordStop(server: *Server, arena: std.mem.Allocator, id: std.json.Valu
const path = server.record_path.?;
const lines = server.record_lines;
// Build the response before nulling state so an allocPrint failure doesn't
// strand `path` (record_path = null would hide it from Server.deinit).
const msg = std.fmt.allocPrint(arena, "recording stopped: {s} ({d} line(s) written)", .{ path, lines }) catch
return sendErrorContent(server, id, "out of memory");
var r = server.recorder.?;
r.deinit();
server.recorder = null;
server.record_path = null;
server.record_lines = 0;
const msg = std.fmt.allocPrint(arena, "recording stopped: {s} ({d} line(s) written)", .{ path, lines }) catch return;
server.allocator.free(path);
const content = [_]protocol.TextContent([]const u8){.{ .text = msg }};
@@ -236,10 +242,9 @@ fn handleRecordComment(server: *Server, arena: std.mem.Allocator, id: std.json.V
if (server.recorder == null) {
return sendErrorContent(server, id, "no recording is active");
}
_ = arena;
const args_value = arguments orelse return server.transport.sendError(id, .InvalidParams, "missing arguments");
const Args = struct { text: []const u8 };
const args = std.json.parseFromValueLeaky(Args, server.allocator, args_value, .{ .ignore_unknown_fields = true }) catch {
const args = std.json.parseFromValueLeaky(Args, arena, args_value, .{ .ignore_unknown_fields = true }) catch {
return server.transport.sendError(id, .InvalidParams, "expected { text: string }");
};
@@ -352,16 +357,21 @@ fn handleScriptHeal(server: *Server, arena: std.mem.Allocator, id: std.json.Valu
var splices = arena.alloc(script.Replacement, args.replacements.len) catch return sendErrorContent(server, id, "out of memory");
for (args.replacements, 0..) |spec, i| {
const span = findLineSpan(content, spec.original_line) orelse {
const msg = std.fmt.allocPrint(arena, "original_line not found verbatim: `{s}`", .{spec.original_line}) catch "original_line not found";
const span = findLineSpan(content, spec.original_line) catch |err| {
const reason: []const u8 = switch (err) {
error.NotFound => "original_line not found verbatim",
error.Ambiguous => "original_line matches more than one line; make it unique to disambiguate",
};
const msg = std.fmt.allocPrint(arena, "{s}: `{s}`", .{ reason, spec.original_line }) catch reason;
return sendErrorContent(server, id, msg);
};
var aw: std.Io.Writer.Allocating = .init(arena);
aw.writer.print("# [Auto-healed] Original: {s}\n", .{spec.original_line}) catch return sendErrorContent(server, id, "out of memory formatting heal header");
aw.writer.print("# [Auto-healed] Original: {s}\n", .{spec.original_line}) catch |err|
return sendErrorContent(server, id, @errorName(err));
for (spec.replacement_lines) |rl| {
aw.writer.writeAll(rl) catch return sendErrorContent(server, id, "out of memory writing replacement line");
aw.writer.writeByte('\n') catch return sendErrorContent(server, id, "out of memory writing replacement line");
aw.writer.writeAll(rl) catch |err| return sendErrorContent(server, id, @errorName(err));
aw.writer.writeByte('\n') catch |err| return sendErrorContent(server, id, @errorName(err));
}
splices[i] = .{ .original_span = span, .new_text = aw.written() };
@@ -380,19 +390,24 @@ fn handleScriptHeal(server: *Server, arena: std.mem.Allocator, id: std.json.Valu
/// Find a line in `content` that exactly equals `line` (after trimming the
/// trailing newline). Returns the slice covering the line plus its
/// terminating `\n` if present, ready for `script.applyReplacements`.
fn findLineSpan(content: []const u8, line: []const u8) ?[]const u8 {
/// Errors if the line is missing or matches more than once — a duplicate
/// match would silently rewrite the wrong line and break
/// applyReplacements' non-overlapping invariant.
fn findLineSpan(content: []const u8, line: []const u8) error{ NotFound, Ambiguous }![]const u8 {
var pos: usize = 0;
var found: ?[]const u8 = null;
while (pos <= content.len) {
const nl = std.mem.indexOfScalarPos(u8, content, pos, '\n') orelse content.len;
const this_line = content[pos..nl];
if (std.mem.eql(u8, this_line, line)) {
if (found != null) return error.Ambiguous;
const end = if (nl < content.len) nl + 1 else nl;
return content[pos..end];
found = content[pos..end];
}
if (nl == content.len) return null;
if (nl == content.len) break;
pos = nl + 1;
}
return null;
return found orelse error.NotFound;
}
fn currentUrl(server: *Server) ![]const u8 {
@@ -439,21 +454,26 @@ test "MCP - eval error reporting" {
test "MCP - findLineSpan: exact match returns line + trailing newline" {
const content = "GOTO https://x\nCLICK 'old'\nWAIT '.thanks'\n";
const span = findLineSpan(content, "CLICK 'old'").?;
const span = try findLineSpan(content, "CLICK 'old'");
try std.testing.expectEqualStrings("CLICK 'old'\n", span);
}
test "MCP - findLineSpan: no match returns null" {
test "MCP - findLineSpan: no match returns NotFound" {
const content = "GOTO https://x\nCLICK 'a'\n";
try std.testing.expect(findLineSpan(content, "CLICK 'b'") == null);
try std.testing.expectError(error.NotFound, findLineSpan(content, "CLICK 'b'"));
}
test "MCP - findLineSpan: last line without trailing newline" {
const content = "GOTO https://x\nCLICK 'last'";
const span = findLineSpan(content, "CLICK 'last'").?;
const span = try findLineSpan(content, "CLICK 'last'");
try std.testing.expectEqualStrings("CLICK 'last'", span);
}
test "MCP - findLineSpan: duplicate line returns Ambiguous" {
const content = "CLICK 'go'\nWAIT '.x'\nCLICK 'go'\n";
try std.testing.expectError(error.Ambiguous, findLineSpan(content, "CLICK 'go'"));
}
test "MCP - record_start rejects unsafe path" {
defer testing.reset();
var out: std.io.Writer.Allocating = .init(testing.arena_allocator);

View File

@@ -90,8 +90,11 @@ pub fn applyReplacements(
replacements: []const Replacement,
) error{OutOfMemory}![]u8 {
const content_base = @intFromPtr(content.ptr);
// Subtract before adding so intermediate arithmetic on usize cannot
// underflow when individual replacements shrink even though the net
// delta is positive.
var total = content.len;
for (replacements) |r| total = total + r.new_text.len - r.original_span.len;
for (replacements) |r| total = total - r.original_span.len + r.new_text.len;
var out: std.ArrayList(u8) = .empty;
errdefer out.deinit(allocator);
@@ -100,6 +103,7 @@ pub fn applyReplacements(
for (replacements) |r| {
const r_start = @intFromPtr(r.original_span.ptr) - content_base;
const r_end = r_start + r.original_span.len;
std.debug.assert(r_start >= pos and r_end <= content.len);
out.appendSliceAssumeCapacity(content[pos..r_start]);
out.appendSliceAssumeCapacity(r.new_text);
pos = r_end;