diff --git a/src/Config.zig b/src/Config.zig index 9afc71f2..500e52b2 100644 --- a/src/Config.zig +++ b/src/Config.zig @@ -235,7 +235,7 @@ pub const Agent = struct { system_prompt: ?[:0]const u8 = null, repl: bool = true, script_file: ?[]const u8 = null, - record_file: ?[]const u8 = null, + save: bool = false, self_heal: bool = false, }; @@ -950,12 +950,8 @@ fn parseAgentArgs( continue; } - if (std.mem.eql(u8, "--run", opt)) { - const str = args.next() orelse { - log.fatal(.app, "missing argument value", .{ .arg = opt }); - return error.InvalidArgument; - }; - result.script_file = str; + if (std.mem.eql(u8, "--save", opt)) { + result.save = true; continue; } @@ -977,9 +973,8 @@ fn parseAgentArgs( continue; } - // Positional argument: recording file for REPL mode (e.g. `agent --repl my_workflow.panda`) if (!std.mem.startsWith(u8, opt, "-")) { - result.record_file = opt; + result.script_file = opt; continue; } diff --git a/src/agent/Agent.zig b/src/agent/Agent.zig index 2be981cb..af74edf9 100644 --- a/src/agent/Agent.zig +++ b/src/agent/Agent.zig @@ -19,7 +19,7 @@ const default_system_prompt = \\click links, and extract information. \\ \\When helping the user, navigate to relevant pages and extract information. - \\Use the semantic_tree or interactiveElements tools to understand page structure + \\Use the semanticTree or interactiveElements tools to understand page structure \\before clicking or filling forms. Be concise in your responses. \\ \\IMPORTANT RULES: @@ -75,11 +75,10 @@ tools: []const zenai.provider.Tool, model: []const u8, system_prompt: []const u8, script_file: ?[]const u8, -record_file: ?[]const u8, self_heal: bool, pub fn init(allocator: std.mem.Allocator, app: *App, opts: Config.Agent) !*Self { - const is_script_mode = opts.script_file != null; + const is_script_mode = opts.script_file != null and !opts.save; // API key is only required for REPL mode and self-healing const api_key: ?[:0]const u8 = getEnvApiKey(opts.provider) orelse if (!is_script_mode) { @@ -118,14 +117,13 @@ pub fn init(allocator: std.mem.Allocator, app: *App, opts: Config.Agent) !*Self .tool_executor = tool_executor, .terminal = Terminal.init(null), .cmd_executor = undefined, - .recorder = Recorder.init(opts.record_file), + .recorder = Recorder.init(if (opts.save) opts.script_file else null), .messages = .empty, .message_arena = std.heap.ArenaAllocator.init(allocator), .tools = tools, .model = opts.model orelse defaultModel(opts.provider), .system_prompt = opts.system_prompt orelse default_system_prompt, - .script_file = opts.script_file, - .record_file = opts.record_file, + .script_file = if (!opts.save) opts.script_file else null, .self_heal = opts.self_heal, }; @@ -151,8 +149,8 @@ pub fn deinit(self: *Self) void { } pub fn run(self: *Self) void { - if (self.script_file) |script_file| { - self.runScript(script_file); + if (self.script_file) |path| { + self.runScript(path); } else { self.runRepl(); } @@ -183,14 +181,12 @@ fn runRepl(self: *Self) void { .comment => continue, .login => { self.processUserMessage(login_prompt, line) catch |err| { - const msg = std.fmt.allocPrint(self.allocator, "LOGIN failed: {s}", .{@errorName(err)}) catch "LOGIN failed"; - self.terminal.printError(msg); + self.printAllocError("LOGIN failed: {s}", .{@errorName(err)}); }; }, .accept_cookies => { self.processUserMessage(accept_cookies_prompt, line) catch |err| { - const msg = std.fmt.allocPrint(self.allocator, "ACCEPT_COOKIES failed: {s}", .{@errorName(err)}) catch "ACCEPT_COOKIES failed"; - self.terminal.printError(msg); + self.printAllocError("ACCEPT_COOKIES failed: {s}", .{@errorName(err)}); }; }, .natural_language => { @@ -198,8 +194,7 @@ fn runRepl(self: *Self) void { if (std.mem.eql(u8, line, "quit")) break; self.processUserMessage(line, line) catch |err| { - const msg = std.fmt.allocPrint(self.allocator, "Request failed: {s}", .{@errorName(err)}) catch "Request failed"; - self.terminal.printError(msg); + self.printAllocError("Request failed: {s}", .{@errorName(err)}); }; }, else => { @@ -212,17 +207,24 @@ fn runRepl(self: *Self) void { self.terminal.printInfo("Goodbye!"); } +fn printAllocError(self: *Self, comptime fmt: []const u8, args: anytype) void { + const msg = std.fmt.allocPrint(self.allocator, fmt, args) catch { + self.terminal.printError(fmt); + return; + }; + defer self.allocator.free(msg); + self.terminal.printError(msg); +} + fn runScript(self: *Self, path: []const u8) void { const file = std.fs.cwd().openFile(path, .{}) catch |err| { - const msg = std.fmt.allocPrint(self.allocator, "Failed to open script '{s}': {s}", .{ path, @errorName(err) }) catch "Failed to open script"; - self.terminal.printError(msg); + self.printAllocError("Failed to open script '{s}': {s}", .{ path, @errorName(err) }); return; }; defer file.close(); const content = file.readToEndAlloc(self.allocator, 10 * 1024 * 1024) catch |err| { - const msg = std.fmt.allocPrint(self.allocator, "Failed to read script: {s}", .{@errorName(err)}) catch "Failed to read script"; - self.terminal.printError(msg); + self.printAllocError("Failed to read script: {s}", .{@errorName(err)}); return; }; defer self.allocator.free(content); @@ -251,28 +253,25 @@ fn runScript(self: *Self, path: []const u8) void { continue; }, .natural_language => { - const msg = std.fmt.allocPrint(self.allocator, "line {d}: unrecognized command: {s}", .{ entry.line_num, entry.raw_line }) catch "unrecognized command"; - self.terminal.printError(msg); + self.printAllocError("line {d}: unrecognized command: {s}", .{ entry.line_num, entry.raw_line }); return; }, .login, .accept_cookies => { // High-level commands require LLM if (self.ai_client == null) { - const msg = std.fmt.allocPrint(self.allocator, "line {d}: {s} requires an API key for LLM resolution", .{ + self.printAllocError("line {d}: {s} requires an API key for LLM resolution", .{ entry.line_num, entry.raw_line, - }) catch "LLM required"; - self.terminal.printError(msg); + }); return; } const prompt = if (entry.command == .login) login_prompt else accept_cookies_prompt; self.processUserMessage(prompt, "") catch |err| { - const msg = std.fmt.allocPrint(self.allocator, "line {d}: {s} failed: {s}", .{ + self.printAllocError("line {d}: {s} failed: {s}", .{ entry.line_num, entry.raw_line, @errorName(err), - }) catch "command failed"; - self.terminal.printError(msg); + }); return; }; }, @@ -297,11 +296,10 @@ fn runScript(self: *Self, path: []const u8) void { continue; } } - const msg = std.fmt.allocPrint(self.allocator, "line {d}: command failed: {s}", .{ + self.printAllocError("line {d}: command failed: {s}", .{ entry.line_num, entry.raw_line, - }) catch "command failed"; - self.terminal.printError(msg); + }); return; } }, diff --git a/src/agent/Command.zig b/src/agent/Command.zig index 328ac5d7..8fb2b1c4 100644 --- a/src/agent/Command.zig +++ b/src/agent/Command.zig @@ -76,39 +76,39 @@ pub fn parse(line: []const u8) Command { const cmd_word = trimmed[0..cmd_end]; const rest = std.mem.trim(u8, trimmed[cmd_end..], &std.ascii.whitespace); - if (eqlIgnoreCase(cmd_word, "GOTO")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "GOTO")) { if (rest.len == 0) return .{ .natural_language = trimmed }; return .{ .goto = rest }; } - if (eqlIgnoreCase(cmd_word, "CLICK")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "CLICK")) { const arg = extractQuoted(rest) orelse rest; if (arg.len == 0) return .{ .natural_language = trimmed }; return .{ .click = arg }; } - if (eqlIgnoreCase(cmd_word, "TYPE")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "TYPE")) { const first = extractQuotedWithRemainder(rest) orelse return .{ .natural_language = trimmed }; const second_arg = std.mem.trim(u8, first.remainder, &std.ascii.whitespace); const second = extractQuoted(second_arg) orelse return .{ .natural_language = trimmed }; return .{ .type_cmd = .{ .selector = first.value, .value = second } }; } - if (eqlIgnoreCase(cmd_word, "WAIT")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "WAIT")) { const arg = extractQuoted(rest) orelse rest; if (arg.len == 0) return .{ .natural_language = trimmed }; return .{ .wait = arg }; } - if (eqlIgnoreCase(cmd_word, "TREE")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "TREE")) { return .{ .tree = {} }; } - if (eqlIgnoreCase(cmd_word, "MARKDOWN") or eqlIgnoreCase(cmd_word, "MD")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "MARKDOWN") or std.ascii.eqlIgnoreCase(cmd_word, "MD")) { return .{ .markdown = {} }; } - if (eqlIgnoreCase(cmd_word, "EXTRACT")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "EXTRACT")) { const selector = extractQuoted(rest) orelse { if (rest.len == 0) return .{ .natural_language = trimmed }; return .{ .extract = .{ .selector = rest, .file = null } }; @@ -123,21 +123,21 @@ pub fn parse(line: []const u8) Command { return .{ .extract = .{ .selector = selector, .file = null } }; } - if (eqlIgnoreCase(cmd_word, "EVAL")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "EVAL")) { if (rest.len == 0) return .{ .natural_language = trimmed }; const arg = extractQuoted(rest) orelse rest; return .{ .eval_js = arg }; } - if (eqlIgnoreCase(cmd_word, "LOGIN")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "LOGIN")) { return .{ .login = {} }; } - if (eqlIgnoreCase(cmd_word, "ACCEPT_COOKIES") or eqlIgnoreCase(cmd_word, "ACCEPT-COOKIES")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "ACCEPT_COOKIES") or std.ascii.eqlIgnoreCase(cmd_word, "ACCEPT-COOKIES")) { return .{ .accept_cookies = {} }; } - if (eqlIgnoreCase(cmd_word, "EXIT")) { + if (std.ascii.eqlIgnoreCase(cmd_word, "EXIT")) { return .{ .exit = {} }; } @@ -202,7 +202,7 @@ pub const ScriptIterator = struct { fn isEvalTripleQuote(line: []const u8) bool { const cmd_end = std.mem.indexOfAny(u8, line, &std.ascii.whitespace) orelse line.len; const cmd_word = line[0..cmd_end]; - if (!eqlIgnoreCase(cmd_word, "EVAL")) return false; + if (!std.ascii.eqlIgnoreCase(cmd_word, "EVAL")) return false; const rest = std.mem.trim(u8, line[cmd_end..], &std.ascii.whitespace); return std.mem.startsWith(u8, rest, "\"\"\"") or std.mem.startsWith(u8, rest, "'''"); } @@ -248,14 +248,6 @@ fn extractQuoted(s: []const u8) ?[]const u8 { return result.value; } -pub fn eqlIgnoreCase(a: []const u8, comptime upper: []const u8) bool { - if (a.len != upper.len) return false; - for (a, upper) |ac, uc| { - if (std.ascii.toUpper(ac) != uc) return false; - } - return true; -} - // --- Tests --- test "parse GOTO" { diff --git a/src/agent/CommandExecutor.zig b/src/agent/CommandExecutor.zig index c9f6d5a3..a266086c 100644 --- a/src/agent/CommandExecutor.zig +++ b/src/agent/CommandExecutor.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const browser_tools = @import("lightpanda").tools; const Command = @import("Command.zig"); const ToolExecutor = @import("ToolExecutor.zig"); const Terminal = @import("Terminal.zig"); @@ -126,43 +127,7 @@ fn execExtract(self: *Self, arena: std.mem.Allocator, args: Command.ExtractArgs) return .{ .output = result, .failed = false }; } -/// Substitute $VAR_NAME references with values from the environment. -fn substituteEnvVars(arena: std.mem.Allocator, input: []const u8) []const u8 { - // Quick scan: if no $ present, return as-is - if (std.mem.indexOfScalar(u8, input, '$') == null) return input; - - var result: std.ArrayListUnmanaged(u8) = .empty; - var i: usize = 0; - while (i < input.len) { - if (input[i] == '$') { - // Find the end of the variable name (alphanumeric + underscore) - const var_start = i + 1; - var var_end = var_start; - while (var_end < input.len and (std.ascii.isAlphanumeric(input[var_end]) or input[var_end] == '_')) { - var_end += 1; - } - if (var_end > var_start) { - const var_name = input[var_start..var_end]; - // We need a null-terminated string for getenv - const var_name_z = arena.dupeZ(u8, var_name) catch return input; - if (std.posix.getenv(var_name_z)) |env_val| { - result.appendSlice(arena, env_val) catch return input; - } else { - // Keep the original $VAR if not found - result.appendSlice(arena, input[i..var_end]) catch return input; - } - i = var_end; - } else { - result.append(arena, '$') catch return input; - i += 1; - } - } else { - result.append(arena, input[i]) catch return input; - i += 1; - } - } - return result.toOwnedSlice(arena) catch input; -} +const substituteEnvVars = browser_tools.substituteEnvVars; /// Escape a string for safe interpolation inside a JS double-quoted string literal. fn escapeJs(arena: std.mem.Allocator, input: []const u8) []const u8 { diff --git a/src/browser/tools.zig b/src/browser/tools.zig index ba86dc2b..110f3200 100644 --- a/src/browser/tools.zig +++ b/src/browser/tools.zig @@ -475,12 +475,10 @@ fn execNodeDetails(session: *lp.Session, registry: *CDPNode.Registry, arena: std const Params = struct { backendNodeId: CDPNode.Id }; const args = parseArgsOrErr(Params, arena, arguments) orelse return ToolError.InvalidParams; - _ = session.currentPage() orelse return ToolError.PageNotLoaded; + const page = session.currentPage() orelse return ToolError.PageNotLoaded; const node = registry.lookup_by_id.get(args.backendNodeId) orelse return ToolError.NodeNotFound; - - const page = session.currentPage().?; const details = lp.SemanticTree.getNodeDetails(arena, node.dom, registry, page) catch return ToolError.InternalError; @@ -802,7 +800,7 @@ fn execFindElement(session: *lp.Session, registry: *CDPNode.Registry, arena: std } if (args.name) |name| { const el_name = el.name orelse continue; - if (!containsIgnoreCase(el_name, name)) continue; + if (std.ascii.indexOfIgnoreCase(el_name, name) == null) continue; } matches.append(arena, el) catch return ToolError.InternalError; } @@ -914,7 +912,7 @@ fn parseArgsOrErr(comptime T: type, arena: std.mem.Allocator, arguments: ?std.js } /// Substitute $VAR_NAME references with values from the environment. -fn substituteEnvVars(arena: std.mem.Allocator, input: []const u8) []const u8 { +pub fn substituteEnvVars(arena: std.mem.Allocator, input: []const u8) []const u8 { if (std.mem.indexOfScalar(u8, input, '$') == null) return input; var result: std.ArrayListUnmanaged(u8) = .empty; @@ -945,13 +943,3 @@ fn substituteEnvVars(arena: std.mem.Allocator, input: []const u8) []const u8 { } return result.toOwnedSlice(arena) catch input; } - -pub fn containsIgnoreCase(haystack: []const u8, needle: []const u8) bool { - if (needle.len > haystack.len) return false; - if (needle.len == 0) return true; - const end = haystack.len - needle.len + 1; - for (0..end) |i| { - if (std.ascii.eqlIgnoreCase(haystack[i..][0..needle.len], needle)) return true; - } - return false; -}