mirror of
https://github.com/ollama/ollama.git
synced 2026-01-17 03:49:12 -05:00
Compare commits
14 Commits
usage
...
improve-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b7456ca1f | ||
|
|
44179b7e53 | ||
|
|
359be5b658 | ||
|
|
820e51e144 | ||
|
|
8470c25fa9 | ||
|
|
c8b599bd44 | ||
|
|
59928c536b | ||
|
|
0b4850812f | ||
|
|
9383082070 | ||
|
|
85e48af46a | ||
|
|
aa9a1477b3 | ||
|
|
aed714a676 | ||
|
|
064c6a984e | ||
|
|
3aaa8d5564 |
@@ -377,15 +377,6 @@ func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||
return &lr, nil
|
||||
}
|
||||
|
||||
// Usage returns usage statistics and system info.
|
||||
func (c *Client) Usage(ctx context.Context) (*UsageResponse, error) {
|
||||
var ur UsageResponse
|
||||
if err := c.do(ctx, http.MethodGet, "/api/usage", nil, &ur); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ur, nil
|
||||
}
|
||||
|
||||
// Copy copies a model - creating a model with another name from an existing
|
||||
// model.
|
||||
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
||||
|
||||
27
api/types.go
27
api/types.go
@@ -792,33 +792,6 @@ type ProcessResponse struct {
|
||||
Models []ProcessModelResponse `json:"models"`
|
||||
}
|
||||
|
||||
// UsageResponse is the response from [Client.Usage].
|
||||
type UsageResponse struct {
|
||||
GPUs []GPUUsage `json:"gpus,omitempty"`
|
||||
}
|
||||
|
||||
// GPUUsage contains GPU/device memory usage breakdown.
|
||||
type GPUUsage struct {
|
||||
Name string `json:"name"` // Device name (e.g., "Apple M2 Max", "NVIDIA GeForce RTX 4090")
|
||||
Backend string `json:"backend"` // CUDA, ROCm, Metal, etc.
|
||||
Total uint64 `json:"total"`
|
||||
Free uint64 `json:"free"`
|
||||
Used uint64 `json:"used"` // Memory used by Ollama
|
||||
Other uint64 `json:"other"` // Memory used by other processes
|
||||
}
|
||||
|
||||
// UsageStats contains usage statistics.
|
||||
type UsageStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
TokensInput int64 `json:"tokens_input"`
|
||||
TokensOutput int64 `json:"tokens_output"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Models map[string]int64 `json:"models,omitempty"`
|
||||
Sources map[string]int64 `json:"sources,omitempty"`
|
||||
ToolCalls int64 `json:"tool_calls,omitempty"`
|
||||
StructuredOutput int64 `json:"structured_output,omitempty"`
|
||||
}
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -1833,7 +1833,6 @@ func NewCLI() *cobra.Command {
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: ListRunningHandler,
|
||||
}
|
||||
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE DESTINATION",
|
||||
Short: "Copy a model",
|
||||
|
||||
@@ -206,8 +206,6 @@ var (
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
EnableVulkan = Bool("OLLAMA_VULKAN")
|
||||
// Usage enables usage statistics reporting
|
||||
Usage = Bool("OLLAMA_USAGE")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
|
||||
128
server/routes.go
128
server/routes.go
@@ -20,7 +20,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -45,7 +44,6 @@ import (
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/server/usage"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/tools"
|
||||
@@ -84,7 +82,6 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
lowVRAM bool
|
||||
stats *usage.Stats
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -107,30 +104,6 @@ var (
|
||||
errBadTemplate = errors.New("template error")
|
||||
)
|
||||
|
||||
// usage records a request to usage stats if enabled.
|
||||
func (s *Server) usage(c *gin.Context, endpoint, model, architecture string, promptTokens, completionTokens int, usedTools bool) {
|
||||
if s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.Record(&usage.Request{
|
||||
Endpoint: endpoint,
|
||||
Model: model,
|
||||
Architecture: architecture,
|
||||
APIType: usage.ClassifyAPIType(c.Request.URL.Path),
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
UsedTools: usedTools,
|
||||
})
|
||||
}
|
||||
|
||||
// usageError records a failed request to usage stats if enabled.
|
||||
func (s *Server) usageError() {
|
||||
if s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.RecordError()
|
||||
}
|
||||
|
||||
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
@@ -401,7 +374,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
return
|
||||
} else if err != nil {
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -588,7 +561,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
s.usage(c, "generate", m.ShortName, m.Config.ModelFamily, cr.PromptEvalCount, cr.EvalCount, false)
|
||||
|
||||
if !req.Raw {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||
@@ -708,7 +680,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -818,7 +790,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: int(totalTokens),
|
||||
}
|
||||
s.usage(c, "embed", m.ShortName, m.Config.ModelFamily, int(totalTokens), 0, false)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -856,7 +827,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1560,7 +1531,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.GET("/api/usage", s.UsageHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
@@ -1623,13 +1593,6 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
// Initialize usage stats if enabled
|
||||
if envconfig.Usage() {
|
||||
s.stats = usage.New()
|
||||
s.stats.Start()
|
||||
slog.Info("usage stats enabled")
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
var err error
|
||||
@@ -1669,9 +1632,6 @@ func Serve(ln net.Listener) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
if s.stats != nil {
|
||||
s.stats.Stop()
|
||||
}
|
||||
srvr.Close()
|
||||
schedDone()
|
||||
sched.unloadAllRunners()
|
||||
@@ -1689,24 +1649,6 @@ func Serve(ln net.Listener) error {
|
||||
gpus := discover.GPUDevices(ctx, nil)
|
||||
discover.LogDetails(gpus)
|
||||
|
||||
// Set GPU info for usage reporting
|
||||
if s.stats != nil {
|
||||
usage.GPUInfoFunc = func() []usage.GPU {
|
||||
var result []usage.GPU
|
||||
for _, gpu := range gpus {
|
||||
result = append(result, usage.GPU{
|
||||
Name: gpu.Name,
|
||||
VRAMBytes: gpu.TotalMemory,
|
||||
ComputeMajor: gpu.ComputeMajor,
|
||||
ComputeMinor: gpu.ComputeMinor,
|
||||
DriverMajor: gpu.DriverMajor,
|
||||
DriverMinor: gpu.DriverMinor,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
var totalVRAM uint64
|
||||
for _, gpu := range gpus {
|
||||
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
|
||||
@@ -1910,63 +1852,6 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
func (s *Server) UsageHandler(c *gin.Context) {
|
||||
// Get total VRAM used by Ollama
|
||||
s.sched.loadedMu.Lock()
|
||||
var totalOllamaVRAM uint64
|
||||
for _, runner := range s.sched.loaded {
|
||||
totalOllamaVRAM += runner.vramSize
|
||||
}
|
||||
s.sched.loadedMu.Unlock()
|
||||
|
||||
var resp api.UsageResponse
|
||||
|
||||
// Get GPU/device info
|
||||
gpus := discover.GPUDevices(c.Request.Context(), nil)
|
||||
|
||||
// On Apple Silicon, use system memory instead of Metal's recommendedMaxWorkingSetSize
|
||||
// because unified memory means GPU and CPU share the same physical RAM pool
|
||||
var sysTotal, sysFree uint64
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||
sysInfo := discover.GetSystemInfo()
|
||||
sysTotal = sysInfo.TotalMemory
|
||||
sysFree = sysInfo.FreeMemory
|
||||
}
|
||||
|
||||
for _, gpu := range gpus {
|
||||
total := gpu.TotalMemory
|
||||
free := gpu.FreeMemory
|
||||
|
||||
// On Apple Silicon, override with system memory values
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" && sysTotal > 0 {
|
||||
total = sysTotal
|
||||
free = sysFree
|
||||
}
|
||||
|
||||
used := total - free
|
||||
ollamaUsed := min(totalOllamaVRAM, used)
|
||||
otherUsed := used - ollamaUsed
|
||||
|
||||
// Use Description for Name (actual device name like "Apple M2 Max")
|
||||
// Fall back to backend name if Description is empty
|
||||
name := gpu.Description
|
||||
if name == "" {
|
||||
name = gpu.Name
|
||||
}
|
||||
|
||||
resp.GPUs = append(resp.GPUs, api.GPUUsage{
|
||||
Name: name,
|
||||
Backend: gpu.Library,
|
||||
Total: total,
|
||||
Free: free,
|
||||
Used: ollamaUsed,
|
||||
Other: otherUsed,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
@@ -2147,7 +2032,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
return
|
||||
} else if err != nil {
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2295,7 +2180,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, r.PromptEvalCount, r.EvalCount, len(req.Tools) > 0)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
@@ -2471,7 +2355,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
}
|
||||
|
||||
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, resp.PromptEvalCount, resp.EvalCount, len(toolCalls) > 0)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
@@ -2479,8 +2362,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func (s *Server) handleScheduleError(c *gin.Context, name string, err error) {
|
||||
s.usageError()
|
||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
switch {
|
||||
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestUsageHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("empty server", func(t *testing.T) {
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
loaded: make(map[string]*runnerRef),
|
||||
},
|
||||
}
|
||||
|
||||
w := createRequest(t, s.UsageHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.UsageResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// GPUs may or may not be present depending on system
|
||||
// Just verify we can decode the response
|
||||
})
|
||||
|
||||
t.Run("response structure", func(t *testing.T) {
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
loaded: make(map[string]*runnerRef),
|
||||
},
|
||||
}
|
||||
|
||||
w := createRequest(t, s.UsageHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
// Verify we can decode the response as valid JSON
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The response should be a valid object (not null)
|
||||
if resp == nil {
|
||||
t.Error("expected non-nil response")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const (
|
||||
reportTimeout = 10 * time.Second
|
||||
usageURL = "https://ollama.com/api/usage"
|
||||
)
|
||||
|
||||
// HeartbeatResponse is the response from the heartbeat endpoint.
|
||||
type HeartbeatResponse struct {
|
||||
UpdateVersion string `json:"update_version,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateAvailable returns the available update version, if any.
|
||||
func (t *Stats) UpdateAvailable() string {
|
||||
if v := t.updateAvailable.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sendHeartbeat sends usage stats and checks for updates.
|
||||
func (t *Stats) sendHeartbeat(payload *Payload) {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), reportTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, usageURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s", version.Version))
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var heartbeat HeartbeatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&heartbeat); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.updateAvailable.Store(heartbeat.UpdateVersion)
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// API type constants
|
||||
const (
|
||||
APITypeOllama = "ollama"
|
||||
APITypeOpenAI = "openai"
|
||||
APITypeAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// ClassifyAPIType determines the API type from the request path.
|
||||
func ClassifyAPIType(path string) string {
|
||||
if strings.HasPrefix(path, "/v1/messages") {
|
||||
return APITypeAnthropic
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/") {
|
||||
return APITypeOpenAI
|
||||
}
|
||||
return APITypeOllama
|
||||
}
|
||||
@@ -1,324 +0,0 @@
|
||||
// Package usage provides in-memory usage statistics collection and reporting.
|
||||
package usage
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// Stats collects usage statistics in memory and reports them periodically.
|
||||
type Stats struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Atomic counters for hot path
|
||||
requestsTotal atomic.Int64
|
||||
tokensPrompt atomic.Int64
|
||||
tokensCompletion atomic.Int64
|
||||
errorsTotal atomic.Int64
|
||||
|
||||
// Map-based counters (require lock)
|
||||
endpoints map[string]int64
|
||||
architectures map[string]int64
|
||||
apis map[string]int64
|
||||
models map[string]*ModelStats // per-model stats
|
||||
|
||||
// Feature usage
|
||||
toolCalls atomic.Int64
|
||||
structuredOutput atomic.Int64
|
||||
|
||||
// Update info (set by reporter after pinging update endpoint)
|
||||
updateAvailable atomic.Value // string
|
||||
|
||||
// Reporter
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
interval time.Duration
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// ModelStats tracks per-model usage statistics.
|
||||
type ModelStats struct {
|
||||
Requests int64
|
||||
TokensInput int64
|
||||
TokensOutput int64
|
||||
}
|
||||
|
||||
// Request contains the data to record for a single request.
|
||||
type Request struct {
|
||||
Endpoint string // "chat", "generate", "embed"
|
||||
Model string // model name (e.g., "llama3.2:3b")
|
||||
Architecture string // model architecture (e.g., "llama", "qwen2")
|
||||
APIType string // "native" or "openai_compat"
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
UsedTools bool
|
||||
StructuredOutput bool
|
||||
}
|
||||
|
||||
// SystemInfo contains hardware information to report.
|
||||
type SystemInfo struct {
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
CPUCores int `json:"cpu_cores"`
|
||||
RAMBytes uint64 `json:"ram_bytes"`
|
||||
GPUs []GPU `json:"gpus,omitempty"`
|
||||
}
|
||||
|
||||
// GPU contains information about a GPU.
|
||||
type GPU struct {
|
||||
Name string `json:"name"`
|
||||
VRAMBytes uint64 `json:"vram_bytes"`
|
||||
ComputeMajor int `json:"compute_major,omitempty"`
|
||||
ComputeMinor int `json:"compute_minor,omitempty"`
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
}
|
||||
|
||||
// Payload is the data sent to the heartbeat endpoint.
|
||||
type Payload struct {
|
||||
Version string `json:"version"`
|
||||
Time time.Time `json:"time"`
|
||||
System SystemInfo `json:"system"`
|
||||
|
||||
Totals struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Errors int64 `json:"errors"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"totals"`
|
||||
|
||||
Endpoints map[string]int64 `json:"endpoints"`
|
||||
Architectures map[string]int64 `json:"architectures"`
|
||||
APIs map[string]int64 `json:"apis"`
|
||||
|
||||
Features struct {
|
||||
ToolCalls int64 `json:"tool_calls"`
|
||||
StructuredOutput int64 `json:"structured_output"`
|
||||
} `json:"features"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// New creates a new Stats instance.
|
||||
func New(opts ...Option) *Stats {
|
||||
t := &Stats{
|
||||
endpoints: make(map[string]int64),
|
||||
architectures: make(map[string]int64),
|
||||
apis: make(map[string]int64),
|
||||
models: make(map[string]*ModelStats),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
interval: defaultInterval,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Option configures the Stats instance.
|
||||
type Option func(*Stats)
|
||||
|
||||
// WithInterval sets the reporting interval.
|
||||
func WithInterval(d time.Duration) Option {
|
||||
return func(t *Stats) {
|
||||
t.interval = d
|
||||
}
|
||||
}
|
||||
|
||||
// Record records a request. This is the hot path and should be fast.
|
||||
func (t *Stats) Record(r *Request) {
|
||||
t.requestsTotal.Add(1)
|
||||
t.tokensPrompt.Add(int64(r.PromptTokens))
|
||||
t.tokensCompletion.Add(int64(r.CompletionTokens))
|
||||
|
||||
if r.UsedTools {
|
||||
t.toolCalls.Add(1)
|
||||
}
|
||||
if r.StructuredOutput {
|
||||
t.structuredOutput.Add(1)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.endpoints[r.Endpoint]++
|
||||
t.architectures[r.Architecture]++
|
||||
t.apis[r.APIType]++
|
||||
|
||||
// Track per-model stats
|
||||
if r.Model != "" {
|
||||
if t.models[r.Model] == nil {
|
||||
t.models[r.Model] = &ModelStats{}
|
||||
}
|
||||
t.models[r.Model].Requests++
|
||||
t.models[r.Model].TokensInput += int64(r.PromptTokens)
|
||||
t.models[r.Model].TokensOutput += int64(r.CompletionTokens)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// RecordError records a failed request.
|
||||
func (t *Stats) RecordError() {
|
||||
t.errorsTotal.Add(1)
|
||||
}
|
||||
|
||||
// GetModelStats returns a copy of per-model statistics.
|
||||
func (t *Stats) GetModelStats() map[string]*ModelStats {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*ModelStats, len(t.models))
|
||||
for k, v := range t.models {
|
||||
result[k] = &ModelStats{
|
||||
Requests: v.Requests,
|
||||
TokensInput: v.TokensInput,
|
||||
TokensOutput: v.TokensOutput,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// View returns current stats without resetting counters.
|
||||
func (t *Stats) View() *Payload {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Copy maps
|
||||
endpoints := make(map[string]int64, len(t.endpoints))
|
||||
for k, v := range t.endpoints {
|
||||
endpoints[k] = v
|
||||
}
|
||||
architectures := make(map[string]int64, len(t.architectures))
|
||||
for k, v := range t.architectures {
|
||||
architectures[k] = v
|
||||
}
|
||||
apis := make(map[string]int64, len(t.apis))
|
||||
for k, v := range t.apis {
|
||||
apis[k] = v
|
||||
}
|
||||
|
||||
p := &Payload{
|
||||
Version: version.Version,
|
||||
Time: now,
|
||||
System: getSystemInfo(),
|
||||
Endpoints: endpoints,
|
||||
Architectures: architectures,
|
||||
APIs: apis,
|
||||
}
|
||||
|
||||
p.Totals.Requests = t.requestsTotal.Load()
|
||||
p.Totals.Errors = t.errorsTotal.Load()
|
||||
p.Totals.InputTokens = t.tokensPrompt.Load()
|
||||
p.Totals.OutputTokens = t.tokensCompletion.Load()
|
||||
p.Features.ToolCalls = t.toolCalls.Load()
|
||||
p.Features.StructuredOutput = t.structuredOutput.Load()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Snapshot returns current stats and resets counters.
|
||||
func (t *Stats) Snapshot() *Payload {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
p := &Payload{
|
||||
Version: version.Version,
|
||||
Time: now,
|
||||
System: getSystemInfo(),
|
||||
Endpoints: t.endpoints,
|
||||
Architectures: t.architectures,
|
||||
APIs: t.apis,
|
||||
}
|
||||
|
||||
p.Totals.Requests = t.requestsTotal.Swap(0)
|
||||
p.Totals.Errors = t.errorsTotal.Swap(0)
|
||||
p.Totals.InputTokens = t.tokensPrompt.Swap(0)
|
||||
p.Totals.OutputTokens = t.tokensCompletion.Swap(0)
|
||||
p.Features.ToolCalls = t.toolCalls.Swap(0)
|
||||
p.Features.StructuredOutput = t.structuredOutput.Swap(0)
|
||||
|
||||
// Reset maps
|
||||
t.endpoints = make(map[string]int64)
|
||||
t.architectures = make(map[string]int64)
|
||||
t.apis = make(map[string]int64)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// getSystemInfo collects hardware information.
|
||||
func getSystemInfo() SystemInfo {
|
||||
info := SystemInfo{
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
}
|
||||
|
||||
// Get CPU and memory info
|
||||
sysInfo := discover.GetSystemInfo()
|
||||
info.CPUCores = sysInfo.ThreadCount
|
||||
info.RAMBytes = sysInfo.TotalMemory
|
||||
|
||||
// Get GPU info
|
||||
gpus := getGPUInfo()
|
||||
info.GPUs = gpus
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GPUInfoFunc is a function that returns GPU information.
|
||||
// It's set by the server package after GPU discovery.
|
||||
var GPUInfoFunc func() []GPU
|
||||
|
||||
// getGPUInfo collects GPU information.
|
||||
func getGPUInfo() []GPU {
|
||||
if GPUInfoFunc != nil {
|
||||
return GPUInfoFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start begins the periodic reporting goroutine.
|
||||
func (t *Stats) Start() {
|
||||
go t.reportLoop()
|
||||
}
|
||||
|
||||
// Stop stops reporting and waits for the final report.
|
||||
func (t *Stats) Stop() {
|
||||
close(t.stopCh)
|
||||
<-t.doneCh
|
||||
}
|
||||
|
||||
// reportLoop runs the periodic reporting.
|
||||
func (t *Stats) reportLoop() {
|
||||
defer close(t.doneCh)
|
||||
|
||||
ticker := time.NewTicker(t.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.report()
|
||||
case <-t.stopCh:
|
||||
// Send final report before stopping
|
||||
t.report()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// report sends usage stats and checks for updates.
|
||||
func (t *Stats) report() {
|
||||
payload := t.Snapshot()
|
||||
t.sendHeartbeat(payload)
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
stats := New()
|
||||
if stats == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecord(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
Architecture: "llama",
|
||||
APIType: "native",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
UsedTools: true,
|
||||
StructuredOutput: false,
|
||||
})
|
||||
|
||||
// Check totals
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", payload.Totals.Requests)
|
||||
}
|
||||
if payload.Totals.InputTokens != 100 {
|
||||
t.Errorf("expected 100 prompt tokens, got %d", payload.Totals.InputTokens)
|
||||
}
|
||||
if payload.Totals.OutputTokens != 50 {
|
||||
t.Errorf("expected 50 completion tokens, got %d", payload.Totals.OutputTokens)
|
||||
}
|
||||
if payload.Features.ToolCalls != 1 {
|
||||
t.Errorf("expected 1 tool call, got %d", payload.Features.ToolCalls)
|
||||
}
|
||||
if payload.Features.StructuredOutput != 0 {
|
||||
t.Errorf("expected 0 structured outputs, got %d", payload.Features.StructuredOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelStats(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
// Record requests for multiple models
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
})
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 200,
|
||||
CompletionTokens: 100,
|
||||
})
|
||||
stats.Record(&Request{
|
||||
Model: "mistral:7b",
|
||||
PromptTokens: 50,
|
||||
CompletionTokens: 25,
|
||||
})
|
||||
|
||||
modelStats := stats.GetModelStats()
|
||||
|
||||
// Check llama3:8b stats
|
||||
llama := modelStats["llama3:8b"]
|
||||
if llama == nil {
|
||||
t.Fatal("expected llama3:8b stats")
|
||||
}
|
||||
if llama.Requests != 2 {
|
||||
t.Errorf("expected 2 requests for llama3:8b, got %d", llama.Requests)
|
||||
}
|
||||
if llama.TokensInput != 300 {
|
||||
t.Errorf("expected 300 input tokens for llama3:8b, got %d", llama.TokensInput)
|
||||
}
|
||||
if llama.TokensOutput != 150 {
|
||||
t.Errorf("expected 150 output tokens for llama3:8b, got %d", llama.TokensOutput)
|
||||
}
|
||||
|
||||
// Check mistral:7b stats
|
||||
mistral := modelStats["mistral:7b"]
|
||||
if mistral == nil {
|
||||
t.Fatal("expected mistral:7b stats")
|
||||
}
|
||||
if mistral.Requests != 1 {
|
||||
t.Errorf("expected 1 request for mistral:7b, got %d", mistral.Requests)
|
||||
}
|
||||
if mistral.TokensInput != 50 {
|
||||
t.Errorf("expected 50 input tokens for mistral:7b, got %d", mistral.TokensInput)
|
||||
}
|
||||
if mistral.TokensOutput != 25 {
|
||||
t.Errorf("expected 25 output tokens for mistral:7b, got %d", mistral.TokensOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordError(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.RecordError()
|
||||
stats.RecordError()
|
||||
|
||||
payload := stats.View()
|
||||
if payload.Totals.Errors != 2 {
|
||||
t.Errorf("expected 2 errors, got %d", payload.Totals.Errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestView(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
Architecture: "llama",
|
||||
APIType: "native",
|
||||
})
|
||||
|
||||
// First view
|
||||
_ = stats.View()
|
||||
|
||||
// View should not reset counters
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1 {
|
||||
t.Errorf("View should not reset counters, expected 1 request, got %d", payload.Totals.Requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
})
|
||||
|
||||
// Snapshot should return data and reset counters
|
||||
snapshot := stats.Snapshot()
|
||||
if snapshot.Totals.Requests != 1 {
|
||||
t.Errorf("expected 1 request in snapshot, got %d", snapshot.Totals.Requests)
|
||||
}
|
||||
|
||||
// After snapshot, counters should be reset
|
||||
payload2 := stats.View()
|
||||
if payload2.Totals.Requests != 0 {
|
||||
t.Errorf("expected 0 requests after snapshot, got %d", payload2.Totals.Requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 10,
|
||||
CompletionTokens: 5,
|
||||
})
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = stats.View()
|
||||
_ = stats.GetModelStats()
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 15; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1000 {
|
||||
t.Errorf("expected 1000 requests, got %d", payload.Totals.Requests)
|
||||
}
|
||||
}
|
||||
@@ -381,28 +381,6 @@ func (t templateTools) String() string {
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateArgs is a map type with JSON string output for templates.
|
||||
type templateArgs map[string]any
|
||||
|
||||
func (t templateArgs) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateProperties is a map type with JSON string output for templates.
|
||||
type templateProperties map[string]api.ToolProperty
|
||||
|
||||
func (t templateProperties) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateTool is a template-compatible representation of api.Tool
|
||||
// with Properties as a regular map for template ranging.
|
||||
type templateTool struct {
|
||||
@@ -418,11 +396,11 @@ type templateToolFunction struct {
|
||||
}
|
||||
|
||||
type templateToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties templateProperties `json:"properties"`
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
// templateToolCall is a template-compatible representation of api.ToolCall
|
||||
@@ -435,7 +413,7 @@ type templateToolCall struct {
|
||||
type templateToolCallFunction struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments templateArgs
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// templateMessage is a template-compatible representation of api.Message
|
||||
@@ -468,7 +446,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
|
||||
Defs: tool.Function.Parameters.Defs,
|
||||
Items: tool.Function.Parameters.Items,
|
||||
Required: tool.Function.Parameters.Required,
|
||||
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
|
||||
Properties: tool.Function.Parameters.Properties.ToMap(),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -490,7 +468,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
|
||||
Function: templateToolCallFunction{
|
||||
Index: tc.Function.Index,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
|
||||
Arguments: tc.Function.Arguments.ToMap(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -613,159 +613,3 @@ func TestCollate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "Tokyo")
|
||||
args.Set("unit", "celsius")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Arguments output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:{...}]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Properties output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsRange(t *testing.T) {
|
||||
// Test that we can range over Arguments in templates
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("city", "Tokyo")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "city=Tokyo;" {
|
||||
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesRange(t *testing.T) {
|
||||
// Test that we can range over Properties in templates
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "location:string;" {
|
||||
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1104,3 +1104,108 @@ func PromptYesNo(question string) (bool, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloudModelOption represents a suggested cloud model for the selection prompt.
|
||||
type CloudModelOption struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// PromptModelChoice displays a model selection prompt with multiple options.
|
||||
// Returns the selected model name, or empty string if user declined or cancelled.
|
||||
func PromptModelChoice(question string, models []CloudModelOption) (string, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
// Build options: models + "No thanks, continue"
|
||||
optionCount := len(models) + 1
|
||||
selected := 0
|
||||
|
||||
// Total lines: question + models + "no thanks" + hint = optionCount + 2
|
||||
totalLines := optionCount + 2
|
||||
|
||||
// Hide cursor
|
||||
fmt.Fprint(os.Stderr, "\033[?25l")
|
||||
defer fmt.Fprint(os.Stderr, "\033[?25h")
|
||||
|
||||
firstRender := true
|
||||
|
||||
render := func() {
|
||||
if !firstRender {
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", totalLines-1)
|
||||
}
|
||||
firstRender = false
|
||||
|
||||
// \r\n needed in raw mode for proper line breaks
|
||||
fmt.Fprintf(os.Stderr, "\033[K\033[36m%s\033[0m\r\n", question)
|
||||
|
||||
for i, model := range models {
|
||||
fmt.Fprintf(os.Stderr, "\033[K")
|
||||
if i == selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1;32m> %s\033[0m \033[90m%s\033[0m\r\n", model.Name, model.Description)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[90m%s %s\033[0m\r\n", model.Name, model.Description)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[K")
|
||||
if selected == len(models) {
|
||||
fmt.Fprintf(os.Stderr, " \033[1;32m> No thanks, continue\033[0m\r\n")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[90mNo thanks, continue\033[0m\r\n")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[K\033[90m(↑/↓ to navigate, Enter to confirm)\033[0m")
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
buf := make([]byte, 3)
|
||||
for {
|
||||
n, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
switch buf[0] {
|
||||
case 'j', 'J':
|
||||
if selected < optionCount-1 {
|
||||
selected++
|
||||
}
|
||||
render()
|
||||
case 'k', 'K':
|
||||
if selected > 0 {
|
||||
selected--
|
||||
}
|
||||
render()
|
||||
case '\r', '\n':
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if selected < len(models) {
|
||||
return models[selected].Name, nil
|
||||
}
|
||||
return "", nil
|
||||
case 3: // Ctrl+C
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return "", nil
|
||||
}
|
||||
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
|
||||
switch buf[2] {
|
||||
case 'A': // Up
|
||||
if selected > 0 {
|
||||
selected--
|
||||
}
|
||||
render()
|
||||
case 'B': // Down
|
||||
if selected < optionCount-1 {
|
||||
selected++
|
||||
}
|
||||
render()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
25
x/agent/prompt_test.go
Normal file
25
x/agent/prompt_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCloudModelOptionStruct(t *testing.T) {
|
||||
// Test that the struct is defined correctly
|
||||
models := []CloudModelOption{
|
||||
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
|
||||
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
|
||||
}
|
||||
|
||||
if len(models) != 2 {
|
||||
t.Errorf("expected 2 models, got %d", len(models))
|
||||
}
|
||||
|
||||
if models[0].Name != "glm-4.7:cloud" {
|
||||
t.Errorf("expected glm-4.7:cloud, got %s", models[0].Name)
|
||||
}
|
||||
|
||||
if models[1].Description != "Qwen3 Coder 480B" {
|
||||
t.Errorf("expected 'Qwen3 Coder 480B', got %s", models[1].Description)
|
||||
}
|
||||
}
|
||||
41
x/cmd/cloudmodel_test.go
Normal file
41
x/cmd/cloudmodel_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCloudModelSwitchRequest(t *testing.T) {
|
||||
// Test the error type
|
||||
req := &CloudModelSwitchRequest{Model: "glm-4.7:cloud"}
|
||||
|
||||
// Test Error() method
|
||||
errMsg := req.Error()
|
||||
expected := "switch to model: glm-4.7:cloud"
|
||||
if errMsg != expected {
|
||||
t.Errorf("expected %q, got %q", expected, errMsg)
|
||||
}
|
||||
|
||||
// Test errors.As
|
||||
var err error = req
|
||||
var switchReq *CloudModelSwitchRequest
|
||||
if !errors.As(err, &switchReq) {
|
||||
t.Error("errors.As should return true for CloudModelSwitchRequest")
|
||||
}
|
||||
|
||||
if switchReq.Model != "glm-4.7:cloud" {
|
||||
t.Errorf("expected model glm-4.7:cloud, got %s", switchReq.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuggestedCloudModels(t *testing.T) {
|
||||
// Verify the suggested models are defined
|
||||
if len(suggestedCloudModels) == 0 {
|
||||
t.Error("suggestedCloudModels should not be empty")
|
||||
}
|
||||
|
||||
// Check first model
|
||||
if suggestedCloudModels[0].Name != "glm-4.7:cloud" {
|
||||
t.Errorf("expected first model to be glm-4.7:cloud, got %s", suggestedCloudModels[0].Name)
|
||||
}
|
||||
}
|
||||
218
x/cmd/run.go
218
x/cmd/run.go
@@ -37,6 +37,22 @@ const (
|
||||
charsPerToken = 4
|
||||
)
|
||||
|
||||
// suggestedCloudModels are the models suggested to users after signing in.
|
||||
// TODO(parthsareen): Dynamically recommend models based on user context instead of hardcoding
|
||||
var suggestedCloudModels = []agent.CloudModelOption{
|
||||
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
|
||||
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
|
||||
}
|
||||
|
||||
// CloudModelSwitchRequest signals that the user wants to switch to a different model.
|
||||
type CloudModelSwitchRequest struct {
|
||||
Model string
|
||||
}
|
||||
|
||||
func (c *CloudModelSwitchRequest) Error() string {
|
||||
return fmt.Sprintf("switch to model: %s", c.Model)
|
||||
}
|
||||
|
||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||
// TODO: Improve local/cloud model identification - could check model metadata
|
||||
func isLocalModel(modelName string) bool {
|
||||
@@ -119,6 +135,21 @@ func waitForOllamaSignin(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// promptCloudModelSuggestion shows cloud model suggestions after successful sign-in.
|
||||
// Returns the selected model name, or empty string if user declines.
|
||||
func promptCloudModelSuggestion() string {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[1;36mTry cloud models for free!\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[90mCloud models offer powerful capabilities without local hardware requirements.\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
selectedModel, err := agent.PromptModelChoice("Try a cloud model now?", suggestedCloudModels)
|
||||
if err != nil || selectedModel == "" {
|
||||
return ""
|
||||
}
|
||||
return selectedModel
|
||||
}
|
||||
|
||||
// RunOptions contains options for running an interactive agent session.
|
||||
type RunOptions struct {
|
||||
Model string
|
||||
@@ -144,6 +175,40 @@ type RunOptions struct {
|
||||
|
||||
// LastToolOutputTruncated stores the truncated version shown inline
|
||||
LastToolOutputTruncated *string
|
||||
|
||||
// ActiveModel points to the current model name - can be updated mid-turn
|
||||
// for model switching. If nil, opts.Model is used.
|
||||
ActiveModel *string
|
||||
}
|
||||
|
||||
// getActiveModel returns the current model name, checking ActiveModel pointer first.
|
||||
func getActiveModel(opts *RunOptions) string {
|
||||
if opts.ActiveModel != nil && *opts.ActiveModel != "" {
|
||||
return *opts.ActiveModel
|
||||
}
|
||||
return opts.Model
|
||||
}
|
||||
|
||||
// showModelConnection displays "Connecting to X on ollama.com" for cloud models.
|
||||
func showModelConnection(ctx context.Context, modelName string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.RemoteHost != "" {
|
||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chat runs an agent chat loop with tool support.
|
||||
@@ -243,7 +308,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
// Agentic loop: continue until no more tool calls
|
||||
for {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Model: getActiveModel(&opts),
|
||||
Messages: messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
@@ -267,7 +332,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for 401 Unauthorized - prompt user to sign in
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) {
|
||||
p.StopAndClear()
|
||||
@@ -275,9 +339,13 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the chat request
|
||||
suggestedModel := promptCloudModelSuggestion()
|
||||
if suggestedModel != "" {
|
||||
return nil, &CloudModelSwitchRequest{Model: suggestedModel}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
|
||||
continue // Retry the loop
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
|
||||
@@ -415,19 +483,20 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := toolRegistry.Execute(call)
|
||||
if err != nil {
|
||||
// Check if web search needs authentication
|
||||
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
|
||||
// Prompt user to sign in
|
||||
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
// Get signin URL and wait for auth completion
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the web search
|
||||
fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n")
|
||||
suggestedModel := promptCloudModelSuggestion()
|
||||
if suggestedModel != "" && opts.ActiveModel != nil {
|
||||
*opts.ActiveModel = suggestedModel
|
||||
showModelConnection(ctx, suggestedModel)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRetrying web search...\033[0m\n")
|
||||
toolResult, err = toolRegistry.Execute(call)
|
||||
if err == nil {
|
||||
goto toolSuccess
|
||||
@@ -466,7 +535,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
|
||||
// Truncate output to prevent context overflow
|
||||
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
|
||||
toolResultForLLM := truncateToolOutput(toolResult, getActiveModel(&opts))
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
@@ -625,25 +694,28 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
return out
|
||||
}
|
||||
|
||||
// checkModelCapabilities checks if the model supports tools.
|
||||
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
|
||||
// checkModelCapabilities checks if the model supports tools and thinking.
|
||||
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, supportsThinking bool, err error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityTools {
|
||||
return true, nil
|
||||
supportsTools = true
|
||||
}
|
||||
if cap == model.CapabilityThinking {
|
||||
supportsThinking = true
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return supportsTools, supportsThinking, nil
|
||||
}
|
||||
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
@@ -663,13 +735,17 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
// Check if model supports tools
|
||||
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||
// Check if model supports tools and thinking
|
||||
supportsTools, supportsThinking, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
|
||||
supportsTools = false
|
||||
supportsThinking = false
|
||||
}
|
||||
|
||||
// Track if session is using thinking mode
|
||||
usingThinking := think != nil && supportsThinking
|
||||
|
||||
// Create tool registry only if model supports tools
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
@@ -757,30 +833,44 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if sb.Len() > 0 {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
LastToolOutput: &lastToolOutput,
|
||||
LastToolOutputTruncated: &lastToolOutputTruncated,
|
||||
}
|
||||
// Reset expanded state for new tool execution
|
||||
toolOutputExpanded = false
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
messages = append(messages, *assistant)
|
||||
retryChat:
|
||||
for {
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
LastToolOutput: &lastToolOutput,
|
||||
LastToolOutputTruncated: &lastToolOutputTruncated,
|
||||
ActiveModel: &modelName,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
var switchReq *CloudModelSwitchRequest
|
||||
if errors.As(err, &switchReq) {
|
||||
newModel := switchReq.Model
|
||||
if err := switchToModel(cmd.Context(), newModel, &modelName, &supportsTools, &supportsThinking, &toolRegistry, usingThinking); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[33m%v\033[0m\n", err)
|
||||
fmt.Fprintf(os.Stderr, "\033[90mContinuing with %s...\033[0m\n", modelName)
|
||||
}
|
||||
continue retryChat
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if assistant != nil {
|
||||
messages = append(messages, *assistant)
|
||||
}
|
||||
break retryChat
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
@@ -788,6 +878,52 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
}
|
||||
|
||||
// switchToModel handles model switching with capability checks and UI updates.
|
||||
func switchToModel(ctx context.Context, newModel string, modelName *string, supportsTools, supportsThinking *bool, toolRegistry **tools.Registry, usingThinking bool) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create client: %w", err)
|
||||
}
|
||||
|
||||
newSupportsTools, newSupportsThinking, capErr := checkModelCapabilities(ctx, newModel)
|
||||
if capErr != nil {
|
||||
return fmt.Errorf("could not check model capabilities: %w", capErr)
|
||||
}
|
||||
|
||||
// TODO(parthsareen): Handle thinking -> non-thinking model switch gracefully
|
||||
if usingThinking && !newSupportsThinking {
|
||||
return fmt.Errorf("%s does not support thinking mode", newModel)
|
||||
}
|
||||
|
||||
// Show "Connecting to X on ollama.com" for cloud models
|
||||
info, err := client.Show(ctx, &api.ShowRequest{Model: newModel})
|
||||
if err == nil && info.RemoteHost != "" {
|
||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
*modelName = newModel
|
||||
*supportsTools = newSupportsTools
|
||||
*supportsThinking = newSupportsThinking
|
||||
|
||||
if *supportsTools {
|
||||
if *toolRegistry == nil {
|
||||
*toolRegistry = tools.DefaultRegistry()
|
||||
}
|
||||
if (*toolRegistry).Count() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join((*toolRegistry).Names(), ", "))
|
||||
}
|
||||
} else {
|
||||
*toolRegistry = nil
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// showToolsStatus displays the current tools and approval status.
|
||||
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
|
||||
if !supportsTools || registry == nil {
|
||||
|
||||
Reference in New Issue
Block a user