From db385cf4db06f7630fb8c328bc13d343788daee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= Date: Wed, 3 Jun 2026 08:52:25 +0200 Subject: [PATCH] agent: optimize autocomplete and cache providers --- src/agent/Agent.zig | 147 ++++++++++++++++++++++++++--------------- src/agent/Terminal.zig | 21 +++--- 2 files changed, 103 insertions(+), 65 deletions(-) diff --git a/src/agent/Agent.zig b/src/agent/Agent.zig index e8e22921..5546a11a 100644 --- a/src/agent/Agent.zig +++ b/src/agent/Agent.zig @@ -138,8 +138,31 @@ synthetic_tool_call_id: u32 = 0, total_usage: zenai.provider.Usage = .{}, /// Set when the last turn ended in a model refusal (safety stop). last_turn_refused: bool = false, +available_providers: []const []const u8, + +fn resolveModelName(opts: Config.Agent, resolved: ?settings.ResolvedProvider, remembered: ?settings.Remembered) []const u8 { + if (opts.model) |m| return m; + if (resolved) |r| { + if (r.source == .remembered) { + if (remembered) |rem| return rem.model; + } + return zenai.provider.defaultModel(r.credentials.provider); + } + return ""; +} pub fn init(allocator: std.mem.Allocator, app: *App, opts: Config.Agent) !*Agent { + var providers_buf: [@typeInfo(Config.AiProvider).@"enum".fields.len]Credentials = undefined; + const found_providers = settings.availableProviders(&providers_buf); + const available_providers = try allocator.alloc([]const u8, found_providers.len); + errdefer { + for (available_providers) |p| allocator.free(p); + allocator.free(available_providers); + } + for (found_providers, 0..) |f, i| { + available_providers[i] = try allocator.dupe(u8, @tagName(f.provider)); + } + if (opts.task != null and opts.script_file != null) { log.fatal(.app, "conflicting flags", .{ .hint = "--task runs a one-shot turn; drop the positional script or drop --task", @@ -193,15 +216,7 @@ pub fn init(allocator: std.mem.Allocator, app: *App, opts: Config.Agent) !*Agent return error.MissingProvider; } - const model: []u8 = if (opts.model) |m| - try allocator.dupe(u8, m) - else if (resolved) |r| - if (r.source == .remembered) - try allocator.dupe(u8, remembered.?.model) - else - try allocator.dupe(u8, zenai.provider.defaultModel(r.credentials.provider)) - else - try allocator.dupe(u8, ""); + const model = try allocator.dupe(u8, resolveModelName(opts, resolved, remembered)); errdefer allocator.free(model); if (resolved) |r| { @@ -253,6 +268,7 @@ pub fn init(allocator: std.mem.Allocator, app: *App, opts: Config.Agent) !*Agent .interactive = opts.interactive, .one_shot_task = opts.task, .one_shot_attachments = if (opts.attach.items.len == 0) null else opts.attach.items, + .available_providers = available_providers, }; errdefer self.node_registry.deinit(); errdefer self.terminal.deinit(); @@ -284,6 +300,9 @@ pub fn init(allocator: std.mem.Allocator, app: *App, opts: Config.Agent) !*Agent .providers = completionProviders, .models = completionModels, }; + // Warm the model-list cache so the first autocomplete keystroke doesn't + // block on the network. + if (self.model_credentials != null) _ = completionModels(self, allocator); } if (recorder_path) |p| { @@ -312,6 +331,8 @@ pub fn deinit(self: *Agent) void { self.notification.deinit(); if (self.ai_client) |ai_client| ai_client.deinit(self.allocator); self.allocator.free(self.model); + for (self.available_providers) |p| self.allocator.free(p); + self.allocator.free(self.available_providers); self.allocator.destroy(self); } @@ -675,10 +696,41 @@ fn setProvider(self: *Agent, credentials: Credentials) !void { settings.saveRemembered(credentials.provider, self.model); self.terminal.printInfo("provider: {s}", .{@tagName(credentials.provider)}); self.terminal.printInfo("model: {s}", .{self.model}); + _ = completionModels(self, self.allocator); } const SaveMode = enum { replace, append }; +const PathAndMode = struct { path: []const u8, mode: SaveMode }; + +fn resolveSavePathAndMode(self: *Agent, arena: std.mem.Allocator, filename: ?[]const u8) ?PathAndMode { + if (self.save_path) |saved| { + if (filename) |name| { + if (!std.mem.eql(u8, saved, name)) { + self.terminal.printError("already saving to {s}; use /save without a filename to append to it", .{saved}); + return null; + } + } + return .{ .path = saved, .mode = .append }; + } else if (filename) |name| { + const exists = fileExists(name) catch |err| { + self.terminal.printError("failed to inspect {s}: {s}", .{ name, @errorName(err) }); + return null; + }; + const mode = if (exists) + self.promptSaveMode(name) orelse return null + else + .replace; + return .{ .path = name, .mode = mode }; + } else { + const path = randomSaveFilename(arena) catch |err| { + self.terminal.printError("failed to choose save filename: {s}", .{@errorName(err)}); + return null; + }; + return .{ .path = path, .mode = .replace }; + } +} + fn handleSave(self: *Agent, arena: std.mem.Allocator, rest: []const u8) void { const parsed = parseSaveCommand(rest) catch |err| { const msg: []const u8 = switch (err) { @@ -700,31 +752,9 @@ fn handleSave(self: *Agent, arena: std.mem.Allocator, rest: []const u8) void { if (parsed.prompt != null) { self.terminal.printWarning("prompt ignored without an LLM; saving the recorded commands as-is", .{}); } - const filename = parsed.filename; - - const path: []const u8, const mode: SaveMode = if (self.save_path) |saved| blk: { - if (filename) |name| { - if (!std.mem.eql(u8, saved, name)) { - self.terminal.printError("already saving to {s}; use /save without a filename to append to it", .{saved}); - return; - } - } - break :blk .{ saved, .append }; - } else blk: { - const path = filename orelse randomSaveFilename(arena) catch |err| { - self.terminal.printError("failed to choose save filename: {s}", .{@errorName(err)}); - return; - }; - const exists = fileExists(path) catch |err| { - self.terminal.printError("failed to inspect {s}: {s}", .{ path, @errorName(err) }); - return; - }; - const mode: SaveMode = if (exists) - self.promptSaveMode(path) orelse return - else - .replace; - break :blk .{ path, mode }; - }; + const resolved = self.resolveSavePathAndMode(arena, parsed.filename) orelse return; + const path = resolved.path; + const mode = resolved.mode; // `path` aliases either an arena-owned string (first save) or // `self.save_path` (subsequent saves to the same destination); only @@ -979,22 +1009,20 @@ fn stripCodeFence(text: []const u8) []const u8 { return std.mem.trim(u8, body[0..close], &std.ascii.whitespace); } +fn logSaveBufferError(self: *Agent, err: anyerror) void { + self.terminal.printError("save buffer disabled: {s}", .{@errorName(err)}); +} + fn recordSaveCommand(self: *Agent, cmd: Command) void { - self.save_buffer.record(cmd) catch |err| { - self.terminal.printError("save buffer disabled: {s}", .{@errorName(err)}); - }; + self.save_buffer.record(cmd) catch |err| self.logSaveBufferError(err); } fn recordSaveComment(self: *Agent, comment: []const u8) void { - self.save_buffer.recordComment(comment) catch |err| { - self.terminal.printError("save buffer disabled: {s}", .{@errorName(err)}); - }; + self.save_buffer.recordComment(comment) catch |err| self.logSaveBufferError(err); } fn recordSaveRaw(self: *Agent, line: []const u8) void { - self.save_buffer.recordRaw(line) catch |err| { - self.terminal.printError("save buffer disabled: {s}", .{@errorName(err)}); - }; + self.save_buffer.recordRaw(line) catch |err| self.logSaveBufferError(err); } fn printSlashHelp(self: *Agent, arena: std.mem.Allocator, target: []const u8) void { @@ -1322,7 +1350,12 @@ fn processUserMessage(self: *Agent, input: TurnInput) !?[]const u8 { if (result.cancelled) return self.drainCancellation(msg_baseline); - const file_recorder: ?*Recorder = if (self.recorder) |*r| (if (r.isActive()) r else null) else null; + const file_recorder: ?*Recorder = blk: { + if (self.recorder) |*r| { + if (r.isActive()) break :blk r; + } + break :blk null; + }; const record_to_memory = input.capture_for_save; if (file_recorder != null or record_to_memory) { // When the LLM tries multiple `extract` schemas in one turn, only the @@ -1337,7 +1370,9 @@ fn processUserMessage(self: *Agent, input: TurnInput) !?[]const u8 { for (result.tool_calls_made, 0..) |tc, i| { if (tc.is_error) continue; const tool = std.meta.stringToEnum(BrowserTool, tc.name) orelse continue; - if (last_extract_idx) |idx| if (tool == .extract and idx != i) continue; + if (last_extract_idx) |idx| { + if (tool == .extract and idx != i) continue; + } const args = browser_tools.normalizeArgKeys(self.message_arena.allocator(), tool, tc.arguments) catch tc.arguments; const cmd = Command.fromToolCall(tool, args); if (!cmd.isRecorded()) continue; @@ -1351,9 +1386,11 @@ fn processUserMessage(self: *Agent, input: TurnInput) !?[]const u8 { if (file_recorder) |r| r.record(cmd); if (record_to_memory) self.recordSaveCommand(cmd); } - if (file_recorder) |r| if (!r.isActive()) { - self.terminal.printError("recording disabled (write failed); see logs", .{}); - }; + if (file_recorder) |r| { + if (!r.isActive()) { + self.terminal.printError("recording disabled (write failed); see logs", .{}); + } + } } // Dupe into `message_arena` — RunToolsResult arenas are deinited below. @@ -1539,14 +1576,14 @@ const ModelCompletions = struct { ids: []const []const u8, }; -/// `CompletionSource.providers`. Stateless — `context` (the `*Agent`) is unused, -/// present only for the shared callback shape. +/// `CompletionSource.providers`. Reuses pre-detected available providers to avoid +/// reading environment variables on every autocomplete keypress. fn completionProviders(context: *anyopaque, arena: std.mem.Allocator) []const []const u8 { - _ = context; - var buf: [@typeInfo(Config.AiProvider).@"enum".fields.len]Credentials = undefined; - const found = settings.availableProviders(&buf); - const names = arena.alloc([]const u8, found.len) catch return &.{}; - for (found, 0..) |f, i| names[i] = @tagName(f.provider); + const self: *Agent = @ptrCast(@alignCast(context)); + const names = arena.alloc([]const u8, self.available_providers.len) catch return &.{}; + for (self.available_providers, 0..) |p, i| { + names[i] = arena.dupe(u8, p) catch return &.{}; + } return names; } diff --git a/src/agent/Terminal.zig b/src/agent/Terminal.zig index 8982189e..3411162a 100644 --- a/src/agent/Terminal.zig +++ b/src/agent/Terminal.zig @@ -387,16 +387,7 @@ fn completionCallback(cenv: ?*c.ic_completion_env_t, prefix: [*c]const u8) callc const inside_block = Schema.hasUnclosedTripleQuote(input); if (input[0] == '/') { - if (has_space) { - if (!inside_block) if (Schema.parseSlashCommand(input)) |parts| { - if (Schema.findByName(parts.name)) |schema| { - addPartialKeyCompletions(cenv, input, parts.rest, schema, &buf); - } else if (SlashCommand.findMeta(parts.name)) |meta| { - self.addMetaValueCompletions(cenv, input, parts.rest, meta, &buf); - } - }; - // Fall through so `value=$LP_` picks up env completions. - } else { + if (!has_space) { const partial = input[1..]; // Trailing space on commands with params hands off to the hinter, // which renders the full ` [timeout=…]` template uniformly @@ -406,7 +397,17 @@ fn completionCallback(cenv: ?*c.ic_completion_env_t, prefix: [*c]const u8) callc addPrefixedCompletion(cenv, &buf, input, "/", name, suffix, partial); } return; + } else if (!inside_block) { + if (Schema.parseSlashCommand(input)) |parts| { + if (Schema.findByName(parts.name)) |schema| { + addPartialKeyCompletions(cenv, input, parts.rest, schema, &buf); + } else if (SlashCommand.findMeta(parts.name)) |meta| { + self.addMetaValueCompletions(cenv, input, parts.rest, meta, &buf); + } + } } + // Fall through so `value=$LP_` picks up env completions, including + // inside an unclosed `'''` block. } addEnvVarCompletions(cenv, &buf, input);