From 10f7478099e074ea1112380a9bd62a26348fc2c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= Date: Sun, 10 May 2026 17:07:57 +0200 Subject: [PATCH] refactor: unify tool arg parsing and simplify string formatting --- src/agent/Agent.zig | 36 +++++++++++++++++++----------------- src/browser/tools.zig | 23 ++++++++++------------- src/mcp/tools.zig | 24 +++++++++--------------- 3 files changed, 38 insertions(+), 45 deletions(-) diff --git a/src/agent/Agent.zig b/src/agent/Agent.zig index d7cff37c..9071e07c 100644 --- a/src/agent/Agent.zig +++ b/src/agent/Agent.zig @@ -264,7 +264,6 @@ pub fn run(self: *Self) bool { /// pipe stdout to capture a clean answer. fn runTurn(self: *Self, input: TurnInput) bool { const text = self.processUserMessage(input) catch |err| switch (err) { - // buildUserMessageParts has already logged the detail. error.UnsupportedAttachment, error.AttachmentReadFailed => return false, else => { self.terminal.printErrorFmt("{s} failed: {s}", .{ input.label, @errorName(err) }); @@ -395,15 +394,16 @@ fn printSlashHelp(self: *Self, target: []const u8) void { } fn printSlashParseError(self: *Self, err: SlashCommand.ParseError, name: []const u8) void { - switch (err) { - error.UnknownTool => self.terminal.printErrorFmt("unknown tool '{s}'. Try /help.", .{name}), - error.MissingName => self.terminal.printError("missing tool name. Try /help."), - error.MissingRequired => self.terminal.printErrorFmt("{s}: missing required argument. Try /help {s}.", .{ name, name }), - error.MalformedKv => self.terminal.printErrorFmt("{s}: malformed key=value. Use key=value or {{json}}.", .{name}), - error.PositionalNotAllowed => self.terminal.printErrorFmt("{s}: positional only works for tools with one required field. Use key=value.", .{name}), - error.UnterminatedQuote => self.terminal.printErrorFmt("{s}: unterminated quote.", .{name}), - error.OutOfMemory => self.terminal.printError("out of memory"), - } + const reason: []const u8 = switch (err) { + error.UnknownTool => "unknown tool", + error.MissingName => return self.terminal.printError("missing tool name. Try /help."), + error.MissingRequired => "missing required argument", + error.MalformedKv => "malformed key=value. Use key=value or {json}", + error.PositionalNotAllowed => "positional only works for tools with one required field. Use key=value", + error.UnterminatedQuote => "unterminated quote", + error.OutOfMemory => return self.terminal.printError("out of memory"), + }; + self.terminal.printErrorFmt("{s}: {s}. Try /help {s}.", .{ name, reason, name }); } fn firstSentence(text: []const u8) []const u8 { @@ -701,11 +701,14 @@ fn runHealTurn(self: *Self, arena: std.mem.Allocator, prompt: []const u8) ![]Com } fn attemptSelfHeal(self: *Self, arena: std.mem.Allocator, failed_command: []const u8, verify_context: ?[]const u8, context_comment: ?[]const u8) ?[]Command.Command { - var aw: std.Io.Writer.Allocating = .init(self.message_arena.allocator()); - aw.writer.writeAll(self_heal_prompt_prefix) catch return null; - aw.writer.writeAll(failed_command) catch return null; - aw.writer.writeAll(self_heal_prompt_page_state) catch return null; - aw.writer.writeAll(self.tool_executor.getCurrentUrl()) catch return null; + const ma = self.message_arena.allocator(); + var aw: std.Io.Writer.Allocating = .init(ma); + aw.writer.print("{s}{s}{s}{s}", .{ + self_heal_prompt_prefix, + failed_command, + self_heal_prompt_page_state, + self.tool_executor.getCurrentUrl(), + }) catch return null; if (context_comment) |c| aw.writer.print("\n\nThe original user request that generated this command was:\n{s}", .{c}) catch return null; if (verify_context) |ctx| @@ -911,8 +914,7 @@ const tool_output_max_bytes: usize = 1 * 1024 * 1024; fn capToolOutput(allocator: std.mem.Allocator, output: []const u8) []const u8 { if (output.len <= tool_output_max_bytes) return output; const prefix = output[0..tool_output_max_bytes]; - const suffix = std.fmt.allocPrint(allocator, "\n...[truncated, original {d} bytes]", .{output.len}) catch return prefix; - return std.mem.concat(allocator, u8, &.{ prefix, suffix }) catch prefix; + return std.fmt.allocPrint(allocator, "{s}\n...[truncated, original {d} bytes]", .{ prefix, output.len }) catch prefix; } fn handleToolCall(ctx: *anyopaque, allocator: std.mem.Allocator, tool_name: []const u8, arguments: []const u8) zenai.provider.Client.ToolHandler.Result { diff --git a/src/browser/tools.zig b/src/browser/tools.zig index fa7bc4e5..404fd8b7 100644 --- a/src/browser/tools.zig +++ b/src/browser/tools.zig @@ -722,16 +722,13 @@ fn formatActionResult( page: *lp.Frame, ) ToolError![]const u8 { const page_title = page.getTitle() catch null; - var aw: std.Io.Writer.Allocating = .init(arena); - if (selector) |sel| - aw.writer.print("{s} (selector: {s}){s}. Page url: {s}, title: {s}", .{ - prefix, sel, suffix, page.url, page_title orelse "(none)", - }) catch return ToolError.InternalError + const target = if (selector) |sel| + std.fmt.allocPrint(arena, "selector: {s}", .{sel}) catch return ToolError.InternalError else - aw.writer.print("{s} (backendNodeId: {d}){s}. Page url: {s}, title: {s}", .{ - prefix, backend_node_id.?, suffix, page.url, page_title orelse "(none)", - }) catch return ToolError.InternalError; - return aw.written(); + std.fmt.allocPrint(arena, "backendNodeId: {d}", .{backend_node_id.?}) catch return ToolError.InternalError; + return std.fmt.allocPrint(arena, "{s} ({s}){s}. Page url: {s}, title: {s}", .{ + prefix, target, suffix, page.url, page_title orelse "(none)", + }) catch ToolError.InternalError; } fn execClick(arena: std.mem.Allocator, session: *lp.Session, registry: *CDPNode.Registry, arguments: ?std.json.Value) ToolError![]const u8 { @@ -1010,9 +1007,9 @@ fn resolveBySelector(session: *lp.Session, selector: []const u8) ToolError!NodeA return .{ .node = node, .page = page }; } -const ParseArgsError = error{ OutOfMemory, InvalidParams }; +pub const ParseArgsError = error{ OutOfMemory, InvalidParams }; -fn parseValue(comptime T: type, arena: std.mem.Allocator, value: std.json.Value) ParseArgsError!T { +pub fn parseValue(comptime T: type, arena: std.mem.Allocator, value: std.json.Value) ParseArgsError!T { return std.json.parseFromValueLeaky(T, arena, value, .{ .ignore_unknown_fields = true }) catch |err| switch (err) { error.OutOfMemory => error.OutOfMemory, else => error.InvalidParams, @@ -1021,12 +1018,12 @@ fn parseValue(comptime T: type, arena: std.mem.Allocator, value: std.json.Value) /// For tools where every field is optional. Missing args → default `T`; /// wrong-typed args still error (don't silently default). -fn parseArgsOrDefault(comptime T: type, arena: std.mem.Allocator, arguments: ?std.json.Value) ParseArgsError!T { +pub fn parseArgsOrDefault(comptime T: type, arena: std.mem.Allocator, arguments: ?std.json.Value) ParseArgsError!T { return parseValue(T, arena, arguments orelse return .{}); } /// Required-args parse: missing or malformed both surface as `InvalidParams`. -fn parseArgs(comptime T: type, arena: std.mem.Allocator, arguments: ?std.json.Value) ParseArgsError!T { +pub fn parseArgs(comptime T: type, arena: std.mem.Allocator, arguments: ?std.json.Value) ParseArgsError!T { return parseValue(T, arena, arguments orelse return error.InvalidParams); } diff --git a/src/mcp/tools.zig b/src/mcp/tools.zig index 9896e5fb..8fa2b8e1 100644 --- a/src/mcp/tools.zig +++ b/src/mcp/tools.zig @@ -122,12 +122,10 @@ pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Reque const id = req.id orelse return; const params = req.params orelse return server.transport.sendError(id, .InvalidParams, "Missing params"); - const call_params = std.json.parseFromValueLeaky(protocol.CallParams, arena, params, .{ .ignore_unknown_fields = true }) catch { + const call_params = browser_tools.parseValue(protocol.CallParams, arena, params) catch { return server.transport.sendError(id, .InvalidParams, "Invalid params"); }; - // Hand-written tools: dispatch first so they don't collide with the - // generated browser tools. if (std.mem.eql(u8, call_params.name, "record_start")) return handleRecordStart(server, arena, id, call_params.arguments); if (std.mem.eql(u8, call_params.name, "record_stop")) return handleRecordStop(server, arena, id); if (std.mem.eql(u8, call_params.name, "record_comment")) return handleRecordComment(server, arena, id, call_params.arguments); @@ -188,9 +186,8 @@ fn handleRecordStart(server: *Server, arena: std.mem.Allocator, id: std.json.Val if (server.recorder != null) { return sendErrorContent(server, id, "a recording is already active; call record_stop first"); } - const args_value = arguments orelse return server.transport.sendError(id, .InvalidParams, "missing arguments"); const Args = struct { path: []const u8 }; - const args = std.json.parseFromValueLeaky(Args, arena, args_value, .{ .ignore_unknown_fields = true }) catch { + const args = browser_tools.parseArgs(Args, arena, arguments) catch { return server.transport.sendError(id, .InvalidParams, "expected { path: string }"); }; @@ -233,9 +230,8 @@ fn handleRecordComment(server: *Server, arena: std.mem.Allocator, id: std.json.V if (server.recorder == null) { return sendErrorContent(server, id, "no recording is active"); } - const args_value = arguments orelse return server.transport.sendError(id, .InvalidParams, "missing arguments"); const Args = struct { text: []const u8 }; - const args = std.json.parseFromValueLeaky(Args, arena, args_value, .{ .ignore_unknown_fields = true }) catch { + const args = browser_tools.parseArgs(Args, arena, arguments) catch { return server.transport.sendError(id, .InvalidParams, "expected { text: string }"); }; @@ -246,22 +242,22 @@ fn handleRecordComment(server: *Server, arena: std.mem.Allocator, id: std.json.V } fn handleScriptStep(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { - const args_value = arguments orelse return server.transport.sendError(id, .InvalidParams, "missing arguments"); const Args = struct { line: []const u8 }; - const args = std.json.parseFromValueLeaky(Args, arena, args_value, .{ .ignore_unknown_fields = true }) catch { + const args = browser_tools.parseArgs(Args, arena, arguments) catch { return server.transport.sendError(id, .InvalidParams, "expected { line: string }"); }; const cmd = Command.parse(args.line); + if (cmd.needsLlm()) { + return sendErrorContent(server, id, "LOGIN / ACCEPT_COOKIES / natural-language steps require an LLM and are not handled by lightpanda mcp; the calling agent owns those"); + } + switch (cmd) { .comment => { const content = [_]protocol.TextContent([]const u8){.{ .text = "comment" }}; return server.transport.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content }); }, - .login, .accept_cookies, .natural_language => { - return sendErrorContent(server, id, "LOGIN / ACCEPT_COOKIES / natural-language steps require an LLM and are not handled by lightpanda mcp; the calling agent owns those"); - }, .extract => |sel| { const result = browser_tools.extractText(arena, server.session, &server.node_registry, sel); const content = [_]protocol.TextContent([]const u8){.{ .text = result.text }}; @@ -309,8 +305,6 @@ fn handleScriptStep(server: *Server, arena: std.mem.Allocator, id: std.json.Valu } fn handleScriptHeal(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void { - const args_value = arguments orelse return server.transport.sendError(id, .InvalidParams, "missing arguments"); - const ReplacementSpec = struct { original_line: []const u8, replacement_lines: []const []const u8, @@ -319,7 +313,7 @@ fn handleScriptHeal(server: *Server, arena: std.mem.Allocator, id: std.json.Valu path: []const u8, replacements: []const ReplacementSpec, }; - const args = std.json.parseFromValueLeaky(Args, arena, args_value, .{ .ignore_unknown_fields = true }) catch { + const args = browser_tools.parseArgs(Args, arena, arguments) catch { return server.transport.sendError(id, .InvalidParams, "expected { path: string, replacements: [{ original_line, replacement_lines }] }"); };