diff --git a/src/agent/Command.zig b/src/agent/Command.zig index 129c475e..feb1f36e 100644 --- a/src/agent/Command.zig +++ b/src/agent/Command.zig @@ -449,10 +449,17 @@ pub fn toToolCall(arena: std.mem.Allocator, cmd: Command, substitute: Substitute /// return null if the tool name doesn't correspond to a PandaScript /// command. Variants emitted by `toToolCall` round-trip through this. pub fn fromToolCall(arena: std.mem.Allocator, tool_name: []const u8, arguments: []const u8) ?Command { + const parsed = std.json.parseFromSliceLeaky(std.json.Value, arena, arguments, .{}) catch return null; + return fromToolCallValue(tool_name, parsed); +} + +/// Like `fromToolCall` but takes the already-parsed JSON value directly, +/// skipping the string round-trip when the caller already has it (e.g. the +/// MCP server, which dispatches off `std.json.Value`). +pub fn fromToolCallValue(tool_name: []const u8, arguments: std.json.Value) ?Command { const Action = lp.tools.Action; const action = std.meta.stringToEnum(Action, tool_name) orelse return null; - const parsed = std.json.parseFromSliceLeaky(std.json.Value, arena, arguments, .{}) catch return null; - const obj = switch (parsed) { + const obj = switch (arguments) { .object => |o| o, else => return null, }; diff --git a/src/mcp/tools.zig b/src/mcp/tools.zig index 9b4fa964..27a28d09 100644 --- a/src/mcp/tools.zig +++ b/src/mcp/tools.zig @@ -154,7 +154,7 @@ fn dispatchBrowserTool( // JS errors are returned as isError tool results, not protocol errors if (action == .eval) { const result = browser_tools.callEval(arena, server.session, &server.node_registry, arguments); - if (!result.is_error) recordIfActive(server, arena, name, arguments); + if (!result.is_error) recordIfActive(server, name, arguments); const content = [_]protocol.TextContent([]const u8){.{ .text = result.text }}; return server.transport.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content, .isError = result.is_error }); } @@ -168,7 +168,7 @@ fn dispatchBrowserTool( return server.transport.sendError(id, code, @errorName(err)); }; - recordIfActive(server, arena, name, arguments); + recordIfActive(server, name, arguments); const content = [_]protocol.TextContent([]const u8){.{ .text = result }}; try server.transport.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content }); @@ -177,11 +177,10 @@ fn dispatchBrowserTool( /// If a recorder is active and the (name, args) pair maps to a PandaScript /// Command, append it to the recording. Tools without a Command mapping /// (tree, markdown, findElement, etc.) are silently skipped. -fn recordIfActive(server: *Server, arena: std.mem.Allocator, name: []const u8, arguments: ?std.json.Value) void { +fn recordIfActive(server: *Server, name: []const u8, arguments: ?std.json.Value) void { if (server.recorder == null) return; const args_value = arguments orelse return; - const args_json = Command.stringifyJson(arena, args_value); - const cmd = Command.fromToolCall(arena, name, args_json) orelse return; + const cmd = Command.fromToolCallValue(name, args_value) orelse return; server.recorder.?.record(cmd); server.record_lines += 1; } @@ -203,11 +202,15 @@ fn handleRecordStart(server: *Server, arena: std.mem.Allocator, id: std.json.Val const path_owned = server.allocator.dupe(u8, args.path) catch return sendErrorContent(server, id, "out of memory"); errdefer server.allocator.free(path_owned); + const msg = std.fmt.allocPrint(arena, "recording started: {s}", .{path_owned}) catch { + server.allocator.free(path_owned); + return sendErrorContent(server, id, "out of memory"); + }; + server.recorder = Recorder.init(server.allocator, path_owned); server.record_path = path_owned; server.record_lines = 0; - const msg = std.fmt.allocPrint(arena, "recording started: {s}", .{path_owned}) catch return; const content = [_]protocol.TextContent([]const u8){.{ .text = msg }}; try server.transport.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content }); } @@ -219,13 +222,16 @@ fn handleRecordStop(server: *Server, arena: std.mem.Allocator, id: std.json.Valu const path = server.record_path.?; const lines = server.record_lines; + // Build the response before nulling state so an allocPrint failure doesn't + // strand `path` (record_path = null would hide it from Server.deinit). + const msg = std.fmt.allocPrint(arena, "recording stopped: {s} ({d} line(s) written)", .{ path, lines }) catch + return sendErrorContent(server, id, "out of memory"); + var r = server.recorder.?; r.deinit(); server.recorder = null; server.record_path = null; server.record_lines = 0; - - const msg = std.fmt.allocPrint(arena, "recording stopped: {s} ({d} line(s) written)", .{ path, lines }) catch return; server.allocator.free(path); const content = [_]protocol.TextContent([]const u8){.{ .text = msg }}; @@ -236,10 +242,9 @@ fn handleRecordComment(server: *Server, arena: std.mem.Allocator, id: std.json.V if (server.recorder == null) { return sendErrorContent(server, id, "no recording is active"); } - _ = arena; const args_value = arguments orelse return server.transport.sendError(id, .InvalidParams, "missing arguments"); const Args = struct { text: []const u8 }; - const args = std.json.parseFromValueLeaky(Args, server.allocator, args_value, .{ .ignore_unknown_fields = true }) catch { + const args = std.json.parseFromValueLeaky(Args, arena, args_value, .{ .ignore_unknown_fields = true }) catch { return server.transport.sendError(id, .InvalidParams, "expected { text: string }"); }; @@ -352,16 +357,21 @@ fn handleScriptHeal(server: *Server, arena: std.mem.Allocator, id: std.json.Valu var splices = arena.alloc(script.Replacement, args.replacements.len) catch return sendErrorContent(server, id, "out of memory"); for (args.replacements, 0..) |spec, i| { - const span = findLineSpan(content, spec.original_line) orelse { - const msg = std.fmt.allocPrint(arena, "original_line not found verbatim: `{s}`", .{spec.original_line}) catch "original_line not found"; + const span = findLineSpan(content, spec.original_line) catch |err| { + const reason: []const u8 = switch (err) { + error.NotFound => "original_line not found verbatim", + error.Ambiguous => "original_line matches more than one line; make it unique to disambiguate", + }; + const msg = std.fmt.allocPrint(arena, "{s}: `{s}`", .{ reason, spec.original_line }) catch reason; return sendErrorContent(server, id, msg); }; var aw: std.Io.Writer.Allocating = .init(arena); - aw.writer.print("# [Auto-healed] Original: {s}\n", .{spec.original_line}) catch return sendErrorContent(server, id, "out of memory formatting heal header"); + aw.writer.print("# [Auto-healed] Original: {s}\n", .{spec.original_line}) catch |err| + return sendErrorContent(server, id, @errorName(err)); for (spec.replacement_lines) |rl| { - aw.writer.writeAll(rl) catch return sendErrorContent(server, id, "out of memory writing replacement line"); - aw.writer.writeByte('\n') catch return sendErrorContent(server, id, "out of memory writing replacement line"); + aw.writer.writeAll(rl) catch |err| return sendErrorContent(server, id, @errorName(err)); + aw.writer.writeByte('\n') catch |err| return sendErrorContent(server, id, @errorName(err)); } splices[i] = .{ .original_span = span, .new_text = aw.written() }; @@ -380,19 +390,24 @@ fn handleScriptHeal(server: *Server, arena: std.mem.Allocator, id: std.json.Valu /// Find a line in `content` that exactly equals `line` (after trimming the /// trailing newline). Returns the slice covering the line plus its /// terminating `\n` if present, ready for `script.applyReplacements`. -fn findLineSpan(content: []const u8, line: []const u8) ?[]const u8 { +/// Errors if the line is missing or matches more than once — a duplicate +/// match would silently rewrite the wrong line and break +/// applyReplacements' non-overlapping invariant. +fn findLineSpan(content: []const u8, line: []const u8) error{ NotFound, Ambiguous }![]const u8 { var pos: usize = 0; + var found: ?[]const u8 = null; while (pos <= content.len) { const nl = std.mem.indexOfScalarPos(u8, content, pos, '\n') orelse content.len; const this_line = content[pos..nl]; if (std.mem.eql(u8, this_line, line)) { + if (found != null) return error.Ambiguous; const end = if (nl < content.len) nl + 1 else nl; - return content[pos..end]; + found = content[pos..end]; } - if (nl == content.len) return null; + if (nl == content.len) break; pos = nl + 1; } - return null; + return found orelse error.NotFound; } fn currentUrl(server: *Server) ![]const u8 { @@ -439,21 +454,26 @@ test "MCP - eval error reporting" { test "MCP - findLineSpan: exact match returns line + trailing newline" { const content = "GOTO https://x\nCLICK 'old'\nWAIT '.thanks'\n"; - const span = findLineSpan(content, "CLICK 'old'").?; + const span = try findLineSpan(content, "CLICK 'old'"); try std.testing.expectEqualStrings("CLICK 'old'\n", span); } -test "MCP - findLineSpan: no match returns null" { +test "MCP - findLineSpan: no match returns NotFound" { const content = "GOTO https://x\nCLICK 'a'\n"; - try std.testing.expect(findLineSpan(content, "CLICK 'b'") == null); + try std.testing.expectError(error.NotFound, findLineSpan(content, "CLICK 'b'")); } test "MCP - findLineSpan: last line without trailing newline" { const content = "GOTO https://x\nCLICK 'last'"; - const span = findLineSpan(content, "CLICK 'last'").?; + const span = try findLineSpan(content, "CLICK 'last'"); try std.testing.expectEqualStrings("CLICK 'last'", span); } +test "MCP - findLineSpan: duplicate line returns Ambiguous" { + const content = "CLICK 'go'\nWAIT '.x'\nCLICK 'go'\n"; + try std.testing.expectError(error.Ambiguous, findLineSpan(content, "CLICK 'go'")); +} + test "MCP - record_start rejects unsafe path" { defer testing.reset(); var out: std.io.Writer.Allocating = .init(testing.arena_allocator); diff --git a/src/script.zig b/src/script.zig index a8f353d2..02f80623 100644 --- a/src/script.zig +++ b/src/script.zig @@ -90,8 +90,11 @@ pub fn applyReplacements( replacements: []const Replacement, ) error{OutOfMemory}![]u8 { const content_base = @intFromPtr(content.ptr); + // Subtract before adding so intermediate arithmetic on usize cannot + // underflow when individual replacements shrink even though the net + // delta is positive. var total = content.len; - for (replacements) |r| total = total + r.new_text.len - r.original_span.len; + for (replacements) |r| total = total - r.original_span.len + r.new_text.len; var out: std.ArrayList(u8) = .empty; errdefer out.deinit(allocator); @@ -100,6 +103,7 @@ pub fn applyReplacements( for (replacements) |r| { const r_start = @intFromPtr(r.original_span.ptr) - content_base; const r_end = r_start + r.original_span.len; + std.debug.assert(r_start >= pos and r_end <= content.len); out.appendSliceAssumeCapacity(content[pos..r_start]); out.appendSliceAssumeCapacity(r.new_text); pos = r_end;