Compare commits

..

6 Commits

Author SHA1 Message Date
Parth Sareen
b1d711f8cc x: add skills spec for custom tool definitions
Add a JSON-based skills specification system that allows users to define
custom tools loaded by the experimental agent loop. Skills are auto-discovered
from standard locations (./ollama-skills.json, ~/.ollama/skills.json,
~/.config/ollama/skills.json).

Features:
- SkillSpec schema with parameters and executor configuration
- Script executor that runs commands with JSON args via stdin
- /skills reload command for runtime skill reloading
- Comprehensive validation and error handling

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 01:39:27 -08:00
ParthSareen
301b8547da update windows pathing 2026-01-05 23:10:46 -08:00
ParthSareen
cb72dbce93 use flag instead of goto 2026-01-05 23:07:30 -08:00
ParthSareen
c3a534c13c fix windows/mac separation 2026-01-05 22:39:49 -08:00
ParthSareen
35d10ba88a address comments 2026-01-05 22:09:06 -08:00
ParthSareen
dd74829c27 x: add experimental agent loop with tool approval
Add `--experimental` / `--beta` flag to enable an agent loop that allows
LLMs to use tools (bash, web_search) with interactive user approval.

Features:
- Built-in tools: bash command execution, web search via Ollama API
- Interactive approval UI with arrow key navigation
- Auto-allowlist for safe commands (pwd, git status, npm run, etc.)
- Denylist for dangerous patterns (rm -rf, sudo, credential access)
- Prefix-based allowlist for approved directories (cat src/ approves cat src/*)
- Warning box for commands targeting paths outside project directory

Architecture:
- x/tools/: Tool registry, bash executor, web search client
- x/agent/: Approval manager with TUI selector
- x/cmd/: Agent loop orchestration

The readline change fixes a stdin race condition where background goroutine
consumption caused double-keypress issues in the approval UI.
2026-01-05 21:53:46 -08:00
19 changed files with 829 additions and 1265 deletions

View File

@@ -520,7 +520,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("yolo")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -548,9 +547,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Use experimental agent loop with tools
// Use experimental agent loop with
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
}
return generateInteractive(cmd, opts)
@@ -1765,7 +1764,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@@ -227,9 +227,9 @@ func TestOlmo3Renderer(t *testing.T) {
ID: "call_1",
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: testArgsOrdered([]orderedArg{
{"from", "SFO"},
{"to", "NYC"},
Arguments: testArgs(map[string]any{
"from": "SFO",
"to": "NYC",
}),
},
},
@@ -243,9 +243,9 @@ func TestOlmo3Renderer(t *testing.T) {
Name: "book_flight",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsOrdered([]orderedProp{
{"from", api.ToolProperty{Type: api.PropertyType{"string"}}},
{"to", api.ToolProperty{Type: api.PropertyType{"string"}}},
Properties: testPropsMap(map[string]api.ToolProperty{
"from": {Type: api.PropertyType{"string"}},
"to": {Type: api.PropertyType{"string"}},
}),
},
},

View File

@@ -34,18 +34,3 @@ func testArgsOrdered(pairs []orderedArg) api.ToolCallFunctionArguments {
}
return args
}
// orderedProp represents a key-value pair for ordered property creation
type orderedProp struct {
Key string
Value api.ToolProperty
}
// testPropsOrdered creates a ToolPropertiesMap with a specific key order
func testPropsOrdered(pairs []orderedProp) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for _, p := range pairs {
props.Set(p.Key, p.Value)
}
return props
}

View File

@@ -6,9 +6,6 @@ import (
var ErrInterrupt = errors.New("Interrupt")
// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output
var ErrExpandOutput = errors.New("ExpandOutput")
type InterruptError struct {
Line []rune
}

View File

@@ -206,9 +206,6 @@ func (i *Instance) Readline() (string, error) {
buf.DeleteBefore()
case CharCtrlL:
buf.ClearScreen()
case CharCtrlO:
// Ctrl+O - expand tool output
return "", ErrExpandOutput
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:

View File

@@ -18,7 +18,6 @@ const (
CharCtrlL = 12
CharEnter = 13
CharNext = 14
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
CharPrev = 16
CharBckSearch = 18
CharFwdSearch = 19

View File

@@ -4,7 +4,6 @@ package agent
import (
"fmt"
"os"
"path"
"path/filepath"
"strings"
"sync"
@@ -180,7 +179,6 @@ func FormatDeniedResult(command string, pattern string) string {
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For commands without path args, returns empty string.
// Paths with ".." traversal that escape the base directory return empty string for security.
func extractBashPrefix(command string) string {
// Split command by pipes and get the first part
parts := strings.Split(command, "|")
@@ -206,8 +204,8 @@ func extractBashPrefix(command string) string {
return ""
}
// Find the first path-like argument (must contain / or \ or start with .)
// First pass: look for clear paths (containing path separators or starting with .)
// Find the first path-like argument (must contain / or start with .)
// First pass: look for clear paths (containing / or starting with .)
for _, arg := range fields[1:] {
// Skip flags
if strings.HasPrefix(arg, "-") {
@@ -217,49 +215,19 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// Only process if it looks like a path (contains / or \ or starts with .)
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
// Only process if it looks like a path (contains / or starts with .)
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
continue
}
// Normalize to forward slashes for consistent cross-platform matching
arg = strings.ReplaceAll(arg, "\\", "/")
// Security: reject absolute paths
if path.IsAbs(arg) {
return "" // Absolute path - don't create prefix
// If arg ends with /, it's a directory - use it directly
if strings.HasSuffix(arg, "/") {
return fmt.Sprintf("%s:%s", baseCmd, arg)
}
// Normalize the path using stdlib path.Clean (resolves . and ..)
cleaned := path.Clean(arg)
// Security: reject if cleaned path escapes to parent directory
if strings.HasPrefix(cleaned, "..") {
return "" // Path escapes - don't create prefix
}
// Security: if original had "..", verify cleaned path didn't escape to sibling
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
if strings.Contains(arg, "..") {
origBase := strings.SplitN(arg, "/", 2)[0]
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
if origBase != cleanedBase {
return "" // Path escaped to sibling directory
}
}
// Check if arg ends with / (explicit directory)
isDir := strings.HasSuffix(arg, "/")
// Get the directory part
var dir string
if isDir {
dir = cleaned
} else {
dir = path.Dir(cleaned)
}
// Get the directory part of a file path
dir := filepath.Dir(arg)
if dir == "." {
return fmt.Sprintf("%s:./", baseCmd)
// Path is just a directory like "tools" or "src" (no trailing /)
return fmt.Sprintf("%s:%s/", baseCmd, arg)
}
return fmt.Sprintf("%s:%s/", baseCmd, dir)
}
@@ -364,8 +332,6 @@ func AllowlistKey(toolName string, args map[string]any) string {
}
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
a.mu.RLock()
defer a.mu.RUnlock()
@@ -376,20 +342,12 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return true
}
// For bash commands, check prefix matches with hierarchical path support
// For bash commands, check prefix matches
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
prefix := extractBashPrefix(cmd)
if prefix != "" {
// Check exact prefix match first
if a.prefixes[prefix] {
return true
}
// Check hierarchical match: if any stored prefix is a parent of current prefix
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
if a.matchesHierarchicalPrefix(prefix) {
return true
}
if prefix != "" && a.prefixes[prefix] {
return true
}
}
}
@@ -402,40 +360,6 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
return false
}
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
// Split prefix into command and path parts (format: "cmd:path/")
colonIdx := strings.Index(currentPrefix, ":")
if colonIdx == -1 {
return false
}
currentCmd := currentPrefix[:colonIdx]
currentPath := currentPrefix[colonIdx+1:]
for storedPrefix := range a.prefixes {
storedColonIdx := strings.Index(storedPrefix, ":")
if storedColonIdx == -1 {
continue
}
storedCmd := storedPrefix[:storedColonIdx]
storedPath := storedPrefix[storedColonIdx+1:]
// Commands must match exactly
if currentCmd != storedCmd {
continue
}
// Check if current path starts with stored path (hierarchical match)
// e.g., "tools/subdir/" starts with "tools/"
if strings.HasPrefix(currentPath, storedPath) {
return true
}
}
return false
}
// AddToAllowlist adds a tool/command to the session allowlist.
// For bash commands, it adds the prefix pattern instead of exact command.
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
@@ -519,12 +443,11 @@ func formatToolDisplay(toolName string, args map[string]any) string {
}
}
// For web search, show query and internet notice
// For web search, show query
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
sb.WriteString("Uses internet via ollama.com")
sb.WriteString(fmt.Sprintf("Query: %s", query))
return sb.String()
}
}
@@ -1028,184 +951,3 @@ func FormatDenyResult(toolName string, reason string) string {
}
return fmt.Sprintf("User denied execution of %s.", toolName)
}
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
// Returns true for Yes, false for No.
func PromptYesNo(question string) (bool, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return false, err
}
defer term.Restore(fd, oldState)
selected := 0 // 0 = Yes, 1 = No
options := []string{"Yes", "No"}
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
renderYesNo := func() {
// Move to start of line and clear
fmt.Fprintf(os.Stderr, "\r\033[K")
fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question)
for i, opt := range options {
if i == selected {
fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt)
} else {
fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt)
}
}
fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m")
}
renderYesNo()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return false, err
}
if n == 1 {
switch buf[0] {
case 'y', 'Y':
selected = 0
renderYesNo()
case 'n', 'N':
selected = 1
renderYesNo()
case '\r', '\n': // Enter
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
return selected == 0, nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\r\033[K")
return false, nil
case 27: // Escape - could be arrow key
// Read more bytes for arrow keys
continue
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
// Arrow keys
switch buf[2] {
case 'D': // Left
if selected > 0 {
selected--
}
renderYesNo()
case 'C': // Right
if selected < len(options)-1 {
selected++
}
renderYesNo()
}
}
}
}
// CloudModelOption represents a suggested cloud model for the selection prompt.
type CloudModelOption struct {
Name string
Description string
}
// PromptModelChoice displays a model selection prompt with multiple options.
// Returns the selected model name, or empty string if user declined or cancelled.
func PromptModelChoice(question string, models []CloudModelOption) (string, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return "", err
}
defer term.Restore(fd, oldState)
// Build options: models + "No thanks, continue"
optionCount := len(models) + 1
selected := 0
// Total lines: question + models + "no thanks" + hint = optionCount + 2
totalLines := optionCount + 2
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
firstRender := true
render := func() {
if !firstRender {
fmt.Fprintf(os.Stderr, "\033[%dA\r", totalLines-1)
}
firstRender = false
// \r\n needed in raw mode for proper line breaks
fmt.Fprintf(os.Stderr, "\033[K\033[36m%s\033[0m\r\n", question)
for i, model := range models {
fmt.Fprintf(os.Stderr, "\033[K")
if i == selected {
fmt.Fprintf(os.Stderr, " \033[1;32m> %s\033[0m \033[90m%s\033[0m\r\n", model.Name, model.Description)
} else {
fmt.Fprintf(os.Stderr, " \033[90m%s %s\033[0m\r\n", model.Name, model.Description)
}
}
fmt.Fprintf(os.Stderr, "\033[K")
if selected == len(models) {
fmt.Fprintf(os.Stderr, " \033[1;32m> No thanks, continue\033[0m\r\n")
} else {
fmt.Fprintf(os.Stderr, " \033[90mNo thanks, continue\033[0m\r\n")
}
fmt.Fprintf(os.Stderr, "\033[K\033[90m(↑/↓ to navigate, Enter to confirm)\033[0m")
}
render()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return "", err
}
if n == 1 {
switch buf[0] {
case 'j', 'J':
if selected < optionCount-1 {
selected++
}
render()
case 'k', 'K':
if selected > 0 {
selected--
}
render()
case '\r', '\n':
fmt.Fprintf(os.Stderr, "\n")
if selected < len(models) {
return models[selected].Name, nil
}
return "", nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\n")
return "", nil
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
switch buf[2] {
case 'A': // Up
if selected > 0 {
selected--
}
render()
case 'B': // Down
if selected < optionCount-1 {
selected++
}
render()
}
}
}
}

View File

@@ -151,27 +151,6 @@ func TestExtractBashPrefix(t *testing.T) {
command: "head -n 100",
expected: "",
},
// Path traversal security tests
{
name: "path traversal - parent escape",
command: "cat tools/../../etc/passwd",
expected: "", // Should NOT create a prefix - path escapes
},
{
name: "path traversal - deep escape",
command: "cat tools/a/b/../../../etc/passwd",
expected: "", // Normalizes to "../etc/passwd" - escapes
},
{
name: "path traversal - absolute path",
command: "cat /etc/passwd",
expected: "", // Absolute paths should not create prefix
},
{
name: "path with safe dotdot - normalized",
command: "cat tools/subdir/../file.go",
expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix
},
}
for _, tt := range tests {
@@ -185,34 +164,6 @@ func TestExtractBashPrefix(t *testing.T) {
}
}
func TestApprovalManager_PathTraversalBlocked(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Path traversal attack: should NOT be allowed
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) {
t.Error("SECURITY: path traversal attack should NOT be allowed")
}
// Another traversal variant
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) {
t.Error("SECURITY: deep path traversal should NOT be allowed")
}
// Valid subdirectory access should still work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed")
}
// Safe ".." that normalizes to within allowed directory should work
// tools/subdir/../other.go normalizes to tools/other.go which is under tools/
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) {
t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)")
}
}
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
@@ -235,119 +186,6 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) {
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go" - this creates prefix "cat:tools/"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should allow subdirectories (hierarchical matching)
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix")
}
// Should allow deeply nested subdirectories
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) {
t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix")
}
// Should still allow same directory
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) {
t.Error("expected cat tools/another.go to be allowed")
}
// Should NOT allow different base directory
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
t.Error("expected cat src/main.go to NOT be allowed")
}
// Should NOT allow different command even in subdirectory
if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) {
t.Error("expected ls tools/subdir/ to NOT be allowed (different command)")
}
// Should NOT allow similar but different directory name
if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) {
t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)")
}
}
func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) {
am := NewApprovalManager()
// Allow with forward slashes (Unix-style)
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should work with backslashes too (Windows-style) - normalized internally
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) {
t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)")
}
// Mixed slashes should also work
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) {
t.Error("expected mixed slash path to be allowed via hierarchical prefix")
}
}
func TestMatchesHierarchicalPrefix(t *testing.T) {
am := NewApprovalManager()
// Add prefix for "cat:tools/"
am.prefixes["cat:tools/"] = true
tests := []struct {
name string
prefix string
expected bool
}{
{
name: "exact match",
prefix: "cat:tools/",
expected: true, // exact match also passes HasPrefix - caller handles exact match first
},
{
name: "subdirectory",
prefix: "cat:tools/subdir/",
expected: true,
},
{
name: "deeply nested",
prefix: "cat:tools/a/b/c/",
expected: true,
},
{
name: "different base directory",
prefix: "cat:src/",
expected: false,
},
{
name: "different command same path",
prefix: "ls:tools/",
expected: false,
},
{
name: "similar directory name",
prefix: "cat:toolsbin/",
expected: false,
},
{
name: "invalid prefix format",
prefix: "cattools",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := am.matchesHierarchicalPrefix(tt.prefix)
if result != tt.expected {
t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v",
tt.prefix, result, tt.expected)
}
})
}
}
func TestFormatApprovalResult(t *testing.T) {
tests := []struct {
name string

View File

@@ -1,25 +0,0 @@
package agent
import (
"testing"
)
func TestCloudModelOptionStruct(t *testing.T) {
// Test that the struct is defined correctly
models := []CloudModelOption{
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
}
if len(models) != 2 {
t.Errorf("expected 2 models, got %d", len(models))
}
if models[0].Name != "glm-4.7:cloud" {
t.Errorf("expected glm-4.7:cloud, got %s", models[0].Name)
}
if models[1].Description != "Qwen3 Coder 480B" {
t.Errorf("expected 'Qwen3 Coder 480B', got %s", models[1].Description)
}
}

View File

@@ -1,41 +0,0 @@
package cmd
import (
"errors"
"testing"
)
func TestCloudModelSwitchRequest(t *testing.T) {
// Test the error type
req := &CloudModelSwitchRequest{Model: "glm-4.7:cloud"}
// Test Error() method
errMsg := req.Error()
expected := "switch to model: glm-4.7:cloud"
if errMsg != expected {
t.Errorf("expected %q, got %q", expected, errMsg)
}
// Test errors.As
var err error = req
var switchReq *CloudModelSwitchRequest
if !errors.As(err, &switchReq) {
t.Error("errors.As should return true for CloudModelSwitchRequest")
}
if switchReq.Model != "glm-4.7:cloud" {
t.Errorf("expected model glm-4.7:cloud, got %s", switchReq.Model)
}
}
func TestSuggestedCloudModels(t *testing.T) {
// Verify the suggested models are defined
if len(suggestedCloudModels) == 0 {
t.Error("suggestedCloudModels should not be empty")
}
// Check first model
if suggestedCloudModels[0].Name != "glm-4.7:cloud" {
t.Errorf("expected first model to be glm-4.7:cloud, got %s", suggestedCloudModels[0].Name)
}
}

View File

@@ -6,12 +6,10 @@ import (
"errors"
"fmt"
"io"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
"golang.org/x/term"
@@ -24,132 +22,6 @@ import (
"github.com/ollama/ollama/x/tools"
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
localModelTokenLimit = 4000
// defaultTokenLimit is the token limit for cloud/remote models.
defaultTokenLimit = 10000
// charsPerToken is a rough estimate of characters per token.
// TODO: Estimate tokens more accurately using tokenizer if available
charsPerToken = 4
)
// suggestedCloudModels are the models suggested to users after signing in.
// TODO(parthsareen): Dynamically recommend models based on user context instead of hardcoding
var suggestedCloudModels = []agent.CloudModelOption{
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
}
// CloudModelSwitchRequest signals that the user wants to switch to a different model.
type CloudModelSwitchRequest struct {
Model string
}
func (c *CloudModelSwitchRequest) Error() string {
return fmt.Sprintf("switch to model: %s", c.Model)
}
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
return !strings.HasSuffix(modelName, "-cloud")
}
// isLocalServer checks if connecting to a local Ollama server.
// TODO: Could also check other indicators of local vs cloud server
func isLocalServer() bool {
host := os.Getenv("OLLAMA_HOST")
if host == "" {
return true // Default is localhost:11434
}
// Parse the URL to check host
parsed, err := url.Parse(host)
if err != nil {
return true // If can't parse, assume local
}
hostname := parsed.Hostname()
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
}
// truncateToolOutput truncates tool output to prevent context overflow.
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
func truncateToolOutput(output, modelName string) string {
var tokenLimit int
if isLocalModel(modelName) && isLocalServer() {
tokenLimit = localModelTokenLimit
} else {
tokenLimit = defaultTokenLimit
}
maxChars := tokenLimit * charsPerToken
if len(output) > maxChars {
return output[:maxChars] + "\n... (output truncated)"
}
return output
}
// waitForOllamaSignin shows the signin URL and polls until authentication completes.
func waitForOllamaSignin(ctx context.Context) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Get signin URL from initial Whoami call
_, err = client.Whoami(ctx)
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m")
// Poll until auth succeeds
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "\n")
return ctx.Err()
case <-ticker.C:
user, whoamiErr := client.Whoami(ctx)
if whoamiErr == nil && user != nil && user.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name)
return nil
}
// Still waiting, show dot
fmt.Fprintf(os.Stderr, ".")
}
}
}
return err
}
return nil
}
// promptCloudModelSuggestion shows cloud model suggestions after successful sign-in.
// Returns the selected model name, or empty string if user declines.
func promptCloudModelSuggestion() string {
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "\033[1;36mTry cloud models for free!\033[0m\n")
fmt.Fprintf(os.Stderr, "\033[90mCloud models offer powerful capabilities without local hardware requirements.\033[0m\n")
fmt.Fprintf(os.Stderr, "\n")
selectedModel, err := agent.PromptModelChoice("Try a cloud model now?", suggestedCloudModels)
if err != nil || selectedModel == "" {
return ""
}
return selectedModel
}
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
@@ -165,50 +37,6 @@ type RunOptions struct {
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
Approval *agent.ApprovalManager
// YoloMode skips all tool approval prompts
YoloMode bool
// LastToolOutput stores the full output of the last tool execution
// for Ctrl+O expansion. Updated by Chat(), read by caller.
LastToolOutput *string
// LastToolOutputTruncated stores the truncated version shown inline
LastToolOutputTruncated *string
// ActiveModel points to the current model name - can be updated mid-turn
// for model switching. If nil, opts.Model is used.
ActiveModel *string
}
// getActiveModel returns the current model name, checking ActiveModel pointer first.
func getActiveModel(opts *RunOptions) string {
if opts.ActiveModel != nil && *opts.ActiveModel != "" {
return *opts.ActiveModel
}
return opts.Model
}
// showModelConnection displays "Connecting to X on ollama.com" for cloud models.
func showModelConnection(ctx context.Context, modelName string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
info, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return err
}
if info.RemoteHost != "" {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
return nil
}
// Chat runs an agent chat loop with tool support.
@@ -249,7 +77,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
role := "assistant"
messages := opts.Messages
@@ -308,7 +135,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// Agentic loop: continue until no more tool calls
for {
req := &api.ChatRequest{
Model: getActiveModel(&opts),
Model: opts.Model,
Messages: messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
@@ -332,61 +159,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, nil
}
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
suggestedModel := promptCloudModelSuggestion()
if suggestedModel != "" {
return nil, &CloudModelSwitchRequest{Model: suggestedModel}
}
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
continue
}
}
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
}
// Check for 500 errors (often tool parsing failures) - inform the model
var statusErr api.StatusError
if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 {
consecutiveErrors++
p.StopAndClear()
if consecutiveErrors >= 3 {
fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n")
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
}
fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage)
// Include both the model's response and the error so it can learn
assistantContent := fullResponse.String()
if assistantContent == "" {
assistantContent = "(empty response)"
}
errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent)
messages = append(messages,
api.Message{Role: "user", Content: errorMsg},
)
// Reset state and retry
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
continue
}
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
@@ -396,9 +168,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, err
}
// Reset consecutive error counter on success
consecutiveErrors = 0
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || toolRegistry == nil {
break
@@ -447,12 +216,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Check approval (uses prefix matching for bash commands)
// In yolo mode, skip all approval prompts
if opts.YoloMode {
if !skipApproval {
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
if !skipApproval && !approval.IsAllowed(toolName, args) {
result, err := approval.RequestApproval(toolName, args)
if err != nil {
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
@@ -483,27 +247,9 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
suggestedModel := promptCloudModelSuggestion()
if suggestedModel != "" && opts.ActiveModel != nil {
*opts.ActiveModel = suggestedModel
showModelConnection(ctx, suggestedModel)
}
fmt.Fprintf(os.Stderr, "\033[90mRetrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
}
}
}
}
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
@@ -512,34 +258,20 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
})
continue
}
toolSuccess:
// Display tool output (truncated for display)
truncatedOutput := ""
if toolResult != "" {
output := toolResult
if len(output) > 300 {
output = output[:300] + "... (truncated, press Ctrl+O to expand)"
output = output[:300] + "... (truncated)"
}
truncatedOutput = output
// Show result in grey, indented
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
}
// Store full and truncated output for Ctrl+O toggle
if opts.LastToolOutput != nil {
*opts.LastToolOutput = toolResult
}
if opts.LastToolOutputTruncated != nil {
*opts.LastToolOutputTruncated = truncatedOutput
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, getActiveModel(&opts))
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: toolResultForLLM,
Content: toolResult,
ToolCallID: call.ID,
})
}
@@ -694,34 +426,30 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
return out
}
// checkModelCapabilities checks if the model supports tools and thinking.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, supportsThinking bool, err error) {
// checkModelCapabilities checks if the model supports tools.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return false, false, err
return false, err
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return false, false, err
return false, err
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityTools {
supportsTools = true
}
if cap == model.CapabilityThinking {
supportsThinking = true
return true, nil
}
}
return supportsTools, supportsThinking, nil
return false, nil
}
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
// If yoloMode is true, all tool approvals are skipped.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -735,26 +463,31 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
// Check if model supports tools and thinking
supportsTools, supportsThinking, err := checkModelCapabilities(cmd.Context(), modelName)
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
supportsTools = false
supportsThinking = false
}
// Track if session is using thinking mode
usingThinking := think != nil && supportsThinking
// Create tool registry only if model supports tools
var toolRegistry *tools.Registry
if supportsTools {
toolRegistry = tools.DefaultRegistry()
if toolRegistry.Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
// Load custom skills from skill files
loadedFiles, err := tools.LoadAllSkills(toolRegistry)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[33mWarning: Error loading skills: %v\033[0m\n", err)
} else if len(loadedFiles) > 0 {
fmt.Fprintf(os.Stderr, "Loaded skills from: %s\n", strings.Join(loadedFiles, ", "))
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n")
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
// Check for OLLAMA_API_KEY for web search
if os.Getenv("OLLAMA_API_KEY") == "" {
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
@@ -766,11 +499,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message
var sb strings.Builder
// Track last tool output for Ctrl+O toggle
var lastToolOutput string
var lastToolOutputTruncated string
var toolOutputExpanded bool
for {
line, err := scanner.Readline()
switch {
@@ -783,20 +511,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
sb.Reset()
continue
case errors.Is(err, readline.ErrExpandOutput):
// Ctrl+O pressed - toggle between expanded and collapsed tool output
if lastToolOutput == "" {
fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n")
} else if toolOutputExpanded {
// Currently expanded, show truncated
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n "))
toolOutputExpanded = false
} else {
// Currently collapsed, show full
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n "))
toolOutputExpanded = true
}
continue
case err != nil:
return err
}
@@ -812,15 +526,34 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
case strings.HasPrefix(line, "/tools"):
showToolsStatus(toolRegistry, approval, supportsTools)
continue
case strings.HasPrefix(line, "/skills reload"):
if toolRegistry != nil {
loadedFiles, err := tools.LoadAllSkills(toolRegistry)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[33mWarning: Error loading skills: %v\033[0m\n", err)
} else if len(loadedFiles) > 0 {
fmt.Fprintf(os.Stderr, "Reloaded skills from: %s\n", strings.Join(loadedFiles, ", "))
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
} else {
fmt.Fprintln(os.Stderr, "No skill files found")
}
} else {
fmt.Fprintln(os.Stderr, "Tools not available - model does not support tool calling")
}
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
fmt.Fprintln(os.Stderr, " /skills reload Reload custom skills from skill files")
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:")
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "Custom Skills:")
fmt.Fprintln(os.Stderr, " Skills are loaded from JSON files in these locations:")
fmt.Fprintln(os.Stderr, " ./ollama-skills.json")
fmt.Fprintln(os.Stderr, " ~/.ollama/skills.json")
fmt.Fprintln(os.Stderr, " ~/.config/ollama/skills.json")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/"):
@@ -833,44 +566,25 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
toolOutputExpanded = false
retryChat:
for {
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
LastToolOutput: &lastToolOutput,
LastToolOutputTruncated: &lastToolOutputTruncated,
ActiveModel: &modelName,
}
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
}
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
var switchReq *CloudModelSwitchRequest
if errors.As(err, &switchReq) {
newModel := switchReq.Model
if err := switchToModel(cmd.Context(), newModel, &modelName, &supportsTools, &supportsThinking, &toolRegistry, usingThinking); err != nil {
fmt.Fprintf(os.Stderr, "\033[33m%v\033[0m\n", err)
fmt.Fprintf(os.Stderr, "\033[90mContinuing with %s...\033[0m\n", modelName)
}
continue retryChat
}
return err
}
if assistant != nil {
messages = append(messages, *assistant)
}
break retryChat
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
return err
}
if assistant != nil {
messages = append(messages, *assistant)
}
sb.Reset()
@@ -878,52 +592,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
}
// switchToModel handles model switching with capability checks and UI updates.
func switchToModel(ctx context.Context, newModel string, modelName *string, supportsTools, supportsThinking *bool, toolRegistry **tools.Registry, usingThinking bool) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
newSupportsTools, newSupportsThinking, capErr := checkModelCapabilities(ctx, newModel)
if capErr != nil {
return fmt.Errorf("could not check model capabilities: %w", capErr)
}
// TODO(parthsareen): Handle thinking -> non-thinking model switch gracefully
if usingThinking && !newSupportsThinking {
return fmt.Errorf("%s does not support thinking mode", newModel)
}
// Show "Connecting to X on ollama.com" for cloud models
info, err := client.Show(ctx, &api.ShowRequest{Model: newModel})
if err == nil && info.RemoteHost != "" {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
*modelName = newModel
*supportsTools = newSupportsTools
*supportsThinking = newSupportsThinking
if *supportsTools {
if *toolRegistry == nil {
*toolRegistry = tools.DefaultRegistry()
}
if (*toolRegistry).Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join((*toolRegistry).Names(), ", "))
}
} else {
*toolRegistry = nil
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
return nil
}
// showToolsStatus displays the current tools and approval status.
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
if !supportsTools || registry == nil {

View File

@@ -1,180 +0,0 @@
package cmd
import (
"testing"
)
func TestIsLocalModel(t *testing.T) {
tests := []struct {
name string
modelName string
expected bool
}{
{
name: "local model without suffix",
modelName: "llama3.2",
expected: true,
},
{
name: "local model with version",
modelName: "qwen2.5:7b",
expected: true,
},
{
name: "cloud model",
modelName: "gpt-4-cloud",
expected: false,
},
{
name: "cloud model with version",
modelName: "claude-3-cloud",
expected: false,
},
{
name: "empty model name",
modelName: "",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isLocalModel(tt.modelName)
if result != tt.expected {
t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected)
}
})
}
}
func TestIsLocalServer(t *testing.T) {
tests := []struct {
name string
host string
expected bool
}{
{
name: "empty host (default)",
host: "",
expected: true,
},
{
name: "localhost",
host: "http://localhost:11434",
expected: true,
},
{
name: "127.0.0.1",
host: "http://127.0.0.1:11434",
expected: true,
},
{
name: "custom port on localhost",
host: "http://localhost:8080",
expected: true, // localhost is always considered local
},
{
name: "remote host",
host: "http://ollama.example.com:11434",
expected: true, // has :11434
},
{
name: "remote host different port",
host: "http://ollama.example.com:8080",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := isLocalServer()
if result != tt.expected {
t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected)
}
})
}
}
func TestTruncateToolOutput(t *testing.T) {
// Create outputs of different sizes
localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars)
defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars)
for i := range localLimitOutput {
localLimitOutput[i] = 'a'
}
for i := range defaultLimitOutput {
defaultLimitOutput[i] = 'b'
}
tests := []struct {
name string
output string
modelName string
host string
shouldTrim bool
expectedLimit int
}{
{
name: "short output local model",
output: "hello world",
modelName: "llama3.2",
host: "",
shouldTrim: false,
expectedLimit: localModelTokenLimit,
},
{
name: "long output local model - trimmed at 4k",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "",
shouldTrim: true,
expectedLimit: localModelTokenLimit,
},
{
name: "long output cloud model - uses 10k limit",
output: string(localLimitOutput), // 20k chars, under 10k token limit
modelName: "gpt-4-cloud",
host: "",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
{
name: "very long output cloud model - trimmed at 10k",
output: string(defaultLimitOutput),
modelName: "gpt-4-cloud",
host: "",
shouldTrim: true,
expectedLimit: defaultTokenLimit,
},
{
name: "long output remote server - uses 10k limit",
output: string(localLimitOutput),
modelName: "llama3.2",
host: "http://remote.example.com:8080",
shouldTrim: false,
expectedLimit: defaultTokenLimit,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.host)
result := truncateToolOutput(tt.output, tt.modelName)
if tt.shouldTrim {
maxLen := tt.expectedLimit * charsPerToken
if len(result) > maxLen+50 { // +50 for the truncation message
t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result))
}
if result == tt.output {
t.Error("expected output to be truncated but it wasn't")
}
} else {
if result != tt.output {
t.Error("expected output to not be truncated")
}
}
})
}
}

View File

@@ -0,0 +1,44 @@
{
"version": "1",
"skills": [
{
"name": "get_time",
"description": "Get the current date and time. Use this when the user asks about the current time or date.",
"parameters": [],
"executor": {
"type": "script",
"command": "date",
"timeout": 5
}
},
{
"name": "system_info",
"description": "Get system information including OS, hostname, and uptime.",
"parameters": [],
"executor": {
"type": "script",
"command": "sh",
"args": ["-c", "echo \"Hostname: $(hostname)\"; echo \"OS: $(uname -s)\"; echo \"Kernel: $(uname -r)\"; echo \"Uptime: $(uptime -p 2>/dev/null || uptime)\""],
"timeout": 10
}
},
{
"name": "python_eval",
"description": "Evaluate a Python expression. The expression is passed via stdin as JSON with an 'expression' field.",
"parameters": [
{
"name": "expression",
"type": "string",
"description": "The Python expression to evaluate",
"required": true
}
],
"executor": {
"type": "script",
"command": "python3",
"args": ["-c", "import sys, json; data = json.load(sys.stdin); print(eval(data.get('expression', '')))"],
"timeout": 30
}
}
]
}

View File

@@ -3,7 +3,6 @@ package tools
import (
"fmt"
"os"
"sort"
"github.com/ollama/ollama/api"
@@ -89,16 +88,9 @@ func (r *Registry) Count() int {
}
// DefaultRegistry creates a registry with all built-in tools.
// Tools can be disabled via environment variables:
// - OLLAMA_AGENT_DISABLE_WEBSEARCH=1 disables web_search
// - OLLAMA_AGENT_DISABLE_BASH=1 disables bash
func DefaultRegistry() *Registry {
r := NewRegistry()
if os.Getenv("OLLAMA_AGENT_DISABLE_WEBSEARCH") == "" {
r.Register(&WebSearchTool{})
}
if os.Getenv("OLLAMA_AGENT_DISABLE_BASH") == "" {
r.Register(&BashTool{})
}
r.Register(&WebSearchTool{})
r.Register(&BashTool{})
return r
}

View File

@@ -108,57 +108,6 @@ func TestDefaultRegistry(t *testing.T) {
}
}
func TestDefaultRegistry_DisableWebsearch(t *testing.T) {
t.Setenv("OLLAMA_AGENT_DISABLE_WEBSEARCH", "1")
r := DefaultRegistry()
if r.Count() != 1 {
t.Errorf("expected 1 tool with websearch disabled, got %d", r.Count())
}
_, ok := r.Get("bash")
if !ok {
t.Error("expected bash tool in registry")
}
_, ok = r.Get("web_search")
if ok {
t.Error("expected web_search to be disabled")
}
}
func TestDefaultRegistry_DisableBash(t *testing.T) {
t.Setenv("OLLAMA_AGENT_DISABLE_BASH", "1")
r := DefaultRegistry()
if r.Count() != 1 {
t.Errorf("expected 1 tool with bash disabled, got %d", r.Count())
}
_, ok := r.Get("web_search")
if !ok {
t.Error("expected web_search tool in registry")
}
_, ok = r.Get("bash")
if ok {
t.Error("expected bash to be disabled")
}
}
func TestDefaultRegistry_DisableBoth(t *testing.T) {
t.Setenv("OLLAMA_AGENT_DISABLE_WEBSEARCH", "1")
t.Setenv("OLLAMA_AGENT_DISABLE_BASH", "1")
r := DefaultRegistry()
if r.Count() != 0 {
t.Errorf("expected 0 tools with both disabled, got %d", r.Count())
}
}
func TestBashTool_Schema(t *testing.T) {
tool := &BashTool{}

318
x/tools/skills.go Normal file
View File

@@ -0,0 +1,318 @@
// Package tools provides built-in tool implementations for the agent loop.
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/api"
)
// SkillSpec defines the specification for a custom skill.
// Skills can be loaded from JSON files and registered with the tool registry.
type SkillSpec struct {
// Name is the unique identifier for the skill
Name string `json:"name"`
// Description is a human-readable description shown to the LLM
Description string `json:"description"`
// Parameters defines the input schema for the skill
Parameters []SkillParameter `json:"parameters,omitempty"`
// Executor defines how the skill is executed
Executor SkillExecutor `json:"executor"`
}
// SkillParameter defines a single parameter for a skill.
type SkillParameter struct {
// Name is the parameter name
Name string `json:"name"`
// Type is the JSON schema type (string, number, boolean, array, object)
Type string `json:"type"`
// Description explains what this parameter is for
Description string `json:"description"`
// Required indicates if this parameter must be provided
Required bool `json:"required"`
}
// SkillExecutor defines how to execute a skill.
type SkillExecutor struct {
// Type is the executor type: "script", "http", or "builtin"
Type string `json:"type"`
// Command is the command to run for "script" type
// Arguments are passed as JSON via stdin, result is read from stdout
Command string `json:"command,omitempty"`
// Args are additional arguments appended to the command
Args []string `json:"args,omitempty"`
// Timeout is the maximum execution time in seconds (default: 60)
Timeout int `json:"timeout,omitempty"`
// URL is the endpoint for "http" type executors
URL string `json:"url,omitempty"`
// Method is the HTTP method (default: POST)
Method string `json:"method,omitempty"`
}
// SkillsFile represents a file containing skill definitions.
type SkillsFile struct {
// Version is the spec version (currently "1")
Version string `json:"version"`
// Skills is the list of skill definitions
Skills []SkillSpec `json:"skills"`
}
// SkillTool wraps a SkillSpec to implement the Tool interface.
type SkillTool struct {
spec SkillSpec
}
// NewSkillTool creates a Tool from a SkillSpec.
func NewSkillTool(spec SkillSpec) *SkillTool {
return &SkillTool{spec: spec}
}
// Name returns the skill name.
func (s *SkillTool) Name() string {
return s.spec.Name
}
// Description returns the skill description.
func (s *SkillTool) Description() string {
return s.spec.Description
}
// Schema returns the tool's parameter schema for the LLM.
func (s *SkillTool) Schema() api.ToolFunction {
props := api.NewToolPropertiesMap()
var required []string
for _, param := range s.spec.Parameters {
props.Set(param.Name, api.ToolProperty{
Type: api.PropertyType{param.Type},
Description: param.Description,
})
if param.Required {
required = append(required, param.Name)
}
}
return api.ToolFunction{
Name: s.spec.Name,
Description: s.spec.Description,
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: required,
},
}
}
// Execute runs the skill with the given arguments.
func (s *SkillTool) Execute(args map[string]any) (string, error) {
switch s.spec.Executor.Type {
case "script":
return s.executeScript(args)
case "http":
return s.executeHTTP(args)
default:
return "", fmt.Errorf("unknown executor type: %s", s.spec.Executor.Type)
}
}
// executeScript runs a script-based skill.
func (s *SkillTool) executeScript(args map[string]any) (string, error) {
if s.spec.Executor.Command == "" {
return "", fmt.Errorf("script executor requires command")
}
timeout := time.Duration(s.spec.Executor.Timeout) * time.Second
if timeout == 0 {
timeout = 60 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Build command
cmdArgs := append([]string{}, s.spec.Executor.Args...)
cmd := exec.CommandContext(ctx, s.spec.Executor.Command, cmdArgs...)
// Pass arguments as JSON via stdin
inputJSON, err := json.Marshal(args)
if err != nil {
return "", fmt.Errorf("marshaling arguments: %w", err)
}
cmd.Stdin = bytes.NewReader(inputJSON)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = cmd.Run()
// Build output
var sb strings.Builder
if stdout.Len() > 0 {
output := stdout.String()
if len(output) > maxOutputSize {
output = output[:maxOutputSize] + "\n... (output truncated)"
}
sb.WriteString(output)
}
if stderr.Len() > 0 {
stderrOutput := stderr.String()
if len(stderrOutput) > maxOutputSize {
stderrOutput = stderrOutput[:maxOutputSize] + "\n... (stderr truncated)"
}
if sb.Len() > 0 {
sb.WriteString("\n")
}
sb.WriteString("stderr:\n")
sb.WriteString(stderrOutput)
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return sb.String() + fmt.Sprintf("\n\nError: command timed out after %d seconds", s.spec.Executor.Timeout), nil
}
if exitErr, ok := err.(*exec.ExitError); ok {
return sb.String() + fmt.Sprintf("\n\nExit code: %d", exitErr.ExitCode()), nil
}
return sb.String(), fmt.Errorf("executing skill: %w", err)
}
if sb.Len() == 0 {
return "(no output)", nil
}
return sb.String(), nil
}
// executeHTTP runs an HTTP-based skill.
func (s *SkillTool) executeHTTP(args map[string]any) (string, error) {
// HTTP executor is a placeholder for future implementation
return "", fmt.Errorf("http executor not yet implemented")
}
// LoadSkillsFile loads skill definitions from a JSON file.
func LoadSkillsFile(path string) (*SkillsFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading skills file: %w", err)
}
var file SkillsFile
if err := json.Unmarshal(data, &file); err != nil {
return nil, fmt.Errorf("parsing skills file: %w", err)
}
if file.Version == "" {
file.Version = "1"
}
return &file, nil
}
// RegisterSkillsFromFile loads skills from a file and registers them with the registry.
func RegisterSkillsFromFile(registry *Registry, path string) error {
file, err := LoadSkillsFile(path)
if err != nil {
return err
}
for _, spec := range file.Skills {
if err := validateSkillSpec(spec); err != nil {
return fmt.Errorf("invalid skill %q: %w", spec.Name, err)
}
registry.Register(NewSkillTool(spec))
}
return nil
}
// FindSkillsFiles searches for skill definition files in standard locations.
// It looks for:
// - ./ollama-skills.json (current directory)
// - ~/.ollama/skills.json (user config)
// - ~/.config/ollama/skills.json (XDG config)
func FindSkillsFiles() []string {
var files []string
// Current directory
if _, err := os.Stat("ollama-skills.json"); err == nil {
files = append(files, "ollama-skills.json")
}
// Home directory
home, err := os.UserHomeDir()
if err == nil {
paths := []string{
filepath.Join(home, ".ollama", "skills.json"),
filepath.Join(home, ".config", "ollama", "skills.json"),
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
files = append(files, p)
}
}
}
return files
}
// LoadAllSkills loads skills from all discovered skill files into the registry.
func LoadAllSkills(registry *Registry) ([]string, error) {
files := FindSkillsFiles()
var loaded []string
for _, path := range files {
if err := RegisterSkillsFromFile(registry, path); err != nil {
return loaded, fmt.Errorf("loading %s: %w", path, err)
}
loaded = append(loaded, path)
}
return loaded, nil
}
// validateSkillSpec validates a skill specification.
func validateSkillSpec(spec SkillSpec) error {
if spec.Name == "" {
return fmt.Errorf("name is required")
}
if spec.Description == "" {
return fmt.Errorf("description is required")
}
if spec.Executor.Type == "" {
return fmt.Errorf("executor.type is required")
}
switch spec.Executor.Type {
case "script":
if spec.Executor.Command == "" {
return fmt.Errorf("executor.command is required for script type")
}
case "http":
if spec.Executor.URL == "" {
return fmt.Errorf("executor.url is required for http type")
}
default:
return fmt.Errorf("unknown executor type: %s", spec.Executor.Type)
}
for _, param := range spec.Parameters {
if param.Name == "" {
return fmt.Errorf("parameter name is required")
}
if param.Type == "" {
return fmt.Errorf("parameter type is required for %s", param.Name)
}
}
return nil
}

368
x/tools/skills_test.go Normal file
View File

@@ -0,0 +1,368 @@
package tools
import (
"os"
"path/filepath"
"testing"
)
func TestValidateSkillSpec(t *testing.T) {
tests := []struct {
name string
spec SkillSpec
wantErr bool
errMsg string
}{
{
name: "valid script skill",
spec: SkillSpec{
Name: "test_skill",
Description: "A test skill",
Parameters: []SkillParameter{
{Name: "input", Type: "string", Description: "Input value", Required: true},
},
Executor: SkillExecutor{
Type: "script",
Command: "echo",
},
},
wantErr: false,
},
{
name: "valid http skill",
spec: SkillSpec{
Name: "http_skill",
Description: "An HTTP skill",
Executor: SkillExecutor{
Type: "http",
URL: "https://example.com/api",
},
},
wantErr: false,
},
{
name: "missing name",
spec: SkillSpec{
Description: "A skill without name",
Executor: SkillExecutor{Type: "script", Command: "echo"},
},
wantErr: true,
errMsg: "name is required",
},
{
name: "missing description",
spec: SkillSpec{
Name: "no_desc",
Executor: SkillExecutor{Type: "script", Command: "echo"},
},
wantErr: true,
errMsg: "description is required",
},
{
name: "missing executor type",
spec: SkillSpec{
Name: "no_exec_type",
Description: "Missing executor type",
Executor: SkillExecutor{Command: "echo"},
},
wantErr: true,
errMsg: "executor.type is required",
},
{
name: "script missing command",
spec: SkillSpec{
Name: "script_no_cmd",
Description: "Script without command",
Executor: SkillExecutor{Type: "script"},
},
wantErr: true,
errMsg: "executor.command is required",
},
{
name: "http missing url",
spec: SkillSpec{
Name: "http_no_url",
Description: "HTTP without URL",
Executor: SkillExecutor{Type: "http"},
},
wantErr: true,
errMsg: "executor.url is required",
},
{
name: "unknown executor type",
spec: SkillSpec{
Name: "unknown_type",
Description: "Unknown executor",
Executor: SkillExecutor{Type: "invalid"},
},
wantErr: true,
errMsg: "unknown executor type",
},
{
name: "parameter missing name",
spec: SkillSpec{
Name: "param_no_name",
Description: "Parameter without name",
Parameters: []SkillParameter{{Type: "string", Description: "desc"}},
Executor: SkillExecutor{Type: "script", Command: "echo"},
},
wantErr: true,
errMsg: "parameter name is required",
},
{
name: "parameter missing type",
spec: SkillSpec{
Name: "param_no_type",
Description: "Parameter without type",
Parameters: []SkillParameter{{Name: "foo", Description: "desc"}},
Executor: SkillExecutor{Type: "script", Command: "echo"},
},
wantErr: true,
errMsg: "parameter type is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateSkillSpec(tt.spec)
if tt.wantErr {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.errMsg)
} else if tt.errMsg != "" && !contains(err.Error(), tt.errMsg) {
t.Errorf("expected error containing %q, got %q", tt.errMsg, err.Error())
}
} else if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestLoadSkillsFile(t *testing.T) {
// Create a temporary skills file
tmpDir := t.TempDir()
skillsPath := filepath.Join(tmpDir, "skills.json")
content := `{
"version": "1",
"skills": [
{
"name": "echo_skill",
"description": "Echoes the input",
"parameters": [
{"name": "message", "type": "string", "description": "Message to echo", "required": true}
],
"executor": {
"type": "script",
"command": "echo",
"timeout": 30
}
}
]
}`
if err := os.WriteFile(skillsPath, []byte(content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
file, err := LoadSkillsFile(skillsPath)
if err != nil {
t.Fatalf("LoadSkillsFile failed: %v", err)
}
if file.Version != "1" {
t.Errorf("expected version 1, got %s", file.Version)
}
if len(file.Skills) != 1 {
t.Fatalf("expected 1 skill, got %d", len(file.Skills))
}
skill := file.Skills[0]
if skill.Name != "echo_skill" {
t.Errorf("expected name 'echo_skill', got %s", skill.Name)
}
if skill.Executor.Timeout != 30 {
t.Errorf("expected timeout 30, got %d", skill.Executor.Timeout)
}
if len(skill.Parameters) != 1 {
t.Errorf("expected 1 parameter, got %d", len(skill.Parameters))
}
}
func TestRegisterSkillsFromFile(t *testing.T) {
tmpDir := t.TempDir()
skillsPath := filepath.Join(tmpDir, "skills.json")
content := `{
"version": "1",
"skills": [
{
"name": "skill_a",
"description": "Skill A",
"executor": {"type": "script", "command": "echo"}
},
{
"name": "skill_b",
"description": "Skill B",
"executor": {"type": "script", "command": "cat"}
}
]
}`
if err := os.WriteFile(skillsPath, []byte(content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
registry := NewRegistry()
if err := RegisterSkillsFromFile(registry, skillsPath); err != nil {
t.Fatalf("RegisterSkillsFromFile failed: %v", err)
}
if registry.Count() != 2 {
t.Errorf("expected 2 tools, got %d", registry.Count())
}
names := registry.Names()
if names[0] != "skill_a" || names[1] != "skill_b" {
t.Errorf("unexpected tool names: %v", names)
}
}
func TestSkillToolSchema(t *testing.T) {
spec := SkillSpec{
Name: "test_tool",
Description: "A test tool",
Parameters: []SkillParameter{
{Name: "required_param", Type: "string", Description: "Required", Required: true},
{Name: "optional_param", Type: "number", Description: "Optional", Required: false},
},
Executor: SkillExecutor{Type: "script", Command: "echo"},
}
tool := NewSkillTool(spec)
if tool.Name() != "test_tool" {
t.Errorf("expected name 'test_tool', got %s", tool.Name())
}
if tool.Description() != "A test tool" {
t.Errorf("expected description 'A test tool', got %s", tool.Description())
}
schema := tool.Schema()
if schema.Name != "test_tool" {
t.Errorf("schema name mismatch")
}
if len(schema.Parameters.Required) != 1 {
t.Errorf("expected 1 required param, got %d", len(schema.Parameters.Required))
}
if schema.Parameters.Required[0] != "required_param" {
t.Errorf("wrong required param: %v", schema.Parameters.Required)
}
}
func TestSkillToolExecuteScript(t *testing.T) {
spec := SkillSpec{
Name: "echo_test",
Description: "Echo test",
Executor: SkillExecutor{
Type: "script",
Command: "cat",
Timeout: 5,
},
}
tool := NewSkillTool(spec)
// cat will read JSON from stdin and output it
result, err := tool.Execute(map[string]any{"message": "hello"})
if err != nil {
t.Fatalf("Execute failed: %v", err)
}
if !contains(result, "message") || !contains(result, "hello") {
t.Errorf("expected JSON output with 'message' and 'hello', got: %s", result)
}
}
func TestSkillToolExecuteHTTPNotImplemented(t *testing.T) {
spec := SkillSpec{
Name: "http_test",
Description: "HTTP test",
Executor: SkillExecutor{
Type: "http",
URL: "https://example.com",
},
}
tool := NewSkillTool(spec)
_, err := tool.Execute(map[string]any{})
if err == nil {
t.Error("expected error for http executor")
}
if !contains(err.Error(), "not yet implemented") {
t.Errorf("unexpected error: %v", err)
}
}
func TestLoadSkillsFileNotFound(t *testing.T) {
_, err := LoadSkillsFile("/nonexistent/path/skills.json")
if err == nil {
t.Error("expected error for nonexistent file")
}
}
func TestLoadSkillsFileInvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
skillsPath := filepath.Join(tmpDir, "invalid.json")
if err := os.WriteFile(skillsPath, []byte("not valid json"), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := LoadSkillsFile(skillsPath)
if err == nil {
t.Error("expected error for invalid JSON")
}
}
func TestRegisterSkillsFromFileInvalidSkill(t *testing.T) {
tmpDir := t.TempDir()
skillsPath := filepath.Join(tmpDir, "invalid_skill.json")
content := `{
"version": "1",
"skills": [
{
"name": "",
"description": "Missing name",
"executor": {"type": "script", "command": "echo"}
}
]
}`
if err := os.WriteFile(skillsPath, []byte(content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
registry := NewRegistry()
err := RegisterSkillsFromFile(registry, skillsPath)
if err == nil {
t.Error("expected error for invalid skill")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -2,19 +2,15 @@ package tools
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"os"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
)
const (
@@ -22,9 +18,6 @@ const (
webSearchTimeout = 15 * time.Second
)
// ErrWebSearchAuthRequired is returned when web search requires authentication
var ErrWebSearchAuthRequired = errors.New("web search requires authentication")
// WebSearchTool implements web search using Ollama's hosted API.
type WebSearchTool struct{}
@@ -75,13 +68,17 @@ type webSearchResult struct {
}
// Execute performs the web search.
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
return "", fmt.Errorf("query parameter is required")
}
apiKey := os.Getenv("OLLAMA_API_KEY")
if apiKey == "" {
return "", fmt.Errorf("OLLAMA_API_KEY environment variable is required for web search")
}
// Prepare request
reqBody := webSearchRequest{
Query: query,
@@ -93,34 +90,13 @@ func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
return "", fmt.Errorf("marshaling request: %w", err)
}
// Parse URL and add timestamp for signing
searchURL, err := url.Parse(webSearchAPI)
if err != nil {
return "", fmt.Errorf("parsing search URL: %w", err)
}
q := searchURL.Query()
q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
searchURL.RawQuery = q.Encode()
// Sign the request using Ollama key (~/.ollama/id_ed25519)
// This authenticates with ollama.com using the local signing key
ctx := context.Background()
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, searchURL.RequestURI())
signature, err := auth.Sign(ctx, data)
if err != nil {
return "", fmt.Errorf("signing request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, searchURL.String(), bytes.NewBuffer(jsonBody))
req, err := http.NewRequest("POST", webSearchAPI, bytes.NewBuffer(jsonBody))
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if signature != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
}
req.Header.Set("Authorization", "Bearer "+apiKey)
// Send request
client := &http.Client{Timeout: webSearchTimeout}
@@ -135,9 +111,6 @@ func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
return "", fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized {
return "", ErrWebSearchAuthRequired
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("web search API returned status %d: %s", resp.StatusCode, string(body))
}

View File

@@ -1,58 +0,0 @@
package tools
import (
"errors"
"testing"
)
func TestWebSearchTool_Name(t *testing.T) {
tool := &WebSearchTool{}
if tool.Name() != "web_search" {
t.Errorf("expected name 'web_search', got '%s'", tool.Name())
}
}
func TestWebSearchTool_Description(t *testing.T) {
tool := &WebSearchTool{}
if tool.Description() == "" {
t.Error("expected non-empty description")
}
}
func TestWebSearchTool_Execute_MissingQuery(t *testing.T) {
tool := &WebSearchTool{}
// Test with no query
_, err := tool.Execute(map[string]any{})
if err == nil {
t.Error("expected error for missing query")
}
// Test with empty query
_, err = tool.Execute(map[string]any{"query": ""})
if err == nil {
t.Error("expected error for empty query")
}
}
func TestErrWebSearchAuthRequired(t *testing.T) {
// Test that the error type exists and can be checked with errors.Is
err := ErrWebSearchAuthRequired
if err == nil {
t.Fatal("ErrWebSearchAuthRequired should not be nil")
}
if err.Error() != "web search requires authentication" {
t.Errorf("unexpected error message: %s", err.Error())
}
// Test that errors.Is works
wrappedErr := errors.New("wrapped: " + err.Error())
if errors.Is(wrappedErr, ErrWebSearchAuthRequired) {
t.Error("wrapped error should not match with errors.Is")
}
if !errors.Is(ErrWebSearchAuthRequired, ErrWebSearchAuthRequired) {
t.Error("ErrWebSearchAuthRequired should match itself with errors.Is")
}
}