refactor: simplify mcp communication and tool dispatching

This commit is contained in:
Adrià Arrufat
2026-05-11 08:57:36 +02:00
parent ecc68f8780
commit 0338bf71af
6 changed files with 51 additions and 30 deletions

View File

@@ -736,6 +736,7 @@ fn attemptSelfHeal(self: *Self, arena: std.mem.Allocator, failed_command: []cons
return cmds;
}
self.messages.shrinkRetainingCapacity(msg_baseline);
break;
}
return null;
}

View File

@@ -525,8 +525,7 @@ pub fn printErrorFmt(_: *Self, comptime fmt: []const u8, args: anytype) void {
}
pub fn printInfo(self: *Self, msg: []const u8) void {
if (!self.isRepl() and !atLeast(self.verbosity, .medium)) return;
std.debug.print("{s}{s}{s}\n", .{ ansi.dim, msg, ansi.reset });
self.printInfoFmt("{s}", .{msg});
}
pub fn printInfoFmt(self: *Self, comptime fmt: []const u8, args: anytype) void {

View File

@@ -77,9 +77,17 @@ pub fn deinit(self: *Self) void {
self.allocator.destroy(self);
}
pub fn sendError(self: *Self, id: std.json.Value, code: protocol.ErrorCode, message: []const u8) !void {
return self.transport.sendError(id, code, message);
}
pub fn sendResult(self: *Self, id: std.json.Value, result: anytype) !void {
return self.transport.sendResult(id, result);
}
pub fn handleInitialize(self: *Self, req: protocol.Request) !void {
const id = req.id orelse return;
try self.transport.sendResult(id, protocol.InitializeResult{
try self.sendResult(id, protocol.InitializeResult{
.protocolVersion = @tagName(protocol.Version.default),
.capabilities = .{
.resources = .{},

View File

@@ -23,7 +23,7 @@ pub const resource_list = [_]protocol.Resource{
pub fn handleList(server: *Server, req: protocol.Request) !void {
const id = req.id orelse return;
try server.transport.sendResult(id, .{ .resources = &resource_list });
try server.sendResult(id, .{ .resources = &resource_list });
}
const ReadParams = struct {
@@ -74,20 +74,20 @@ const resource_map = std.StaticStringMap(ResourceUri).initComptime(.{
pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void {
if (req.params == null or req.id == null) {
return server.transport.sendError(req.id orelse .{ .integer = -1 }, .InvalidParams, "Missing params");
return server.sendError(req.id orelse .{ .integer = -1 }, .InvalidParams, "Missing params");
}
const req_id = req.id.?;
const params = std.json.parseFromValueLeaky(ReadParams, arena, req.params.?, .{ .ignore_unknown_fields = true }) catch {
return server.transport.sendError(req_id, .InvalidParams, "Invalid params");
return server.sendError(req_id, .InvalidParams, "Invalid params");
};
const uri = resource_map.get(params.uri) orelse {
return server.transport.sendError(req_id, .InvalidRequest, "Resource not found");
return server.sendError(req_id, .InvalidRequest, "Resource not found");
};
const frame = server.session.currentFrame() orelse {
return server.transport.sendError(req_id, .FrameNotLoaded, "Page not loaded");
return server.sendError(req_id, .FrameNotLoaded, "Page not loaded");
};
const format: Format = switch (uri) {
@@ -106,7 +106,7 @@ pub fn handleRead(server: *Server, arena: std.mem.Allocator, req: protocol.Reque
.text = .{ .frame = frame, .format = format },
}},
};
server.transport.sendResult(req_id, result) catch {
return server.transport.sendError(req_id, .InternalError, "Failed to serialize resource content");
server.sendResult(req_id, result) catch {
return server.sendError(req_id, .InternalError, "Failed to serialize resource content");
};
}

View File

@@ -22,7 +22,7 @@ pub fn processRequests(server: anytype, reader: *std.io.Reader) !void {
const buffered_line = reader.takeDelimiter('\n') catch |err| switch (err) {
error.StreamTooLong => {
log.err(.mcp, "Message too long", .{});
try server.transport.sendError(.null, .InvalidRequest, "Message too long");
try server.sendError(.null, .InvalidRequest, "Message too long");
continue;
},
else => return err,
@@ -62,13 +62,13 @@ pub fn handleMessage(server: anytype, arena: std.mem.Allocator, msg: []const u8)
.ignore_unknown_fields = true,
}) catch |err| {
log.warn(.mcp, "JSON Parse Error", .{ .err = err, .msg = msg });
try server.transport.sendError(.null, .ParseError, "Parse error");
try server.sendError(.null, .ParseError, "Parse error");
return;
};
const method = method_map.get(req.method) orelse {
if (req.id != null) {
try server.transport.sendError(req.id.?, .MethodNotFound, "Method not found");
try server.sendError(req.id.?, .MethodNotFound, "Method not found");
}
return;
};
@@ -88,13 +88,13 @@ fn handleOptional(server: anytype, req: protocol.Request, comptime method: []con
if (@hasDecl(@TypeOf(server.*), method)) {
try @call(.auto, @field(@TypeOf(server.*), method), .{server} ++ args);
} else if (req.id) |id| {
try server.transport.sendError(id, .MethodNotFound, "Method not supported");
try server.sendError(id, .MethodNotFound, "Method not supported");
}
}
fn handlePing(server: anytype, req: protocol.Request) !void {
const id = req.id orelse return;
try server.transport.sendResult(id, .{});
try server.sendResult(id, .{});
}
const Server = @import("Server.zig");

View File

@@ -112,25 +112,38 @@ const extra_tools = [_]protocol.Tool{
const all_tools = browser_tool_list ++ extra_tools;
/// Tools that bypass the browser-tool dispatch and have their own handlers.
const ExtraTool = enum {
record_start,
record_stop,
record_comment,
script_step,
script_heal,
};
pub fn handleList(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void {
_ = arena;
const id = req.id orelse return;
try server.transport.sendResult(id, .{ .tools = &all_tools });
try server.sendResult(id, .{ .tools = &all_tools });
}
pub fn handleCall(server: *Server, arena: std.mem.Allocator, req: protocol.Request) !void {
const id = req.id orelse return;
const params = req.params orelse return server.transport.sendError(id, .InvalidParams, "Missing params");
const params = req.params orelse return server.sendError(id, .InvalidParams, "Missing params");
const call_params = browser_tools.parseValue(protocol.CallParams, arena, params) catch {
return server.transport.sendError(id, .InvalidParams, "Invalid params");
return server.sendError(id, .InvalidParams, "Invalid params");
};
if (std.mem.eql(u8, call_params.name, "record_start")) return handleRecordStart(server, arena, id, call_params.arguments);
if (std.mem.eql(u8, call_params.name, "record_stop")) return handleRecordStop(server, arena, id);
if (std.mem.eql(u8, call_params.name, "record_comment")) return handleRecordComment(server, arena, id, call_params.arguments);
if (std.mem.eql(u8, call_params.name, "script_step")) return handleScriptStep(server, arena, id, call_params.arguments);
if (std.mem.eql(u8, call_params.name, "script_heal")) return handleScriptHeal(server, arena, id, call_params.arguments);
if (std.meta.stringToEnum(ExtraTool, call_params.name)) |tool| {
return switch (tool) {
.record_start => handleRecordStart(server, arena, id, call_params.arguments),
.record_stop => handleRecordStop(server, arena, id),
.record_comment => handleRecordComment(server, arena, id, call_params.arguments),
.script_step => handleScriptStep(server, arena, id, call_params.arguments),
.script_heal => handleScriptHeal(server, arena, id, call_params.arguments),
};
}
return dispatchBrowserTool(server, arena, id, call_params.name, call_params.arguments);
}
@@ -146,7 +159,7 @@ fn dispatchBrowserTool(
arguments: ?std.json.Value,
) !void {
const action = std.meta.stringToEnum(browser_tools.Action, name) orelse {
return server.transport.sendError(id, .MethodNotFound, "Tool not found");
return server.sendError(id, .MethodNotFound, "Tool not found");
};
// JS errors are returned as isError tool results, not protocol errors
@@ -162,7 +175,7 @@ fn dispatchBrowserTool(
error.NodeNotFound, error.InvalidParams => .InvalidParams,
error.NavigationFailed, error.InternalError, error.OutOfMemory => .InternalError,
};
return server.transport.sendError(id, code, @errorName(err));
return server.sendError(id, code, @errorName(err));
};
recordIfActive(server, name, arguments);
@@ -186,7 +199,7 @@ fn handleRecordStart(server: *Server, arena: std.mem.Allocator, id: std.json.Val
}
const Args = struct { path: []const u8 };
const args = browser_tools.parseArgs(Args, arena, arguments) catch {
return server.transport.sendError(id, .InvalidParams, "expected { path: string }");
return server.sendError(id, .InvalidParams, "expected { path: string }");
};
if (!script.isPathSafe(args.path)) {
@@ -228,7 +241,7 @@ fn handleRecordComment(server: *Server, arena: std.mem.Allocator, id: std.json.V
}
const Args = struct { text: []const u8 };
const args = browser_tools.parseArgs(Args, arena, arguments) catch {
return server.transport.sendError(id, .InvalidParams, "expected { text: string }");
return server.sendError(id, .InvalidParams, "expected { text: string }");
};
server.recorder.?.recordComment(args.text);
@@ -239,7 +252,7 @@ fn handleRecordComment(server: *Server, arena: std.mem.Allocator, id: std.json.V
fn handleScriptStep(server: *Server, arena: std.mem.Allocator, id: std.json.Value, arguments: ?std.json.Value) !void {
const Args = struct { line: []const u8 };
const args = browser_tools.parseArgs(Args, arena, arguments) catch {
return server.transport.sendError(id, .InvalidParams, "expected { line: string }");
return server.sendError(id, .InvalidParams, "expected { line: string }");
};
const cmd = Command.parse(args.line);
@@ -305,7 +318,7 @@ fn handleScriptHeal(server: *Server, arena: std.mem.Allocator, id: std.json.Valu
replacements: []const ReplacementSpec,
};
const args = browser_tools.parseArgs(Args, arena, arguments) catch {
return server.transport.sendError(id, .InvalidParams, "expected { path: string, replacements: [{ original_line, replacement_lines }] }");
return server.sendError(id, .InvalidParams, "expected { path: string, replacements: [{ original_line, replacement_lines }] }");
};
if (!script.isPathSafe(args.path)) {
@@ -367,7 +380,7 @@ fn findLineSpan(content: []const u8, line: []const u8) error{ NotFound, Ambiguou
fn sendToolResultText(server: *Server, id: std.json.Value, msg: []const u8, is_error: bool) !void {
const content = [_]protocol.TextContent([]const u8){.{ .text = msg }};
try server.transport.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content, .isError = is_error });
try server.sendResult(id, protocol.CallToolResult([]const u8){ .content = &content, .isError = is_error });
}
fn sendErrorContent(server: *Server, id: std.json.Value, msg: []const u8) !void {