mirror of
https://github.com/ollama/ollama.git
synced 2026-01-16 11:29:26 -05:00
Compare commits
1 Commits
usage-anal
...
parth/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65ccbc85c |
@@ -377,15 +377,6 @@ func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||
return &lr, nil
|
||||
}
|
||||
|
||||
// Usage returns usage statistics and system info.
|
||||
func (c *Client) Usage(ctx context.Context) (*UsageResponse, error) {
|
||||
var ur UsageResponse
|
||||
if err := c.do(ctx, http.MethodGet, "/api/usage", nil, &ur); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ur, nil
|
||||
}
|
||||
|
||||
// Copy copies a model - creating a model with another name from an existing
|
||||
// model.
|
||||
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
||||
|
||||
27
api/types.go
27
api/types.go
@@ -792,33 +792,6 @@ type ProcessResponse struct {
|
||||
Models []ProcessModelResponse `json:"models"`
|
||||
}
|
||||
|
||||
// UsageResponse is the response from [Client.Usage].
|
||||
type UsageResponse struct {
|
||||
GPUs []GPUUsage `json:"gpus,omitempty"`
|
||||
}
|
||||
|
||||
// GPUUsage contains GPU/device memory usage breakdown.
|
||||
type GPUUsage struct {
|
||||
Name string `json:"name"` // Device name (e.g., "Apple M2 Max", "NVIDIA GeForce RTX 4090")
|
||||
Backend string `json:"backend"` // CUDA, ROCm, Metal, etc.
|
||||
Total uint64 `json:"total"`
|
||||
Free uint64 `json:"free"`
|
||||
Used uint64 `json:"used"` // Memory used by Ollama
|
||||
Other uint64 `json:"other"` // Memory used by other processes
|
||||
}
|
||||
|
||||
// UsageStats contains usage statistics.
|
||||
type UsageStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
TokensInput int64 `json:"tokens_input"`
|
||||
TokensOutput int64 `json:"tokens_output"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Models map[string]int64 `json:"models,omitempty"`
|
||||
Sources map[string]int64 `json:"sources,omitempty"`
|
||||
ToolCalls int64 `json:"tool_calls,omitempty"`
|
||||
StructuredOutput int64 `json:"structured_output,omitempty"`
|
||||
}
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -520,7 +520,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
yoloMode, _ := cmd.Flags().GetBool("yolo")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
@@ -548,9 +547,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Use experimental agent loop with tools
|
||||
// Use experimental agent loop with
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
@@ -1765,7 +1764,6 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
runCmd.Flags().BoolP("yolo", "y", false, "Skip all tool approval prompts (use with caution)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
@@ -1833,7 +1831,6 @@ func NewCLI() *cobra.Command {
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: ListRunningHandler,
|
||||
}
|
||||
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE DESTINATION",
|
||||
Short: "Copy a model",
|
||||
|
||||
@@ -206,8 +206,6 @@ var (
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
EnableVulkan = Bool("OLLAMA_VULKAN")
|
||||
// Usage enables usage statistics reporting
|
||||
Usage = Bool("OLLAMA_USAGE")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
|
||||
@@ -6,9 +6,6 @@ import (
|
||||
|
||||
var ErrInterrupt = errors.New("Interrupt")
|
||||
|
||||
// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output
|
||||
var ErrExpandOutput = errors.New("ExpandOutput")
|
||||
|
||||
type InterruptError struct {
|
||||
Line []rune
|
||||
}
|
||||
|
||||
@@ -206,9 +206,6 @@ func (i *Instance) Readline() (string, error) {
|
||||
buf.DeleteBefore()
|
||||
case CharCtrlL:
|
||||
buf.ClearScreen()
|
||||
case CharCtrlO:
|
||||
// Ctrl+O - expand tool output
|
||||
return "", ErrExpandOutput
|
||||
case CharCtrlW:
|
||||
buf.DeleteWord()
|
||||
case CharCtrlZ:
|
||||
|
||||
@@ -18,7 +18,6 @@ const (
|
||||
CharCtrlL = 12
|
||||
CharEnter = 13
|
||||
CharNext = 14
|
||||
CharCtrlO = 15 // Ctrl+O - used for expanding tool output
|
||||
CharPrev = 16
|
||||
CharBckSearch = 18
|
||||
CharFwdSearch = 19
|
||||
|
||||
@@ -20,7 +20,8 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
// It also returns numKeep, the number of tokens in system messages + tools that should be protected from truncation.
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, numKeep int, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -44,12 +45,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
ctxLen := len(s)
|
||||
@@ -71,7 +72,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
||||
return "", nil, 0, errors.New("this model only supports one image while more than one image requested")
|
||||
}
|
||||
|
||||
var prefix string
|
||||
@@ -98,10 +99,40 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
// truncate any messages that do not fit into the context window
|
||||
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
return p, images, nil
|
||||
// Compute numKeep: tokens for system messages + tools that should be protected from truncation
|
||||
// Re-collect all system messages from the entire conversation
|
||||
allSystemMsgs := make([]api.Message, 0)
|
||||
for _, msg := range msgs {
|
||||
if msg.Role == "system" {
|
||||
allSystemMsgs = append(allSystemMsgs, msg)
|
||||
}
|
||||
}
|
||||
protectedPrompt, err := renderPrompt(m, allSystemMsgs, tools, think)
|
||||
if err != nil {
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
protectedTokens, err := tokenize(ctx, protectedPrompt)
|
||||
if err != nil {
|
||||
return "", nil, 0, err
|
||||
}
|
||||
|
||||
numKeep = len(protectedTokens)
|
||||
|
||||
// Error if system+tools leaves less than 100 tokens for conversation
|
||||
if numKeep > 0 && numKeep > opts.NumCtx-100 {
|
||||
return "", nil, 0, fmt.Errorf("system prompt and tools (%d tokens) exceed context length (%d) minus required buffer (100 tokens)", numKeep, opts.NumCtx)
|
||||
}
|
||||
|
||||
// Cap numKeep to ensure at least 200 tokens can be generated
|
||||
if opts.NumCtx > 200 {
|
||||
numKeep = min(numKeep, opts.NumCtx-200)
|
||||
}
|
||||
|
||||
return p, images, numKeep, nil
|
||||
}
|
||||
|
||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
|
||||
@@ -235,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||
prompt, images, _, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
||||
140
server/routes.go
140
server/routes.go
@@ -20,7 +20,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -45,7 +44,6 @@ import (
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/server/usage"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/tools"
|
||||
@@ -84,7 +82,6 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
lowVRAM bool
|
||||
stats *usage.Stats
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -107,30 +104,6 @@ var (
|
||||
errBadTemplate = errors.New("template error")
|
||||
)
|
||||
|
||||
// usage records a request to usage stats if enabled.
|
||||
func (s *Server) usage(c *gin.Context, endpoint, model, architecture string, promptTokens, completionTokens int, usedTools bool) {
|
||||
if s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.Record(&usage.Request{
|
||||
Endpoint: endpoint,
|
||||
Model: model,
|
||||
Architecture: architecture,
|
||||
APIType: usage.ClassifyAPIType(c.Request.URL.Path),
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
UsedTools: usedTools,
|
||||
})
|
||||
}
|
||||
|
||||
// usageError records a failed request to usage stats if enabled.
|
||||
func (s *Server) usageError() {
|
||||
if s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.RecordError()
|
||||
}
|
||||
|
||||
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
@@ -401,7 +374,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
return
|
||||
} else if err != nil {
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -486,11 +459,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
var numKeep int
|
||||
prompt, images, numKeep, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// Set numKeep to protect system prompt + tools from truncation during context shift
|
||||
opts.NumKeep = numKeep
|
||||
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
|
||||
if req.Context != nil {
|
||||
b.WriteString(prompt)
|
||||
@@ -588,7 +564,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
s.usage(c, "generate", m.ShortName, m.Config.ModelFamily, cr.PromptEvalCount, cr.EvalCount, false)
|
||||
|
||||
if !req.Raw {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||
@@ -708,7 +683,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -818,7 +793,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: int(totalTokens),
|
||||
}
|
||||
s.usage(c, "embed", m.ShortName, m.Config.ModelFamily, int(totalTokens), 0, false)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -856,7 +830,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1560,7 +1534,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.GET("/api/usage", s.UsageHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
@@ -1623,13 +1596,6 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
// Initialize usage stats if enabled
|
||||
if envconfig.Usage() {
|
||||
s.stats = usage.New()
|
||||
s.stats.Start()
|
||||
slog.Info("usage stats enabled")
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
var err error
|
||||
@@ -1669,9 +1635,6 @@ func Serve(ln net.Listener) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
if s.stats != nil {
|
||||
s.stats.Stop()
|
||||
}
|
||||
srvr.Close()
|
||||
schedDone()
|
||||
sched.unloadAllRunners()
|
||||
@@ -1689,24 +1652,6 @@ func Serve(ln net.Listener) error {
|
||||
gpus := discover.GPUDevices(ctx, nil)
|
||||
discover.LogDetails(gpus)
|
||||
|
||||
// Set GPU info for usage reporting
|
||||
if s.stats != nil {
|
||||
usage.GPUInfoFunc = func() []usage.GPU {
|
||||
var result []usage.GPU
|
||||
for _, gpu := range gpus {
|
||||
result = append(result, usage.GPU{
|
||||
Name: gpu.Name,
|
||||
VRAMBytes: gpu.TotalMemory,
|
||||
ComputeMajor: gpu.ComputeMajor,
|
||||
ComputeMinor: gpu.ComputeMinor,
|
||||
DriverMajor: gpu.DriverMajor,
|
||||
DriverMinor: gpu.DriverMinor,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
var totalVRAM uint64
|
||||
for _, gpu := range gpus {
|
||||
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
|
||||
@@ -1910,63 +1855,6 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
func (s *Server) UsageHandler(c *gin.Context) {
|
||||
// Get total VRAM used by Ollama
|
||||
s.sched.loadedMu.Lock()
|
||||
var totalOllamaVRAM uint64
|
||||
for _, runner := range s.sched.loaded {
|
||||
totalOllamaVRAM += runner.vramSize
|
||||
}
|
||||
s.sched.loadedMu.Unlock()
|
||||
|
||||
var resp api.UsageResponse
|
||||
|
||||
// Get GPU/device info
|
||||
gpus := discover.GPUDevices(c.Request.Context(), nil)
|
||||
|
||||
// On Apple Silicon, use system memory instead of Metal's recommendedMaxWorkingSetSize
|
||||
// because unified memory means GPU and CPU share the same physical RAM pool
|
||||
var sysTotal, sysFree uint64
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||
sysInfo := discover.GetSystemInfo()
|
||||
sysTotal = sysInfo.TotalMemory
|
||||
sysFree = sysInfo.FreeMemory
|
||||
}
|
||||
|
||||
for _, gpu := range gpus {
|
||||
total := gpu.TotalMemory
|
||||
free := gpu.FreeMemory
|
||||
|
||||
// On Apple Silicon, override with system memory values
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" && sysTotal > 0 {
|
||||
total = sysTotal
|
||||
free = sysFree
|
||||
}
|
||||
|
||||
used := total - free
|
||||
ollamaUsed := min(totalOllamaVRAM, used)
|
||||
otherUsed := used - ollamaUsed
|
||||
|
||||
// Use Description for Name (actual device name like "Apple M2 Max")
|
||||
// Fall back to backend name if Description is empty
|
||||
name := gpu.Description
|
||||
if name == "" {
|
||||
name = gpu.Name
|
||||
}
|
||||
|
||||
resp.GPUs = append(resp.GPUs, api.GPUUsage{
|
||||
Name: name,
|
||||
Backend: gpu.Library,
|
||||
Total: total,
|
||||
Free: free,
|
||||
Used: ollamaUsed,
|
||||
Other: otherUsed,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
@@ -2147,7 +2035,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
return
|
||||
} else if err != nil {
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2191,13 +2079,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
prompt, images, numKeep, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Set numKeep to protect system prompt + tools from truncation during context shift
|
||||
opts.NumKeep = numKeep
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
@@ -2295,7 +2186,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, r.PromptEvalCount, r.EvalCount, len(req.Tools) > 0)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
@@ -2405,7 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
msgs = append(msgs, msg)
|
||||
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
prompt, _, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error applying structured outputs", "error", err)
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
@@ -2471,7 +2361,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
}
|
||||
|
||||
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, resp.PromptEvalCount, resp.EvalCount, len(toolCalls) > 0)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
@@ -2479,8 +2368,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func (s *Server) handleScheduleError(c *gin.Context, name string, err error) {
|
||||
s.usageError()
|
||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
switch {
|
||||
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestUsageHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("empty server", func(t *testing.T) {
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
loaded: make(map[string]*runnerRef),
|
||||
},
|
||||
}
|
||||
|
||||
w := createRequest(t, s.UsageHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.UsageResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// GPUs may or may not be present depending on system
|
||||
// Just verify we can decode the response
|
||||
})
|
||||
|
||||
t.Run("response structure", func(t *testing.T) {
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
loaded: make(map[string]*runnerRef),
|
||||
},
|
||||
}
|
||||
|
||||
w := createRequest(t, s.UsageHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
// Verify we can decode the response as valid JSON
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The response should be a valid object (not null)
|
||||
if resp == nil {
|
||||
t.Error("expected non-nil response")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const (
|
||||
reportTimeout = 10 * time.Second
|
||||
usageURL = "https://ollama.com/api/usage"
|
||||
)
|
||||
|
||||
// HeartbeatResponse is the response from the heartbeat endpoint.
|
||||
type HeartbeatResponse struct {
|
||||
UpdateVersion string `json:"update_version,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateAvailable returns the available update version, if any.
|
||||
func (t *Stats) UpdateAvailable() string {
|
||||
if v := t.updateAvailable.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sendHeartbeat sends usage stats and checks for updates.
|
||||
func (t *Stats) sendHeartbeat(payload *Payload) {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), reportTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, usageURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s", version.Version))
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var heartbeat HeartbeatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&heartbeat); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.updateAvailable.Store(heartbeat.UpdateVersion)
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// API type constants
|
||||
const (
|
||||
APITypeOllama = "ollama"
|
||||
APITypeOpenAI = "openai"
|
||||
APITypeAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// ClassifyAPIType determines the API type from the request path.
|
||||
func ClassifyAPIType(path string) string {
|
||||
if strings.HasPrefix(path, "/v1/messages") {
|
||||
return APITypeAnthropic
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/") {
|
||||
return APITypeOpenAI
|
||||
}
|
||||
return APITypeOllama
|
||||
}
|
||||
@@ -1,324 +0,0 @@
|
||||
// Package usage provides in-memory usage statistics collection and reporting.
|
||||
package usage
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// Stats collects usage statistics in memory and reports them periodically.
|
||||
type Stats struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Atomic counters for hot path
|
||||
requestsTotal atomic.Int64
|
||||
tokensPrompt atomic.Int64
|
||||
tokensCompletion atomic.Int64
|
||||
errorsTotal atomic.Int64
|
||||
|
||||
// Map-based counters (require lock)
|
||||
endpoints map[string]int64
|
||||
architectures map[string]int64
|
||||
apis map[string]int64
|
||||
models map[string]*ModelStats // per-model stats
|
||||
|
||||
// Feature usage
|
||||
toolCalls atomic.Int64
|
||||
structuredOutput atomic.Int64
|
||||
|
||||
// Update info (set by reporter after pinging update endpoint)
|
||||
updateAvailable atomic.Value // string
|
||||
|
||||
// Reporter
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
interval time.Duration
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// ModelStats tracks per-model usage statistics.
|
||||
type ModelStats struct {
|
||||
Requests int64
|
||||
TokensInput int64
|
||||
TokensOutput int64
|
||||
}
|
||||
|
||||
// Request contains the data to record for a single request.
|
||||
type Request struct {
|
||||
Endpoint string // "chat", "generate", "embed"
|
||||
Model string // model name (e.g., "llama3.2:3b")
|
||||
Architecture string // model architecture (e.g., "llama", "qwen2")
|
||||
APIType string // "native" or "openai_compat"
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
UsedTools bool
|
||||
StructuredOutput bool
|
||||
}
|
||||
|
||||
// SystemInfo contains hardware information to report.
|
||||
type SystemInfo struct {
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
CPUCores int `json:"cpu_cores"`
|
||||
RAMBytes uint64 `json:"ram_bytes"`
|
||||
GPUs []GPU `json:"gpus,omitempty"`
|
||||
}
|
||||
|
||||
// GPU contains information about a GPU.
|
||||
type GPU struct {
|
||||
Name string `json:"name"`
|
||||
VRAMBytes uint64 `json:"vram_bytes"`
|
||||
ComputeMajor int `json:"compute_major,omitempty"`
|
||||
ComputeMinor int `json:"compute_minor,omitempty"`
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
}
|
||||
|
||||
// Payload is the data sent to the heartbeat endpoint.
|
||||
type Payload struct {
|
||||
Version string `json:"version"`
|
||||
Time time.Time `json:"time"`
|
||||
System SystemInfo `json:"system"`
|
||||
|
||||
Totals struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Errors int64 `json:"errors"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"totals"`
|
||||
|
||||
Endpoints map[string]int64 `json:"endpoints"`
|
||||
Architectures map[string]int64 `json:"architectures"`
|
||||
APIs map[string]int64 `json:"apis"`
|
||||
|
||||
Features struct {
|
||||
ToolCalls int64 `json:"tool_calls"`
|
||||
StructuredOutput int64 `json:"structured_output"`
|
||||
} `json:"features"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// New creates a new Stats instance.
|
||||
func New(opts ...Option) *Stats {
|
||||
t := &Stats{
|
||||
endpoints: make(map[string]int64),
|
||||
architectures: make(map[string]int64),
|
||||
apis: make(map[string]int64),
|
||||
models: make(map[string]*ModelStats),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
interval: defaultInterval,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Option configures the Stats instance.
|
||||
type Option func(*Stats)
|
||||
|
||||
// WithInterval sets the reporting interval.
|
||||
func WithInterval(d time.Duration) Option {
|
||||
return func(t *Stats) {
|
||||
t.interval = d
|
||||
}
|
||||
}
|
||||
|
||||
// Record records a request. This is the hot path and should be fast.
|
||||
func (t *Stats) Record(r *Request) {
|
||||
t.requestsTotal.Add(1)
|
||||
t.tokensPrompt.Add(int64(r.PromptTokens))
|
||||
t.tokensCompletion.Add(int64(r.CompletionTokens))
|
||||
|
||||
if r.UsedTools {
|
||||
t.toolCalls.Add(1)
|
||||
}
|
||||
if r.StructuredOutput {
|
||||
t.structuredOutput.Add(1)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.endpoints[r.Endpoint]++
|
||||
t.architectures[r.Architecture]++
|
||||
t.apis[r.APIType]++
|
||||
|
||||
// Track per-model stats
|
||||
if r.Model != "" {
|
||||
if t.models[r.Model] == nil {
|
||||
t.models[r.Model] = &ModelStats{}
|
||||
}
|
||||
t.models[r.Model].Requests++
|
||||
t.models[r.Model].TokensInput += int64(r.PromptTokens)
|
||||
t.models[r.Model].TokensOutput += int64(r.CompletionTokens)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// RecordError records a failed request.
|
||||
func (t *Stats) RecordError() {
|
||||
t.errorsTotal.Add(1)
|
||||
}
|
||||
|
||||
// GetModelStats returns a copy of per-model statistics.
|
||||
func (t *Stats) GetModelStats() map[string]*ModelStats {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*ModelStats, len(t.models))
|
||||
for k, v := range t.models {
|
||||
result[k] = &ModelStats{
|
||||
Requests: v.Requests,
|
||||
TokensInput: v.TokensInput,
|
||||
TokensOutput: v.TokensOutput,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// View returns current stats without resetting counters.
|
||||
func (t *Stats) View() *Payload {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Copy maps
|
||||
endpoints := make(map[string]int64, len(t.endpoints))
|
||||
for k, v := range t.endpoints {
|
||||
endpoints[k] = v
|
||||
}
|
||||
architectures := make(map[string]int64, len(t.architectures))
|
||||
for k, v := range t.architectures {
|
||||
architectures[k] = v
|
||||
}
|
||||
apis := make(map[string]int64, len(t.apis))
|
||||
for k, v := range t.apis {
|
||||
apis[k] = v
|
||||
}
|
||||
|
||||
p := &Payload{
|
||||
Version: version.Version,
|
||||
Time: now,
|
||||
System: getSystemInfo(),
|
||||
Endpoints: endpoints,
|
||||
Architectures: architectures,
|
||||
APIs: apis,
|
||||
}
|
||||
|
||||
p.Totals.Requests = t.requestsTotal.Load()
|
||||
p.Totals.Errors = t.errorsTotal.Load()
|
||||
p.Totals.InputTokens = t.tokensPrompt.Load()
|
||||
p.Totals.OutputTokens = t.tokensCompletion.Load()
|
||||
p.Features.ToolCalls = t.toolCalls.Load()
|
||||
p.Features.StructuredOutput = t.structuredOutput.Load()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Snapshot returns current stats and resets counters.
|
||||
func (t *Stats) Snapshot() *Payload {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
p := &Payload{
|
||||
Version: version.Version,
|
||||
Time: now,
|
||||
System: getSystemInfo(),
|
||||
Endpoints: t.endpoints,
|
||||
Architectures: t.architectures,
|
||||
APIs: t.apis,
|
||||
}
|
||||
|
||||
p.Totals.Requests = t.requestsTotal.Swap(0)
|
||||
p.Totals.Errors = t.errorsTotal.Swap(0)
|
||||
p.Totals.InputTokens = t.tokensPrompt.Swap(0)
|
||||
p.Totals.OutputTokens = t.tokensCompletion.Swap(0)
|
||||
p.Features.ToolCalls = t.toolCalls.Swap(0)
|
||||
p.Features.StructuredOutput = t.structuredOutput.Swap(0)
|
||||
|
||||
// Reset maps
|
||||
t.endpoints = make(map[string]int64)
|
||||
t.architectures = make(map[string]int64)
|
||||
t.apis = make(map[string]int64)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// getSystemInfo collects hardware information.
|
||||
func getSystemInfo() SystemInfo {
|
||||
info := SystemInfo{
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
}
|
||||
|
||||
// Get CPU and memory info
|
||||
sysInfo := discover.GetSystemInfo()
|
||||
info.CPUCores = sysInfo.ThreadCount
|
||||
info.RAMBytes = sysInfo.TotalMemory
|
||||
|
||||
// Get GPU info
|
||||
gpus := getGPUInfo()
|
||||
info.GPUs = gpus
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GPUInfoFunc is a function that returns GPU information.
|
||||
// It's set by the server package after GPU discovery.
|
||||
var GPUInfoFunc func() []GPU
|
||||
|
||||
// getGPUInfo collects GPU information.
|
||||
func getGPUInfo() []GPU {
|
||||
if GPUInfoFunc != nil {
|
||||
return GPUInfoFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start begins the periodic reporting goroutine.
|
||||
func (t *Stats) Start() {
|
||||
go t.reportLoop()
|
||||
}
|
||||
|
||||
// Stop stops reporting and waits for the final report.
|
||||
func (t *Stats) Stop() {
|
||||
close(t.stopCh)
|
||||
<-t.doneCh
|
||||
}
|
||||
|
||||
// reportLoop runs the periodic reporting.
|
||||
func (t *Stats) reportLoop() {
|
||||
defer close(t.doneCh)
|
||||
|
||||
ticker := time.NewTicker(t.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.report()
|
||||
case <-t.stopCh:
|
||||
// Send final report before stopping
|
||||
t.report()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// report sends usage stats and checks for updates.
|
||||
func (t *Stats) report() {
|
||||
payload := t.Snapshot()
|
||||
t.sendHeartbeat(payload)
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
stats := New()
|
||||
if stats == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecord(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
Architecture: "llama",
|
||||
APIType: "native",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
UsedTools: true,
|
||||
StructuredOutput: false,
|
||||
})
|
||||
|
||||
// Check totals
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", payload.Totals.Requests)
|
||||
}
|
||||
if payload.Totals.InputTokens != 100 {
|
||||
t.Errorf("expected 100 prompt tokens, got %d", payload.Totals.InputTokens)
|
||||
}
|
||||
if payload.Totals.OutputTokens != 50 {
|
||||
t.Errorf("expected 50 completion tokens, got %d", payload.Totals.OutputTokens)
|
||||
}
|
||||
if payload.Features.ToolCalls != 1 {
|
||||
t.Errorf("expected 1 tool call, got %d", payload.Features.ToolCalls)
|
||||
}
|
||||
if payload.Features.StructuredOutput != 0 {
|
||||
t.Errorf("expected 0 structured outputs, got %d", payload.Features.StructuredOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelStats(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
// Record requests for multiple models
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
})
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 200,
|
||||
CompletionTokens: 100,
|
||||
})
|
||||
stats.Record(&Request{
|
||||
Model: "mistral:7b",
|
||||
PromptTokens: 50,
|
||||
CompletionTokens: 25,
|
||||
})
|
||||
|
||||
modelStats := stats.GetModelStats()
|
||||
|
||||
// Check llama3:8b stats
|
||||
llama := modelStats["llama3:8b"]
|
||||
if llama == nil {
|
||||
t.Fatal("expected llama3:8b stats")
|
||||
}
|
||||
if llama.Requests != 2 {
|
||||
t.Errorf("expected 2 requests for llama3:8b, got %d", llama.Requests)
|
||||
}
|
||||
if llama.TokensInput != 300 {
|
||||
t.Errorf("expected 300 input tokens for llama3:8b, got %d", llama.TokensInput)
|
||||
}
|
||||
if llama.TokensOutput != 150 {
|
||||
t.Errorf("expected 150 output tokens for llama3:8b, got %d", llama.TokensOutput)
|
||||
}
|
||||
|
||||
// Check mistral:7b stats
|
||||
mistral := modelStats["mistral:7b"]
|
||||
if mistral == nil {
|
||||
t.Fatal("expected mistral:7b stats")
|
||||
}
|
||||
if mistral.Requests != 1 {
|
||||
t.Errorf("expected 1 request for mistral:7b, got %d", mistral.Requests)
|
||||
}
|
||||
if mistral.TokensInput != 50 {
|
||||
t.Errorf("expected 50 input tokens for mistral:7b, got %d", mistral.TokensInput)
|
||||
}
|
||||
if mistral.TokensOutput != 25 {
|
||||
t.Errorf("expected 25 output tokens for mistral:7b, got %d", mistral.TokensOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordError(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.RecordError()
|
||||
stats.RecordError()
|
||||
|
||||
payload := stats.View()
|
||||
if payload.Totals.Errors != 2 {
|
||||
t.Errorf("expected 2 errors, got %d", payload.Totals.Errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestView(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
Architecture: "llama",
|
||||
APIType: "native",
|
||||
})
|
||||
|
||||
// First view
|
||||
_ = stats.View()
|
||||
|
||||
// View should not reset counters
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1 {
|
||||
t.Errorf("View should not reset counters, expected 1 request, got %d", payload.Totals.Requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
})
|
||||
|
||||
// Snapshot should return data and reset counters
|
||||
snapshot := stats.Snapshot()
|
||||
if snapshot.Totals.Requests != 1 {
|
||||
t.Errorf("expected 1 request in snapshot, got %d", snapshot.Totals.Requests)
|
||||
}
|
||||
|
||||
// After snapshot, counters should be reset
|
||||
payload2 := stats.View()
|
||||
if payload2.Totals.Requests != 0 {
|
||||
t.Errorf("expected 0 requests after snapshot, got %d", payload2.Totals.Requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 10,
|
||||
CompletionTokens: 5,
|
||||
})
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = stats.View()
|
||||
_ = stats.GetModelStats()
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 15; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1000 {
|
||||
t.Errorf("expected 1000 requests, got %d", payload.Totals.Requests)
|
||||
}
|
||||
}
|
||||
@@ -381,28 +381,6 @@ func (t templateTools) String() string {
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateArgs is a map type with JSON string output for templates.
|
||||
type templateArgs map[string]any
|
||||
|
||||
func (t templateArgs) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateProperties is a map type with JSON string output for templates.
|
||||
type templateProperties map[string]api.ToolProperty
|
||||
|
||||
func (t templateProperties) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateTool is a template-compatible representation of api.Tool
|
||||
// with Properties as a regular map for template ranging.
|
||||
type templateTool struct {
|
||||
@@ -418,11 +396,11 @@ type templateToolFunction struct {
|
||||
}
|
||||
|
||||
type templateToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties templateProperties `json:"properties"`
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
// templateToolCall is a template-compatible representation of api.ToolCall
|
||||
@@ -435,7 +413,7 @@ type templateToolCall struct {
|
||||
type templateToolCallFunction struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments templateArgs
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// templateMessage is a template-compatible representation of api.Message
|
||||
@@ -468,7 +446,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
|
||||
Defs: tool.Function.Parameters.Defs,
|
||||
Items: tool.Function.Parameters.Items,
|
||||
Required: tool.Function.Parameters.Required,
|
||||
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
|
||||
Properties: tool.Function.Parameters.Properties.ToMap(),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -490,7 +468,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
|
||||
Function: templateToolCallFunction{
|
||||
Index: tc.Function.Index,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
|
||||
Arguments: tc.Function.Arguments.ToMap(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -613,159 +613,3 @@ func TestCollate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "Tokyo")
|
||||
args.Set("unit", "celsius")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Arguments output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:{...}]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Properties output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsRange(t *testing.T) {
|
||||
// Test that we can range over Arguments in templates
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("city", "Tokyo")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "city=Tokyo;" {
|
||||
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesRange(t *testing.T) {
|
||||
// Test that we can range over Properties in templates
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "location:string;" {
|
||||
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ package agent
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -180,7 +179,6 @@ func FormatDeniedResult(command string, pattern string) string {
|
||||
// extractBashPrefix extracts a prefix pattern from a bash command.
|
||||
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
|
||||
// For commands without path args, returns empty string.
|
||||
// Paths with ".." traversal that escape the base directory return empty string for security.
|
||||
func extractBashPrefix(command string) string {
|
||||
// Split command by pipes and get the first part
|
||||
parts := strings.Split(command, "|")
|
||||
@@ -206,8 +204,8 @@ func extractBashPrefix(command string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the first path-like argument (must contain / or \ or start with .)
|
||||
// First pass: look for clear paths (containing path separators or starting with .)
|
||||
// Find the first path-like argument (must contain / or start with .)
|
||||
// First pass: look for clear paths (containing / or starting with .)
|
||||
for _, arg := range fields[1:] {
|
||||
// Skip flags
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
@@ -217,49 +215,19 @@ func extractBashPrefix(command string) string {
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// Only process if it looks like a path (contains / or \ or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
|
||||
// Only process if it looks like a path (contains / or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
|
||||
continue
|
||||
}
|
||||
// Normalize to forward slashes for consistent cross-platform matching
|
||||
arg = strings.ReplaceAll(arg, "\\", "/")
|
||||
|
||||
// Security: reject absolute paths
|
||||
if path.IsAbs(arg) {
|
||||
return "" // Absolute path - don't create prefix
|
||||
// If arg ends with /, it's a directory - use it directly
|
||||
if strings.HasSuffix(arg, "/") {
|
||||
return fmt.Sprintf("%s:%s", baseCmd, arg)
|
||||
}
|
||||
|
||||
// Normalize the path using stdlib path.Clean (resolves . and ..)
|
||||
cleaned := path.Clean(arg)
|
||||
|
||||
// Security: reject if cleaned path escapes to parent directory
|
||||
if strings.HasPrefix(cleaned, "..") {
|
||||
return "" // Path escapes - don't create prefix
|
||||
}
|
||||
|
||||
// Security: if original had "..", verify cleaned path didn't escape to sibling
|
||||
// e.g., "tools/a/b/../../../etc" -> "etc" (escaped tools/ to sibling)
|
||||
if strings.Contains(arg, "..") {
|
||||
origBase := strings.SplitN(arg, "/", 2)[0]
|
||||
cleanedBase := strings.SplitN(cleaned, "/", 2)[0]
|
||||
if origBase != cleanedBase {
|
||||
return "" // Path escaped to sibling directory
|
||||
}
|
||||
}
|
||||
|
||||
// Check if arg ends with / (explicit directory)
|
||||
isDir := strings.HasSuffix(arg, "/")
|
||||
|
||||
// Get the directory part
|
||||
var dir string
|
||||
if isDir {
|
||||
dir = cleaned
|
||||
} else {
|
||||
dir = path.Dir(cleaned)
|
||||
}
|
||||
|
||||
// Get the directory part of a file path
|
||||
dir := filepath.Dir(arg)
|
||||
if dir == "." {
|
||||
return fmt.Sprintf("%s:./", baseCmd)
|
||||
// Path is just a directory like "tools" or "src" (no trailing /)
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, arg)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, dir)
|
||||
}
|
||||
@@ -364,8 +332,6 @@ func AllowlistKey(toolName string, args map[string]any) string {
|
||||
}
|
||||
|
||||
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
|
||||
// For bash commands, hierarchical path matching is used - if "cat:tools/" is allowed,
|
||||
// then "cat:tools/subdir/" is also allowed (subdirectories inherit parent permissions).
|
||||
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
@@ -376,20 +342,12 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// For bash commands, check prefix matches with hierarchical path support
|
||||
// For bash commands, check prefix matches
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
prefix := extractBashPrefix(cmd)
|
||||
if prefix != "" {
|
||||
// Check exact prefix match first
|
||||
if a.prefixes[prefix] {
|
||||
return true
|
||||
}
|
||||
// Check hierarchical match: if any stored prefix is a parent of current prefix
|
||||
// e.g., stored "cat:tools/" should match current "cat:tools/subdir/"
|
||||
if a.matchesHierarchicalPrefix(prefix) {
|
||||
return true
|
||||
}
|
||||
if prefix != "" && a.prefixes[prefix] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -402,40 +360,6 @@ func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesHierarchicalPrefix checks if the given prefix matches any stored prefix hierarchically.
|
||||
// For example, if "cat:tools/" is stored, it will match "cat:tools/subdir/" or "cat:tools/a/b/c/".
|
||||
func (a *ApprovalManager) matchesHierarchicalPrefix(currentPrefix string) bool {
|
||||
// Split prefix into command and path parts (format: "cmd:path/")
|
||||
colonIdx := strings.Index(currentPrefix, ":")
|
||||
if colonIdx == -1 {
|
||||
return false
|
||||
}
|
||||
currentCmd := currentPrefix[:colonIdx]
|
||||
currentPath := currentPrefix[colonIdx+1:]
|
||||
|
||||
for storedPrefix := range a.prefixes {
|
||||
storedColonIdx := strings.Index(storedPrefix, ":")
|
||||
if storedColonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
storedCmd := storedPrefix[:storedColonIdx]
|
||||
storedPath := storedPrefix[storedColonIdx+1:]
|
||||
|
||||
// Commands must match exactly
|
||||
if currentCmd != storedCmd {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if current path starts with stored path (hierarchical match)
|
||||
// e.g., "tools/subdir/" starts with "tools/"
|
||||
if strings.HasPrefix(currentPath, storedPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddToAllowlist adds a tool/command to the session allowlist.
|
||||
// For bash commands, it adds the prefix pattern instead of exact command.
|
||||
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
|
||||
@@ -519,12 +443,11 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
// For web search, show query and internet notice
|
||||
// For web search, show query
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
sb.WriteString(fmt.Sprintf("Query: %s", query))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
@@ -1028,79 +951,3 @@ func FormatDenyResult(toolName string, reason string) string {
|
||||
}
|
||||
return fmt.Sprintf("User denied execution of %s.", toolName)
|
||||
}
|
||||
|
||||
// PromptYesNo displays a simple Yes/No prompt and returns the user's choice.
|
||||
// Returns true for Yes, false for No.
|
||||
func PromptYesNo(question string) (bool, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
selected := 0 // 0 = Yes, 1 = No
|
||||
options := []string{"Yes", "No"}
|
||||
|
||||
// Hide cursor
|
||||
fmt.Fprint(os.Stderr, "\033[?25l")
|
||||
defer fmt.Fprint(os.Stderr, "\033[?25h")
|
||||
|
||||
renderYesNo := func() {
|
||||
// Move to start of line and clear
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question)
|
||||
for i, opt := range options {
|
||||
if i == selected {
|
||||
fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m")
|
||||
}
|
||||
|
||||
renderYesNo()
|
||||
|
||||
buf := make([]byte, 3)
|
||||
for {
|
||||
n, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
switch buf[0] {
|
||||
case 'y', 'Y':
|
||||
selected = 0
|
||||
renderYesNo()
|
||||
case 'n', 'N':
|
||||
selected = 1
|
||||
renderYesNo()
|
||||
case '\r', '\n': // Enter
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K") // Clear line
|
||||
return selected == 0, nil
|
||||
case 3: // Ctrl+C
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return false, nil
|
||||
case 27: // Escape - could be arrow key
|
||||
// Read more bytes for arrow keys
|
||||
continue
|
||||
}
|
||||
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
|
||||
// Arrow keys
|
||||
switch buf[2] {
|
||||
case 'D': // Left
|
||||
if selected > 0 {
|
||||
selected--
|
||||
}
|
||||
renderYesNo()
|
||||
case 'C': // Right
|
||||
if selected < len(options)-1 {
|
||||
selected++
|
||||
}
|
||||
renderYesNo()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,27 +151,6 @@ func TestExtractBashPrefix(t *testing.T) {
|
||||
command: "head -n 100",
|
||||
expected: "",
|
||||
},
|
||||
// Path traversal security tests
|
||||
{
|
||||
name: "path traversal - parent escape",
|
||||
command: "cat tools/../../etc/passwd",
|
||||
expected: "", // Should NOT create a prefix - path escapes
|
||||
},
|
||||
{
|
||||
name: "path traversal - deep escape",
|
||||
command: "cat tools/a/b/../../../etc/passwd",
|
||||
expected: "", // Normalizes to "../etc/passwd" - escapes
|
||||
},
|
||||
{
|
||||
name: "path traversal - absolute path",
|
||||
command: "cat /etc/passwd",
|
||||
expected: "", // Absolute paths should not create prefix
|
||||
},
|
||||
{
|
||||
name: "path with safe dotdot - normalized",
|
||||
command: "cat tools/subdir/../file.go",
|
||||
expected: "cat:tools/", // Normalizes to tools/file.go - safe, creates prefix
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -185,34 +164,6 @@ func TestExtractBashPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_PathTraversalBlocked(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow "cat tools/file.go" - creates prefix "cat:tools/"
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Path traversal attack: should NOT be allowed
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../etc/passwd"}) {
|
||||
t.Error("SECURITY: path traversal attack should NOT be allowed")
|
||||
}
|
||||
|
||||
// Another traversal variant
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat tools/../../../etc/shadow"}) {
|
||||
t.Error("SECURITY: deep path traversal should NOT be allowed")
|
||||
}
|
||||
|
||||
// Valid subdirectory access should still work
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
|
||||
t.Error("expected cat tools/subdir/file.go to be allowed")
|
||||
}
|
||||
|
||||
// Safe ".." that normalizes to within allowed directory should work
|
||||
// tools/subdir/../other.go normalizes to tools/other.go which is under tools/
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/../other.go"}) {
|
||||
t.Error("expected cat tools/subdir/../other.go to be allowed (normalizes to tools/other.go)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
@@ -235,119 +186,6 @@ func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_HierarchicalPrefixAllowlist(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow "cat tools/file.go" - this creates prefix "cat:tools/"
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Should allow subdirectories (hierarchical matching)
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/subdir/file.go"}) {
|
||||
t.Error("expected cat tools/subdir/file.go to be allowed via hierarchical prefix")
|
||||
}
|
||||
|
||||
// Should allow deeply nested subdirectories
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/a/b/c/deep.go"}) {
|
||||
t.Error("expected cat tools/a/b/c/deep.go to be allowed via hierarchical prefix")
|
||||
}
|
||||
|
||||
// Should still allow same directory
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/another.go"}) {
|
||||
t.Error("expected cat tools/another.go to be allowed")
|
||||
}
|
||||
|
||||
// Should NOT allow different base directory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
|
||||
t.Error("expected cat src/main.go to NOT be allowed")
|
||||
}
|
||||
|
||||
// Should NOT allow different command even in subdirectory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "ls tools/subdir/"}) {
|
||||
t.Error("expected ls tools/subdir/ to NOT be allowed (different command)")
|
||||
}
|
||||
|
||||
// Should NOT allow similar but different directory name
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat toolsbin/file.go"}) {
|
||||
t.Error("expected cat toolsbin/file.go to NOT be allowed (different directory)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_HierarchicalPrefixAllowlist_CrossPlatform(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow with forward slashes (Unix-style)
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Should work with backslashes too (Windows-style) - normalized internally
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\subdir\\file.go"}) {
|
||||
t.Error("expected cat tools\\subdir\\file.go to be allowed via hierarchical prefix (Windows path)")
|
||||
}
|
||||
|
||||
// Mixed slashes should also work
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools\\a/b\\c/deep.go"}) {
|
||||
t.Error("expected mixed slash path to be allowed via hierarchical prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesHierarchicalPrefix(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Add prefix for "cat:tools/"
|
||||
am.prefixes["cat:tools/"] = true
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
prefix: "cat:tools/",
|
||||
expected: true, // exact match also passes HasPrefix - caller handles exact match first
|
||||
},
|
||||
{
|
||||
name: "subdirectory",
|
||||
prefix: "cat:tools/subdir/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
prefix: "cat:tools/a/b/c/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different base directory",
|
||||
prefix: "cat:src/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different command same path",
|
||||
prefix: "ls:tools/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "similar directory name",
|
||||
prefix: "cat:toolsbin/",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid prefix format",
|
||||
prefix: "cattools",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := am.matchesHierarchicalPrefix(tt.prefix)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchesHierarchicalPrefix(%q) = %v, expected %v",
|
||||
tt.prefix, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatApprovalResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
263
x/cmd/run.go
263
x/cmd/run.go
@@ -6,12 +6,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
@@ -24,101 +22,6 @@ import (
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// Tool output capping constants
|
||||
const (
|
||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||
localModelTokenLimit = 4000
|
||||
|
||||
// defaultTokenLimit is the token limit for cloud/remote models.
|
||||
defaultTokenLimit = 10000
|
||||
|
||||
// charsPerToken is a rough estimate of characters per token.
|
||||
// TODO: Estimate tokens more accurately using tokenizer if available
|
||||
charsPerToken = 4
|
||||
)
|
||||
|
||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||
// TODO: Improve local/cloud model identification - could check model metadata
|
||||
func isLocalModel(modelName string) bool {
|
||||
return !strings.HasSuffix(modelName, "-cloud")
|
||||
}
|
||||
|
||||
// isLocalServer checks if connecting to a local Ollama server.
|
||||
// TODO: Could also check other indicators of local vs cloud server
|
||||
func isLocalServer() bool {
|
||||
host := os.Getenv("OLLAMA_HOST")
|
||||
if host == "" {
|
||||
return true // Default is localhost:11434
|
||||
}
|
||||
|
||||
// Parse the URL to check host
|
||||
parsed, err := url.Parse(host)
|
||||
if err != nil {
|
||||
return true // If can't parse, assume local
|
||||
}
|
||||
|
||||
hostname := parsed.Hostname()
|
||||
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
|
||||
}
|
||||
|
||||
// truncateToolOutput truncates tool output to prevent context overflow.
|
||||
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
|
||||
func truncateToolOutput(output, modelName string) string {
|
||||
var tokenLimit int
|
||||
if isLocalModel(modelName) && isLocalServer() {
|
||||
tokenLimit = localModelTokenLimit
|
||||
} else {
|
||||
tokenLimit = defaultTokenLimit
|
||||
}
|
||||
|
||||
maxChars := tokenLimit * charsPerToken
|
||||
if len(output) > maxChars {
|
||||
return output[:maxChars] + "\n... (output truncated)"
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
// waitForOllamaSignin shows the signin URL and polls until authentication completes.
|
||||
func waitForOllamaSignin(ctx context.Context) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get signin URL from initial Whoami call
|
||||
_, err = client.Whoami(ctx)
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
|
||||
fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL)
|
||||
fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m")
|
||||
|
||||
// Poll until auth succeeds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
user, whoamiErr := client.Whoami(ctx)
|
||||
if whoamiErr == nil && user != nil && user.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name)
|
||||
return nil
|
||||
}
|
||||
// Still waiting, show dot
|
||||
fmt.Fprintf(os.Stderr, ".")
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunOptions contains options for running an interactive agent session.
|
||||
type RunOptions struct {
|
||||
Model string
|
||||
@@ -134,16 +37,6 @@ type RunOptions struct {
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
Approval *agent.ApprovalManager
|
||||
|
||||
// YoloMode skips all tool approval prompts
|
||||
YoloMode bool
|
||||
|
||||
// LastToolOutput stores the full output of the last tool execution
|
||||
// for Ctrl+O expansion. Updated by Chat(), read by caller.
|
||||
LastToolOutput *string
|
||||
|
||||
// LastToolOutputTruncated stores the truncated version shown inline
|
||||
LastToolOutputTruncated *string
|
||||
}
|
||||
|
||||
// Chat runs an agent chat loop with tool support.
|
||||
@@ -184,7 +77,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
var consecutiveErrors int // Track consecutive 500 errors for retry limit
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
@@ -267,58 +159,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for 401 Unauthorized - prompt user to sign in
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) {
|
||||
p.StopAndClear()
|
||||
fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the chat request
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
|
||||
continue // Retry the loop
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
|
||||
}
|
||||
|
||||
// Check for 500 errors (often tool parsing failures) - inform the model
|
||||
var statusErr api.StatusError
|
||||
if errors.As(err, &statusErr) && statusErr.StatusCode >= 500 {
|
||||
consecutiveErrors++
|
||||
p.StopAndClear()
|
||||
|
||||
if consecutiveErrors >= 3 {
|
||||
fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n")
|
||||
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage)
|
||||
|
||||
// Include both the model's response and the error so it can learn
|
||||
assistantContent := fullResponse.String()
|
||||
if assistantContent == "" {
|
||||
assistantContent = "(empty response)"
|
||||
}
|
||||
errorMsg := fmt.Sprintf("Your previous response caused an error: %s\n\nYour response was:\n%s\n\nPlease try again with a valid response.", statusErr.ErrorMessage, assistantContent)
|
||||
messages = append(messages,
|
||||
api.Message{Role: "user", Content: errorMsg},
|
||||
)
|
||||
|
||||
// Reset state and retry
|
||||
fullResponse.Reset()
|
||||
thinkingContent.Reset()
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = false
|
||||
pendingToolCalls = nil
|
||||
state = &displayResponseState{}
|
||||
p = progress.NewProgress(os.Stderr)
|
||||
spinner = progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
@@ -328,9 +168,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset consecutive error counter on success
|
||||
consecutiveErrors = 0
|
||||
|
||||
// If no tool calls, we're done
|
||||
if len(pendingToolCalls) == 0 || toolRegistry == nil {
|
||||
break
|
||||
@@ -379,12 +216,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
|
||||
// Check approval (uses prefix matching for bash commands)
|
||||
// In yolo mode, skip all approval prompts
|
||||
if opts.YoloMode {
|
||||
if !skipApproval {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
}
|
||||
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
result, err := approval.RequestApproval(toolName, args)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
|
||||
@@ -418,23 +250,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
// Execute the tool
|
||||
toolResult, err := toolRegistry.Execute(call)
|
||||
if err != nil {
|
||||
// Check if web search needs authentication
|
||||
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
|
||||
// Prompt user to sign in
|
||||
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
// Get signin URL and wait for auth completion
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the web search
|
||||
fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n")
|
||||
toolResult, err = toolRegistry.Execute(call)
|
||||
if err == nil {
|
||||
goto toolSuccess
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
@@ -443,34 +258,20 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
})
|
||||
continue
|
||||
}
|
||||
toolSuccess:
|
||||
|
||||
// Display tool output (truncated for display)
|
||||
truncatedOutput := ""
|
||||
if toolResult != "" {
|
||||
output := toolResult
|
||||
if len(output) > 300 {
|
||||
output = output[:300] + "... (truncated, press Ctrl+O to expand)"
|
||||
output = output[:300] + "... (truncated)"
|
||||
}
|
||||
truncatedOutput = output
|
||||
// Show result in grey, indented
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
|
||||
}
|
||||
|
||||
// Store full and truncated output for Ctrl+O toggle
|
||||
if opts.LastToolOutput != nil {
|
||||
*opts.LastToolOutput = toolResult
|
||||
}
|
||||
if opts.LastToolOutputTruncated != nil {
|
||||
*opts.LastToolOutputTruncated = truncatedOutput
|
||||
}
|
||||
|
||||
// Truncate output to prevent context overflow
|
||||
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: toolResultForLLM,
|
||||
Content: toolResult,
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
}
|
||||
@@ -648,8 +449,7 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
|
||||
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
// If yoloMode is true, all tool approvals are skipped.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
@@ -674,11 +474,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
if toolRegistry.Count() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
}
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
|
||||
// Check for OLLAMA_API_KEY for web search
|
||||
if os.Getenv("OLLAMA_API_KEY") == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
@@ -690,11 +490,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
|
||||
// Track last tool output for Ctrl+O toggle
|
||||
var lastToolOutput string
|
||||
var lastToolOutputTruncated string
|
||||
var toolOutputExpanded bool
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
@@ -707,20 +502,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
sb.Reset()
|
||||
continue
|
||||
case errors.Is(err, readline.ErrExpandOutput):
|
||||
// Ctrl+O pressed - toggle between expanded and collapsed tool output
|
||||
if lastToolOutput == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n")
|
||||
} else if toolOutputExpanded {
|
||||
// Currently expanded, show truncated
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n "))
|
||||
toolOutputExpanded = false
|
||||
} else {
|
||||
// Currently collapsed, show full
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n "))
|
||||
toolOutputExpanded = true
|
||||
}
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
@@ -743,9 +524,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Keyboard Shortcuts:")
|
||||
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
@@ -759,21 +537,16 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
LastToolOutput: &lastToolOutput,
|
||||
LastToolOutputTruncated: &lastToolOutputTruncated,
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
}
|
||||
// Reset expanded state for new tool execution
|
||||
toolOutputExpanded = false
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsLocalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "local model without suffix",
|
||||
modelName: "llama3.2",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "local model with version",
|
||||
modelName: "qwen2.5:7b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "cloud model",
|
||||
modelName: "gpt-4-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "cloud model with version",
|
||||
modelName: "claude-3-cloud",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty model name",
|
||||
modelName: "",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isLocalModel(tt.modelName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isLocalModel(%q) = %v, expected %v", tt.modelName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalServer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty host (default)",
|
||||
host: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
host: "http://localhost:11434",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
host: "http://127.0.0.1:11434",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "custom port on localhost",
|
||||
host: "http://localhost:8080",
|
||||
expected: true, // localhost is always considered local
|
||||
},
|
||||
{
|
||||
name: "remote host",
|
||||
host: "http://ollama.example.com:11434",
|
||||
expected: true, // has :11434
|
||||
},
|
||||
{
|
||||
name: "remote host different port",
|
||||
host: "http://ollama.example.com:8080",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.host)
|
||||
result := isLocalServer()
|
||||
if result != tt.expected {
|
||||
t.Errorf("isLocalServer() with OLLAMA_HOST=%q = %v, expected %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateToolOutput(t *testing.T) {
|
||||
// Create outputs of different sizes
|
||||
localLimitOutput := make([]byte, 20000) // > 4k tokens (16k chars)
|
||||
defaultLimitOutput := make([]byte, 50000) // > 10k tokens (40k chars)
|
||||
for i := range localLimitOutput {
|
||||
localLimitOutput[i] = 'a'
|
||||
}
|
||||
for i := range defaultLimitOutput {
|
||||
defaultLimitOutput[i] = 'b'
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
modelName string
|
||||
host string
|
||||
shouldTrim bool
|
||||
expectedLimit int
|
||||
}{
|
||||
{
|
||||
name: "short output local model",
|
||||
output: "hello world",
|
||||
modelName: "llama3.2",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: localModelTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output local model - trimmed at 4k",
|
||||
output: string(localLimitOutput),
|
||||
modelName: "llama3.2",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: localModelTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output cloud model - uses 10k limit",
|
||||
output: string(localLimitOutput), // 20k chars, under 10k token limit
|
||||
modelName: "gpt-4-cloud",
|
||||
host: "",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "very long output cloud model - trimmed at 10k",
|
||||
output: string(defaultLimitOutput),
|
||||
modelName: "gpt-4-cloud",
|
||||
host: "",
|
||||
shouldTrim: true,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
{
|
||||
name: "long output remote server - uses 10k limit",
|
||||
output: string(localLimitOutput),
|
||||
modelName: "llama3.2",
|
||||
host: "http://remote.example.com:8080",
|
||||
shouldTrim: false,
|
||||
expectedLimit: defaultTokenLimit,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.host)
|
||||
result := truncateToolOutput(tt.output, tt.modelName)
|
||||
|
||||
if tt.shouldTrim {
|
||||
maxLen := tt.expectedLimit * charsPerToken
|
||||
if len(result) > maxLen+50 { // +50 for the truncation message
|
||||
t.Errorf("expected output to be truncated to ~%d chars, got %d", maxLen, len(result))
|
||||
}
|
||||
if result == tt.output {
|
||||
t.Error("expected output to be truncated but it wasn't")
|
||||
}
|
||||
} else {
|
||||
if result != tt.output {
|
||||
t.Error("expected output to not be truncated")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -89,16 +88,9 @@ func (r *Registry) Count() int {
|
||||
}
|
||||
|
||||
// DefaultRegistry creates a registry with all built-in tools.
|
||||
// Tools can be disabled via environment variables:
|
||||
// - OLLAMA_AGENT_DISABLE_WEBSEARCH=1 disables web_search
|
||||
// - OLLAMA_AGENT_DISABLE_BASH=1 disables bash
|
||||
func DefaultRegistry() *Registry {
|
||||
r := NewRegistry()
|
||||
if os.Getenv("OLLAMA_AGENT_DISABLE_WEBSEARCH") == "" {
|
||||
r.Register(&WebSearchTool{})
|
||||
}
|
||||
if os.Getenv("OLLAMA_AGENT_DISABLE_BASH") == "" {
|
||||
r.Register(&BashTool{})
|
||||
}
|
||||
r.Register(&WebSearchTool{})
|
||||
r.Register(&BashTool{})
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -108,57 +108,6 @@ func TestDefaultRegistry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistry_DisableWebsearch(t *testing.T) {
|
||||
t.Setenv("OLLAMA_AGENT_DISABLE_WEBSEARCH", "1")
|
||||
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected 1 tool with websearch disabled, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("bash")
|
||||
if !ok {
|
||||
t.Error("expected bash tool in registry")
|
||||
}
|
||||
|
||||
_, ok = r.Get("web_search")
|
||||
if ok {
|
||||
t.Error("expected web_search to be disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistry_DisableBash(t *testing.T) {
|
||||
t.Setenv("OLLAMA_AGENT_DISABLE_BASH", "1")
|
||||
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 1 {
|
||||
t.Errorf("expected 1 tool with bash disabled, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("web_search")
|
||||
if !ok {
|
||||
t.Error("expected web_search tool in registry")
|
||||
}
|
||||
|
||||
_, ok = r.Get("bash")
|
||||
if ok {
|
||||
t.Error("expected bash to be disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistry_DisableBoth(t *testing.T) {
|
||||
t.Setenv("OLLAMA_AGENT_DISABLE_WEBSEARCH", "1")
|
||||
t.Setenv("OLLAMA_AGENT_DISABLE_BASH", "1")
|
||||
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 0 {
|
||||
t.Errorf("expected 0 tools with both disabled, got %d", r.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_Schema(t *testing.T) {
|
||||
tool := &BashTool{}
|
||||
|
||||
|
||||
@@ -2,19 +2,15 @@ package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,9 +18,6 @@ const (
|
||||
webSearchTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// ErrWebSearchAuthRequired is returned when web search requires authentication
|
||||
var ErrWebSearchAuthRequired = errors.New("web search requires authentication")
|
||||
|
||||
// WebSearchTool implements web search using Ollama's hosted API.
|
||||
type WebSearchTool struct{}
|
||||
|
||||
@@ -75,13 +68,17 @@ type webSearchResult struct {
|
||||
}
|
||||
|
||||
// Execute performs the web search.
|
||||
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
|
||||
func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || query == "" {
|
||||
return "", fmt.Errorf("query parameter is required")
|
||||
}
|
||||
|
||||
apiKey := os.Getenv("OLLAMA_API_KEY")
|
||||
if apiKey == "" {
|
||||
return "", fmt.Errorf("OLLAMA_API_KEY environment variable is required for web search")
|
||||
}
|
||||
|
||||
// Prepare request
|
||||
reqBody := webSearchRequest{
|
||||
Query: query,
|
||||
@@ -93,34 +90,13 @@ func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
|
||||
return "", fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
// Parse URL and add timestamp for signing
|
||||
searchURL, err := url.Parse(webSearchAPI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parsing search URL: %w", err)
|
||||
}
|
||||
|
||||
q := searchURL.Query()
|
||||
q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
searchURL.RawQuery = q.Encode()
|
||||
|
||||
// Sign the request using Ollama key (~/.ollama/id_ed25519)
|
||||
// This authenticates with ollama.com using the local signing key
|
||||
ctx := context.Background()
|
||||
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, searchURL.RequestURI())
|
||||
signature, err := auth.Sign(ctx, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("signing request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, searchURL.String(), bytes.NewBuffer(jsonBody))
|
||||
req, err := http.NewRequest("POST", webSearchAPI, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if signature != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// Send request
|
||||
client := &http.Client{Timeout: webSearchTimeout}
|
||||
@@ -135,9 +111,6 @@ func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
|
||||
return "", fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return "", ErrWebSearchAuthRequired
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("web search API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWebSearchTool_Name(t *testing.T) {
|
||||
tool := &WebSearchTool{}
|
||||
if tool.Name() != "web_search" {
|
||||
t.Errorf("expected name 'web_search', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchTool_Description(t *testing.T) {
|
||||
tool := &WebSearchTool{}
|
||||
if tool.Description() == "" {
|
||||
t.Error("expected non-empty description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchTool_Execute_MissingQuery(t *testing.T) {
|
||||
tool := &WebSearchTool{}
|
||||
|
||||
// Test with no query
|
||||
_, err := tool.Execute(map[string]any{})
|
||||
if err == nil {
|
||||
t.Error("expected error for missing query")
|
||||
}
|
||||
|
||||
// Test with empty query
|
||||
_, err = tool.Execute(map[string]any{"query": ""})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty query")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrWebSearchAuthRequired(t *testing.T) {
|
||||
// Test that the error type exists and can be checked with errors.Is
|
||||
err := ErrWebSearchAuthRequired
|
||||
if err == nil {
|
||||
t.Fatal("ErrWebSearchAuthRequired should not be nil")
|
||||
}
|
||||
|
||||
if err.Error() != "web search requires authentication" {
|
||||
t.Errorf("unexpected error message: %s", err.Error())
|
||||
}
|
||||
|
||||
// Test that errors.Is works
|
||||
wrappedErr := errors.New("wrapped: " + err.Error())
|
||||
if errors.Is(wrappedErr, ErrWebSearchAuthRequired) {
|
||||
t.Error("wrapped error should not match with errors.Is")
|
||||
}
|
||||
|
||||
if !errors.Is(ErrWebSearchAuthRequired, ErrWebSearchAuthRequired) {
|
||||
t.Error("ErrWebSearchAuthRequired should match itself with errors.Is")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user