agent: simplify client init and tool execution

This commit is contained in:
Adrià Arrufat
2026-04-08 08:33:08 +02:00
parent 1aca921327
commit bfe223c8ad
3 changed files with 42 additions and 53 deletions

View File

@@ -99,27 +99,13 @@ pub fn init(allocator: std.mem.Allocator, app: *App, opts: Config.Agent) !*Self
errdefer allocator.destroy(self);
const ai_client: ?AiClient = if (api_key) |key| switch (opts.provider) {
.anthropic => blk: {
const client = try allocator.create(zenai.anthropic.Client);
client.* = zenai.anthropic.Client.init(allocator, key, if (opts.base_url) |url| .{ .base_url = url } else .{});
break :blk .{ .anthropic = client };
},
.openai => blk: {
const client = try allocator.create(zenai.openai.Client);
client.* = zenai.openai.Client.init(allocator, key, if (opts.base_url) |url| .{ .base_url = url } else .{});
break :blk .{ .openai = client };
},
.gemini => blk: {
const client = try allocator.create(zenai.gemini.Client);
client.* = zenai.gemini.Client.init(allocator, key, if (opts.base_url) |url| .{ .base_url = url } else .{});
break :blk .{ .gemini = client };
},
.ollama => blk: {
const client = try allocator.create(zenai.openai.Client);
client.* = zenai.openai.Client.init(allocator, key, .{
.base_url = opts.base_url orelse "http://localhost:11434/v1",
});
break :blk .{ .ollama = client };
inline else => |tag| blk: {
const ClientPtr = @FieldType(AiClient, @tagName(tag));
const Client = @typeInfo(ClientPtr).pointer.child;
const client = try allocator.create(Client);
const url = opts.base_url orelse if (tag == .ollama) "http://localhost:11434/v1" else null;
client.* = Client.init(allocator, key, if (url) |u| .{ .base_url = u } else .{});
break :blk @unionInit(AiClient, @tagName(tag), client);
},
} else null;

View File

@@ -24,22 +24,17 @@ pub const ExecResult = struct {
/// Execute a command and return the result with success/failure status.
pub fn executeWithResult(self: *Self, a: std.mem.Allocator, cmd: Command.Command) ExecResult {
const result = switch (cmd) {
return switch (cmd) {
.goto => |url| self.execGoto(a, url),
.click => |target| self.execClick(a, target),
.type_cmd => |args| self.execType(a, args),
.wait => |selector| self.tool_executor.call(a, "waitForSelector", buildJson(a, .{ .selector = selector })) catch "Error: wait failed",
.tree => self.tool_executor.call(a, "semantic_tree", "") catch "Error: tree failed",
.markdown => self.tool_executor.call(a, "markdown", "") catch "Error: markdown failed",
.wait => |selector| self.callTool(a, "waitForSelector", buildJson(a, .{ .selector = selector })),
.tree => self.callTool(a, "semantic_tree", ""),
.markdown => self.callTool(a, "markdown", ""),
.extract => |args| self.execExtract(a, args),
.eval_js => |script| self.tool_executor.call(a, "evaluate", buildJson(a, .{ .script = script })) catch "Error: eval failed",
.eval_js => |script| self.callTool(a, "evaluate", buildJson(a, .{ .script = script })),
.exit, .natural_language, .comment, .login, .accept_cookies => unreachable,
};
return .{
.output = result,
.failed = std.mem.startsWith(u8, result, "Error:"),
};
}
pub fn execute(self: *Self, cmd: Command.Command) void {
@@ -52,29 +47,35 @@ pub fn execute(self: *Self, cmd: Command.Command) void {
std.debug.print("\n", .{});
}
fn execGoto(self: *Self, arena: std.mem.Allocator, raw_url: []const u8) []const u8 {
const url = substituteEnvVars(arena, raw_url);
return self.tool_executor.call(arena, "goto", buildJson(arena, .{ .url = url })) catch "Error: goto failed";
fn callTool(self: *Self, arena: std.mem.Allocator, tool_name: []const u8, arguments_json: []const u8) ExecResult {
if (self.tool_executor.call(arena, tool_name, arguments_json)) |output|
return .{ .output = output, .failed = false }
else |err|
return .{ .output = std.fmt.allocPrint(arena, "{s} failed: {s}", .{ tool_name, @errorName(err) }) catch "tool failed", .failed = true };
}
fn execClick(self: *Self, arena: std.mem.Allocator, raw_target: []const u8) []const u8 {
fn execGoto(self: *Self, arena: std.mem.Allocator, raw_url: []const u8) ExecResult {
const url = substituteEnvVars(arena, raw_url);
return self.callTool(arena, "goto", buildJson(arena, .{ .url = url }));
}
fn execClick(self: *Self, arena: std.mem.Allocator, raw_target: []const u8) ExecResult {
const target = substituteEnvVars(arena, raw_target);
// Try as CSS selector via interactiveElements + click
// First get interactive elements to find the target
const elements_result = self.tool_executor.call(arena, "interactiveElements", "") catch
return "Error: failed to get interactive elements";
return .{ .output = "failed to get interactive elements", .failed = true };
// Try to find a backendNodeId by searching the elements result for the target text
if (findNodeIdByText(arena, elements_result, target)) |node_id| {
const args = std.fmt.allocPrint(arena, "{{\"backendNodeId\":{d}}}", .{node_id}) catch
return "Error: failed to build click args";
return self.tool_executor.call(arena, "click", args) catch "Error: click failed";
return .{ .output = "failed to build click args", .failed = true };
return self.callTool(arena, "click", args);
}
return "Error: could not find element matching the target";
return .{ .output = "could not find element matching the target", .failed = true };
}
fn execType(self: *Self, arena: std.mem.Allocator, args: Command.TypeArgs) []const u8 {
fn execType(self: *Self, arena: std.mem.Allocator, args: Command.TypeArgs) ExecResult {
const selector = escapeJs(arena, substituteEnvVars(arena, args.selector));
const value = escapeJs(arena, substituteEnvVars(arena, args.value));
@@ -87,38 +88,38 @@ fn execType(self: *Self, arena: std.mem.Allocator, args: Command.TypeArgs) []con
\\ el.dispatchEvent(new Event("input", {{bubbles: true}}));
\\ return "Typed into " + el.tagName;
\\}})()
, .{ selector, value }) catch return "Error: failed to build type script";
, .{ selector, value }) catch return .{ .output = "failed to build type script", .failed = true };
return self.tool_executor.call(arena, "evaluate", buildJson(arena, .{ .script = script })) catch "Error: type failed";
return self.callTool(arena, "evaluate", buildJson(arena, .{ .script = script }));
}
fn execExtract(self: *Self, arena: std.mem.Allocator, args: Command.ExtractArgs) []const u8 {
fn execExtract(self: *Self, arena: std.mem.Allocator, args: Command.ExtractArgs) ExecResult {
const selector = escapeJs(arena, substituteEnvVars(arena, args.selector));
const script = std.fmt.allocPrint(arena,
\\JSON.stringify(Array.from(document.querySelectorAll("{s}")).map(el => el.textContent.trim()))
, .{selector}) catch return "Error: failed to build extract script";
, .{selector}) catch return .{ .output = "failed to build extract script", .failed = true };
const result = self.tool_executor.call(arena, "evaluate", buildJson(arena, .{ .script = script })) catch
return "Error: extract failed";
return .{ .output = "extract failed", .failed = true };
if (args.file) |raw_file| {
const file = sanitizePath(raw_file) orelse {
self.terminal.printError("Invalid output path: must be relative and not traverse above working directory");
return result;
return .{ .output = result, .failed = false };
};
std.fs.cwd().writeFile(.{
.sub_path = file,
.data = result,
}) catch {
self.terminal.printError("Failed to write to file");
return result;
return .{ .output = result, .failed = false };
};
const msg = std.fmt.allocPrint(arena, "Extracted to {s}", .{file}) catch return "Extracted.";
return msg;
const msg = std.fmt.allocPrint(arena, "Extracted to {s}", .{file}) catch "Extracted.";
return .{ .output = msg, .failed = false };
}
return result;
return .{ .output = result, .failed = false };
}
/// Substitute $VAR_NAME references with values from the environment.

View File

@@ -49,9 +49,11 @@ pub fn printToolCall(_: *Self, name: []const u8, args: []const u8) void {
std.debug.print("\n{s}{s}[tool: {s}]{s} {s}\n", .{ ansi_dim, ansi_cyan, name, ansi_reset, args });
}
const max_result_display_len = 500;
pub fn printToolResult(_: *Self, name: []const u8, result: []const u8) void {
const truncated = if (result.len > 500) result[0..500] else result;
const ellipsis: []const u8 = if (result.len > 500) "..." else "";
const truncated = result[0..@min(result.len, max_result_display_len)];
const ellipsis: []const u8 = if (result.len > max_result_display_len) "..." else "";
std.debug.print("{s}{s}[result: {s}]{s} {s}{s}\n", .{ ansi_dim, ansi_green, name, ansi_reset, truncated, ellipsis });
}