From bfe223c8ad5e3bab44385d4f52cb217255f2af6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= Date: Wed, 8 Apr 2026 08:33:08 +0200 Subject: [PATCH] agent: simplify client init and tool execution --- src/agent/Agent.zig | 28 ++++------------ src/agent/CommandExecutor.zig | 61 ++++++++++++++++++----------------- src/agent/Terminal.zig | 6 ++-- 3 files changed, 42 insertions(+), 53 deletions(-) diff --git a/src/agent/Agent.zig b/src/agent/Agent.zig index b985ae6b..4030181e 100644 --- a/src/agent/Agent.zig +++ b/src/agent/Agent.zig @@ -99,27 +99,13 @@ pub fn init(allocator: std.mem.Allocator, app: *App, opts: Config.Agent) !*Self errdefer allocator.destroy(self); const ai_client: ?AiClient = if (api_key) |key| switch (opts.provider) { - .anthropic => blk: { - const client = try allocator.create(zenai.anthropic.Client); - client.* = zenai.anthropic.Client.init(allocator, key, if (opts.base_url) |url| .{ .base_url = url } else .{}); - break :blk .{ .anthropic = client }; - }, - .openai => blk: { - const client = try allocator.create(zenai.openai.Client); - client.* = zenai.openai.Client.init(allocator, key, if (opts.base_url) |url| .{ .base_url = url } else .{}); - break :blk .{ .openai = client }; - }, - .gemini => blk: { - const client = try allocator.create(zenai.gemini.Client); - client.* = zenai.gemini.Client.init(allocator, key, if (opts.base_url) |url| .{ .base_url = url } else .{}); - break :blk .{ .gemini = client }; - }, - .ollama => blk: { - const client = try allocator.create(zenai.openai.Client); - client.* = zenai.openai.Client.init(allocator, key, .{ - .base_url = opts.base_url orelse "http://localhost:11434/v1", - }); - break :blk .{ .ollama = client }; + inline else => |tag| blk: { + const ClientPtr = @FieldType(AiClient, @tagName(tag)); + const Client = @typeInfo(ClientPtr).pointer.child; + const client = try allocator.create(Client); + const url = opts.base_url orelse if (tag == .ollama) "http://localhost:11434/v1" else null; + client.* = Client.init(allocator, key, if (url) |u| .{ .base_url = u } else .{}); + break :blk @unionInit(AiClient, @tagName(tag), client); }, } else null; diff --git a/src/agent/CommandExecutor.zig b/src/agent/CommandExecutor.zig index 68d60fde..55d9cbfb 100644 --- a/src/agent/CommandExecutor.zig +++ b/src/agent/CommandExecutor.zig @@ -24,22 +24,17 @@ pub const ExecResult = struct { /// Execute a command and return the result with success/failure status. pub fn executeWithResult(self: *Self, a: std.mem.Allocator, cmd: Command.Command) ExecResult { - const result = switch (cmd) { + return switch (cmd) { .goto => |url| self.execGoto(a, url), .click => |target| self.execClick(a, target), .type_cmd => |args| self.execType(a, args), - .wait => |selector| self.tool_executor.call(a, "waitForSelector", buildJson(a, .{ .selector = selector })) catch "Error: wait failed", - .tree => self.tool_executor.call(a, "semantic_tree", "") catch "Error: tree failed", - .markdown => self.tool_executor.call(a, "markdown", "") catch "Error: markdown failed", + .wait => |selector| self.callTool(a, "waitForSelector", buildJson(a, .{ .selector = selector })), + .tree => self.callTool(a, "semantic_tree", ""), + .markdown => self.callTool(a, "markdown", ""), .extract => |args| self.execExtract(a, args), - .eval_js => |script| self.tool_executor.call(a, "evaluate", buildJson(a, .{ .script = script })) catch "Error: eval failed", + .eval_js => |script| self.callTool(a, "evaluate", buildJson(a, .{ .script = script })), .exit, .natural_language, .comment, .login, .accept_cookies => unreachable, }; - - return .{ - .output = result, - .failed = std.mem.startsWith(u8, result, "Error:"), - }; } pub fn execute(self: *Self, cmd: Command.Command) void { @@ -52,29 +47,35 @@ pub fn execute(self: *Self, cmd: Command.Command) void { std.debug.print("\n", .{}); } -fn execGoto(self: *Self, arena: std.mem.Allocator, raw_url: []const u8) []const u8 { - const url = substituteEnvVars(arena, raw_url); - return self.tool_executor.call(arena, "goto", buildJson(arena, .{ .url = url })) catch "Error: goto failed"; +fn callTool(self: *Self, arena: std.mem.Allocator, tool_name: []const u8, arguments_json: []const u8) ExecResult { + if (self.tool_executor.call(arena, tool_name, arguments_json)) |output| + return .{ .output = output, .failed = false } + else |err| + return .{ .output = std.fmt.allocPrint(arena, "{s} failed: {s}", .{ tool_name, @errorName(err) }) catch "tool failed", .failed = true }; } -fn execClick(self: *Self, arena: std.mem.Allocator, raw_target: []const u8) []const u8 { +fn execGoto(self: *Self, arena: std.mem.Allocator, raw_url: []const u8) ExecResult { + const url = substituteEnvVars(arena, raw_url); + return self.callTool(arena, "goto", buildJson(arena, .{ .url = url })); +} + +fn execClick(self: *Self, arena: std.mem.Allocator, raw_target: []const u8) ExecResult { const target = substituteEnvVars(arena, raw_target); - // Try as CSS selector via interactiveElements + click // First get interactive elements to find the target const elements_result = self.tool_executor.call(arena, "interactiveElements", "") catch - return "Error: failed to get interactive elements"; + return .{ .output = "failed to get interactive elements", .failed = true }; // Try to find a backendNodeId by searching the elements result for the target text if (findNodeIdByText(arena, elements_result, target)) |node_id| { const args = std.fmt.allocPrint(arena, "{{\"backendNodeId\":{d}}}", .{node_id}) catch - return "Error: failed to build click args"; - return self.tool_executor.call(arena, "click", args) catch "Error: click failed"; + return .{ .output = "failed to build click args", .failed = true }; + return self.callTool(arena, "click", args); } - return "Error: could not find element matching the target"; + return .{ .output = "could not find element matching the target", .failed = true }; } -fn execType(self: *Self, arena: std.mem.Allocator, args: Command.TypeArgs) []const u8 { +fn execType(self: *Self, arena: std.mem.Allocator, args: Command.TypeArgs) ExecResult { const selector = escapeJs(arena, substituteEnvVars(arena, args.selector)); const value = escapeJs(arena, substituteEnvVars(arena, args.value)); @@ -87,38 +88,38 @@ fn execType(self: *Self, arena: std.mem.Allocator, args: Command.TypeArgs) []con \\ el.dispatchEvent(new Event("input", {{bubbles: true}})); \\ return "Typed into " + el.tagName; \\}})() - , .{ selector, value }) catch return "Error: failed to build type script"; + , .{ selector, value }) catch return .{ .output = "failed to build type script", .failed = true }; - return self.tool_executor.call(arena, "evaluate", buildJson(arena, .{ .script = script })) catch "Error: type failed"; + return self.callTool(arena, "evaluate", buildJson(arena, .{ .script = script })); } -fn execExtract(self: *Self, arena: std.mem.Allocator, args: Command.ExtractArgs) []const u8 { +fn execExtract(self: *Self, arena: std.mem.Allocator, args: Command.ExtractArgs) ExecResult { const selector = escapeJs(arena, substituteEnvVars(arena, args.selector)); const script = std.fmt.allocPrint(arena, \\JSON.stringify(Array.from(document.querySelectorAll("{s}")).map(el => el.textContent.trim())) - , .{selector}) catch return "Error: failed to build extract script"; + , .{selector}) catch return .{ .output = "failed to build extract script", .failed = true }; const result = self.tool_executor.call(arena, "evaluate", buildJson(arena, .{ .script = script })) catch - return "Error: extract failed"; + return .{ .output = "extract failed", .failed = true }; if (args.file) |raw_file| { const file = sanitizePath(raw_file) orelse { self.terminal.printError("Invalid output path: must be relative and not traverse above working directory"); - return result; + return .{ .output = result, .failed = false }; }; std.fs.cwd().writeFile(.{ .sub_path = file, .data = result, }) catch { self.terminal.printError("Failed to write to file"); - return result; + return .{ .output = result, .failed = false }; }; - const msg = std.fmt.allocPrint(arena, "Extracted to {s}", .{file}) catch return "Extracted."; - return msg; + const msg = std.fmt.allocPrint(arena, "Extracted to {s}", .{file}) catch "Extracted."; + return .{ .output = msg, .failed = false }; } - return result; + return .{ .output = result, .failed = false }; } /// Substitute $VAR_NAME references with values from the environment. diff --git a/src/agent/Terminal.zig b/src/agent/Terminal.zig index 1d0e8beb..c4c78ea2 100644 --- a/src/agent/Terminal.zig +++ b/src/agent/Terminal.zig @@ -49,9 +49,11 @@ pub fn printToolCall(_: *Self, name: []const u8, args: []const u8) void { std.debug.print("\n{s}{s}[tool: {s}]{s} {s}\n", .{ ansi_dim, ansi_cyan, name, ansi_reset, args }); } +const max_result_display_len = 500; + pub fn printToolResult(_: *Self, name: []const u8, result: []const u8) void { - const truncated = if (result.len > 500) result[0..500] else result; - const ellipsis: []const u8 = if (result.len > 500) "..." else ""; + const truncated = result[0..@min(result.len, max_result_display_len)]; + const ellipsis: []const u8 = if (result.len > max_result_display_len) "..." else ""; std.debug.print("{s}{s}[result: {s}]{s} {s}{s}\n", .{ ansi_dim, ansi_green, name, ansi_reset, truncated, ellipsis }); }