Compare commits

..

3 Commits

Author SHA1 Message Date
jmorganca
d132315276 uip
api: expose usage data
2026-01-16 00:24:07 -08:00
Parth Sareen
12e2b3514a x: agent loop ux improvements (#13635) 2026-01-07 01:27:15 -08:00
Devon Rifkin
626af2d809 template: fix args-as-json rendering (#13636)
In #13525, I accidentally broke templates' ability to automatically
render tool call function arguments as JSON.

We do need these to be proper maps because we need templates to be able
to call range, which can't be done on custom types.
2026-01-06 18:33:57 -08:00
16 changed files with 1055 additions and 361 deletions

View File

@@ -377,6 +377,15 @@ func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
return &lr, nil
}
// Usage returns usage statistics and system info.
func (c *Client) Usage(ctx context.Context) (*UsageResponse, error) {
var ur UsageResponse
if err := c.do(ctx, http.MethodGet, "/api/usage", nil, &ur); err != nil {
return nil, err
}
return &ur, nil
}
// Copy copies a model - creating a model with another name from an existing
// model.
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {

View File

@@ -792,6 +792,33 @@ type ProcessResponse struct {
Models []ProcessModelResponse `json:"models"`
}
// UsageResponse is the response from [Client.Usage].
type UsageResponse struct {
GPUs []GPUUsage `json:"gpus,omitempty"`
}
// GPUUsage contains GPU/device memory usage breakdown.
type GPUUsage struct {
Name string `json:"name"` // Device name (e.g., "Apple M2 Max", "NVIDIA GeForce RTX 4090")
Backend string `json:"backend"` // CUDA, ROCm, Metal, etc.
Total uint64 `json:"total"`
Free uint64 `json:"free"`
Used uint64 `json:"used"` // Memory used by Ollama
Other uint64 `json:"other"` // Memory used by other processes
}
// UsageStats contains usage statistics.
type UsageStats struct {
Requests int64 `json:"requests"`
TokensInput int64 `json:"tokens_input"`
TokensOutput int64 `json:"tokens_output"`
TotalTokens int64 `json:"total_tokens"`
Models map[string]int64 `json:"models,omitempty"`
Sources map[string]int64 `json:"sources,omitempty"`
ToolCalls int64 `json:"tool_calls,omitempty"`
StructuredOutput int64 `json:"structured_output,omitempty"`
}
// ListModelResponse is a single model description in [ListResponse].
type ListModelResponse struct {
Name string `json:"name"`

View File

@@ -1833,6 +1833,7 @@ func NewCLI() *cobra.Command {
PreRunE: checkServerHeartbeat,
RunE: ListRunningHandler,
}
copyCmd := &cobra.Command{
Use: "cp SOURCE DESTINATION",
Short: "Copy a model",

View File

@@ -206,6 +206,8 @@ var (
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
EnableVulkan = Bool("OLLAMA_VULKAN")
// Usage enables usage statistics reporting
Usage = Bool("OLLAMA_USAGE")
)
func String(s string) func() string {

View File

@@ -20,6 +20,7 @@ import (
"net/url"
"os"
"os/signal"
"runtime"
"slices"
"strings"
"sync/atomic"
@@ -44,6 +45,7 @@ import (
"github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/server/usage"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/tools"
@@ -82,6 +84,7 @@ type Server struct {
addr net.Addr
sched *Scheduler
lowVRAM bool
stats *usage.Stats
}
func init() {
@@ -104,6 +107,30 @@ var (
errBadTemplate = errors.New("template error")
)
// usage records a request to usage stats if enabled.
func (s *Server) usage(c *gin.Context, endpoint, model, architecture string, promptTokens, completionTokens int, usedTools bool) {
if s.stats == nil {
return
}
s.stats.Record(&usage.Request{
Endpoint: endpoint,
Model: model,
Architecture: architecture,
APIType: usage.ClassifyAPIType(c.Request.URL.Path),
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
UsedTools: usedTools,
})
}
// usageError records a failed request to usage stats if enabled.
func (s *Server) usageError() {
if s.stats == nil {
return
}
s.stats.RecordError()
}
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -374,7 +401,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
s.handleScheduleError(c, req.Model, err)
return
}
@@ -561,6 +588,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
s.usage(c, "generate", m.ShortName, m.Config.ModelFamily, cr.PromptEvalCount, cr.EvalCount, false)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
@@ -680,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
s.handleScheduleError(c, req.Model, err)
return
}
@@ -790,6 +818,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: int(totalTokens),
}
s.usage(c, "embed", m.ShortName, m.Config.ModelFamily, int(totalTokens), 0, false)
c.JSON(http.StatusOK, resp)
}
@@ -827,7 +856,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
s.handleScheduleError(c, req.Model, err)
return
}
@@ -1531,6 +1560,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// Inference
r.GET("/api/ps", s.PsHandler)
r.GET("/api/usage", s.UsageHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
@@ -1593,6 +1623,13 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
// Initialize usage stats if enabled
if envconfig.Usage() {
s.stats = usage.New()
s.stats.Start()
slog.Info("usage stats enabled")
}
var rc *ollama.Registry
if useClient2 {
var err error
@@ -1632,6 +1669,9 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if s.stats != nil {
s.stats.Stop()
}
srvr.Close()
schedDone()
sched.unloadAllRunners()
@@ -1649,6 +1689,24 @@ func Serve(ln net.Listener) error {
gpus := discover.GPUDevices(ctx, nil)
discover.LogDetails(gpus)
// Set GPU info for usage reporting
if s.stats != nil {
usage.GPUInfoFunc = func() []usage.GPU {
var result []usage.GPU
for _, gpu := range gpus {
result = append(result, usage.GPU{
Name: gpu.Name,
VRAMBytes: gpu.TotalMemory,
ComputeMajor: gpu.ComputeMajor,
ComputeMinor: gpu.ComputeMinor,
DriverMajor: gpu.DriverMajor,
DriverMinor: gpu.DriverMinor,
})
}
return result
}
}
var totalVRAM uint64
for _, gpu := range gpus {
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
@@ -1852,6 +1910,63 @@ func (s *Server) PsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
func (s *Server) UsageHandler(c *gin.Context) {
// Get total VRAM used by Ollama
s.sched.loadedMu.Lock()
var totalOllamaVRAM uint64
for _, runner := range s.sched.loaded {
totalOllamaVRAM += runner.vramSize
}
s.sched.loadedMu.Unlock()
var resp api.UsageResponse
// Get GPU/device info
gpus := discover.GPUDevices(c.Request.Context(), nil)
// On Apple Silicon, use system memory instead of Metal's recommendedMaxWorkingSetSize
// because unified memory means GPU and CPU share the same physical RAM pool
var sysTotal, sysFree uint64
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
sysInfo := discover.GetSystemInfo()
sysTotal = sysInfo.TotalMemory
sysFree = sysInfo.FreeMemory
}
for _, gpu := range gpus {
total := gpu.TotalMemory
free := gpu.FreeMemory
// On Apple Silicon, override with system memory values
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" && sysTotal > 0 {
total = sysTotal
free = sysFree
}
used := total - free
ollamaUsed := min(totalOllamaVRAM, used)
otherUsed := used - ollamaUsed
// Use Description for Name (actual device name like "Apple M2 Max")
// Fall back to backend name if Description is empty
name := gpu.Description
if name == "" {
name = gpu.Name
}
resp.GPUs = append(resp.GPUs, api.GPUUsage{
Name: name,
Backend: gpu.Library,
Total: total,
Free: free,
Used: ollamaUsed,
Other: otherUsed,
})
}
c.JSON(http.StatusOK, resp)
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
@@ -2032,7 +2147,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
s.handleScheduleError(c, req.Model, err)
return
}
@@ -2180,6 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, r.PromptEvalCount, r.EvalCount, len(req.Tools) > 0)
}
if builtinParser != nil {
@@ -2355,6 +2471,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
resp.Message.ToolCalls = toolCalls
}
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, resp.PromptEvalCount, resp.EvalCount, len(toolCalls) > 0)
c.JSON(http.StatusOK, resp)
return
}
@@ -2362,7 +2479,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
func handleScheduleError(c *gin.Context, name string, err error) {
func (s *Server) handleScheduleError(c *gin.Context, name string, err error) {
s.usageError()
switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View File

@@ -0,0 +1,60 @@
package server
import (
"encoding/json"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
func TestUsageHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("empty server", func(t *testing.T) {
s := Server{
sched: &Scheduler{
loaded: make(map[string]*runnerRef),
},
}
w := createRequest(t, s.UsageHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.UsageResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
// GPUs may or may not be present depending on system
// Just verify we can decode the response
})
t.Run("response structure", func(t *testing.T) {
s := Server{
sched: &Scheduler{
loaded: make(map[string]*runnerRef),
},
}
w := createRequest(t, s.UsageHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
// Verify we can decode the response as valid JSON
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
// The response should be a valid object (not null)
if resp == nil {
t.Error("expected non-nil response")
}
})
}

65
server/usage/reporter.go Normal file
View File

@@ -0,0 +1,65 @@
package usage
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/ollama/ollama/version"
)
const (
reportTimeout = 10 * time.Second
usageURL = "https://ollama.com/api/usage"
)
// HeartbeatResponse is the response from the heartbeat endpoint.
type HeartbeatResponse struct {
UpdateVersion string `json:"update_version,omitempty"`
}
// UpdateAvailable returns the available update version, if any.
func (t *Stats) UpdateAvailable() string {
if v := t.updateAvailable.Load(); v != nil {
return v.(string)
}
return ""
}
// sendHeartbeat sends usage stats and checks for updates.
func (t *Stats) sendHeartbeat(payload *Payload) {
data, err := json.Marshal(payload)
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), reportTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, usageURL, bytes.NewReader(data))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s", version.Version))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return
}
var heartbeat HeartbeatResponse
if err := json.NewDecoder(resp.Body).Decode(&heartbeat); err != nil {
return
}
t.updateAvailable.Store(heartbeat.UpdateVersion)
}

23
server/usage/source.go Normal file
View File

@@ -0,0 +1,23 @@
package usage
import (
"strings"
)
// API type constants
const (
APITypeOllama = "ollama"
APITypeOpenAI = "openai"
APITypeAnthropic = "anthropic"
)
// ClassifyAPIType determines the API type from the request path.
func ClassifyAPIType(path string) string {
if strings.HasPrefix(path, "/v1/messages") {
return APITypeAnthropic
}
if strings.HasPrefix(path, "/v1/") {
return APITypeOpenAI
}
return APITypeOllama
}

324
server/usage/usage.go Normal file
View File

@@ -0,0 +1,324 @@
// Package usage provides in-memory usage statistics collection and reporting.
package usage
import (
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/version"
)
// Stats collects usage statistics in memory and reports them periodically.
type Stats struct {
mu sync.RWMutex
// Atomic counters for hot path
requestsTotal atomic.Int64
tokensPrompt atomic.Int64
tokensCompletion atomic.Int64
errorsTotal atomic.Int64
// Map-based counters (require lock)
endpoints map[string]int64
architectures map[string]int64
apis map[string]int64
models map[string]*ModelStats // per-model stats
// Feature usage
toolCalls atomic.Int64
structuredOutput atomic.Int64
// Update info (set by reporter after pinging update endpoint)
updateAvailable atomic.Value // string
// Reporter
stopCh chan struct{}
doneCh chan struct{}
interval time.Duration
endpoint string
}
// ModelStats tracks per-model usage statistics.
type ModelStats struct {
Requests int64
TokensInput int64
TokensOutput int64
}
// Request contains the data to record for a single request.
type Request struct {
Endpoint string // "chat", "generate", "embed"
Model string // model name (e.g., "llama3.2:3b")
Architecture string // model architecture (e.g., "llama", "qwen2")
APIType string // "native" or "openai_compat"
PromptTokens int
CompletionTokens int
UsedTools bool
StructuredOutput bool
}
// SystemInfo contains hardware information to report.
type SystemInfo struct {
OS string `json:"os"`
Arch string `json:"arch"`
CPUCores int `json:"cpu_cores"`
RAMBytes uint64 `json:"ram_bytes"`
GPUs []GPU `json:"gpus,omitempty"`
}
// GPU contains information about a GPU.
type GPU struct {
Name string `json:"name"`
VRAMBytes uint64 `json:"vram_bytes"`
ComputeMajor int `json:"compute_major,omitempty"`
ComputeMinor int `json:"compute_minor,omitempty"`
DriverMajor int `json:"driver_major,omitempty"`
DriverMinor int `json:"driver_minor,omitempty"`
}
// Payload is the data sent to the heartbeat endpoint.
type Payload struct {
Version string `json:"version"`
Time time.Time `json:"time"`
System SystemInfo `json:"system"`
Totals struct {
Requests int64 `json:"requests"`
Errors int64 `json:"errors"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
} `json:"totals"`
Endpoints map[string]int64 `json:"endpoints"`
Architectures map[string]int64 `json:"architectures"`
APIs map[string]int64 `json:"apis"`
Features struct {
ToolCalls int64 `json:"tool_calls"`
StructuredOutput int64 `json:"structured_output"`
} `json:"features"`
}
const (
defaultInterval = 1 * time.Hour
)
// New creates a new Stats instance.
func New(opts ...Option) *Stats {
t := &Stats{
endpoints: make(map[string]int64),
architectures: make(map[string]int64),
apis: make(map[string]int64),
models: make(map[string]*ModelStats),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
interval: defaultInterval,
}
for _, opt := range opts {
opt(t)
}
return t
}
// Option configures the Stats instance.
type Option func(*Stats)
// WithInterval sets the reporting interval.
func WithInterval(d time.Duration) Option {
return func(t *Stats) {
t.interval = d
}
}
// Record records a request. This is the hot path and should be fast.
func (t *Stats) Record(r *Request) {
t.requestsTotal.Add(1)
t.tokensPrompt.Add(int64(r.PromptTokens))
t.tokensCompletion.Add(int64(r.CompletionTokens))
if r.UsedTools {
t.toolCalls.Add(1)
}
if r.StructuredOutput {
t.structuredOutput.Add(1)
}
t.mu.Lock()
t.endpoints[r.Endpoint]++
t.architectures[r.Architecture]++
t.apis[r.APIType]++
// Track per-model stats
if r.Model != "" {
if t.models[r.Model] == nil {
t.models[r.Model] = &ModelStats{}
}
t.models[r.Model].Requests++
t.models[r.Model].TokensInput += int64(r.PromptTokens)
t.models[r.Model].TokensOutput += int64(r.CompletionTokens)
}
t.mu.Unlock()
}
// RecordError records a failed request.
func (t *Stats) RecordError() {
t.errorsTotal.Add(1)
}
// GetModelStats returns a copy of per-model statistics.
func (t *Stats) GetModelStats() map[string]*ModelStats {
t.mu.RLock()
defer t.mu.RUnlock()
result := make(map[string]*ModelStats, len(t.models))
for k, v := range t.models {
result[k] = &ModelStats{
Requests: v.Requests,
TokensInput: v.TokensInput,
TokensOutput: v.TokensOutput,
}
}
return result
}
// View returns current stats without resetting counters.
func (t *Stats) View() *Payload {
t.mu.RLock()
defer t.mu.RUnlock()
now := time.Now()
// Copy maps
endpoints := make(map[string]int64, len(t.endpoints))
for k, v := range t.endpoints {
endpoints[k] = v
}
architectures := make(map[string]int64, len(t.architectures))
for k, v := range t.architectures {
architectures[k] = v
}
apis := make(map[string]int64, len(t.apis))
for k, v := range t.apis {
apis[k] = v
}
p := &Payload{
Version: version.Version,
Time: now,
System: getSystemInfo(),
Endpoints: endpoints,
Architectures: architectures,
APIs: apis,
}
p.Totals.Requests = t.requestsTotal.Load()
p.Totals.Errors = t.errorsTotal.Load()
p.Totals.InputTokens = t.tokensPrompt.Load()
p.Totals.OutputTokens = t.tokensCompletion.Load()
p.Features.ToolCalls = t.toolCalls.Load()
p.Features.StructuredOutput = t.structuredOutput.Load()
return p
}
// Snapshot returns current stats and resets counters.
func (t *Stats) Snapshot() *Payload {
t.mu.Lock()
defer t.mu.Unlock()
now := time.Now()
p := &Payload{
Version: version.Version,
Time: now,
System: getSystemInfo(),
Endpoints: t.endpoints,
Architectures: t.architectures,
APIs: t.apis,
}
p.Totals.Requests = t.requestsTotal.Swap(0)
p.Totals.Errors = t.errorsTotal.Swap(0)
p.Totals.InputTokens = t.tokensPrompt.Swap(0)
p.Totals.OutputTokens = t.tokensCompletion.Swap(0)
p.Features.ToolCalls = t.toolCalls.Swap(0)
p.Features.StructuredOutput = t.structuredOutput.Swap(0)
// Reset maps
t.endpoints = make(map[string]int64)
t.architectures = make(map[string]int64)
t.apis = make(map[string]int64)
return p
}
// getSystemInfo collects hardware information.
func getSystemInfo() SystemInfo {
info := SystemInfo{
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
// Get CPU and memory info
sysInfo := discover.GetSystemInfo()
info.CPUCores = sysInfo.ThreadCount
info.RAMBytes = sysInfo.TotalMemory
// Get GPU info
gpus := getGPUInfo()
info.GPUs = gpus
return info
}
// GPUInfoFunc is a function that returns GPU information.
// It's set by the server package after GPU discovery.
var GPUInfoFunc func() []GPU
// getGPUInfo collects GPU information.
func getGPUInfo() []GPU {
if GPUInfoFunc != nil {
return GPUInfoFunc()
}
return nil
}
// Start begins the periodic reporting goroutine.
func (t *Stats) Start() {
go t.reportLoop()
}
// Stop stops reporting and waits for the final report.
func (t *Stats) Stop() {
close(t.stopCh)
<-t.doneCh
}
// reportLoop runs the periodic reporting.
func (t *Stats) reportLoop() {
defer close(t.doneCh)
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.report()
case <-t.stopCh:
// Send final report before stopping
t.report()
return
}
}
}
// report sends usage stats and checks for updates.
func (t *Stats) report() {
payload := t.Snapshot()
t.sendHeartbeat(payload)
}

194
server/usage/usage_test.go Normal file
View File

@@ -0,0 +1,194 @@
package usage
import (
"testing"
)
func TestNew(t *testing.T) {
stats := New()
if stats == nil {
t.Fatal("New() returned nil")
}
}
func TestRecord(t *testing.T) {
stats := New()
stats.Record(&Request{
Model: "llama3:8b",
Endpoint: "chat",
Architecture: "llama",
APIType: "native",
PromptTokens: 100,
CompletionTokens: 50,
UsedTools: true,
StructuredOutput: false,
})
// Check totals
payload := stats.View()
if payload.Totals.Requests != 1 {
t.Errorf("expected 1 request, got %d", payload.Totals.Requests)
}
if payload.Totals.InputTokens != 100 {
t.Errorf("expected 100 prompt tokens, got %d", payload.Totals.InputTokens)
}
if payload.Totals.OutputTokens != 50 {
t.Errorf("expected 50 completion tokens, got %d", payload.Totals.OutputTokens)
}
if payload.Features.ToolCalls != 1 {
t.Errorf("expected 1 tool call, got %d", payload.Features.ToolCalls)
}
if payload.Features.StructuredOutput != 0 {
t.Errorf("expected 0 structured outputs, got %d", payload.Features.StructuredOutput)
}
}
func TestGetModelStats(t *testing.T) {
stats := New()
// Record requests for multiple models
stats.Record(&Request{
Model: "llama3:8b",
PromptTokens: 100,
CompletionTokens: 50,
})
stats.Record(&Request{
Model: "llama3:8b",
PromptTokens: 200,
CompletionTokens: 100,
})
stats.Record(&Request{
Model: "mistral:7b",
PromptTokens: 50,
CompletionTokens: 25,
})
modelStats := stats.GetModelStats()
// Check llama3:8b stats
llama := modelStats["llama3:8b"]
if llama == nil {
t.Fatal("expected llama3:8b stats")
}
if llama.Requests != 2 {
t.Errorf("expected 2 requests for llama3:8b, got %d", llama.Requests)
}
if llama.TokensInput != 300 {
t.Errorf("expected 300 input tokens for llama3:8b, got %d", llama.TokensInput)
}
if llama.TokensOutput != 150 {
t.Errorf("expected 150 output tokens for llama3:8b, got %d", llama.TokensOutput)
}
// Check mistral:7b stats
mistral := modelStats["mistral:7b"]
if mistral == nil {
t.Fatal("expected mistral:7b stats")
}
if mistral.Requests != 1 {
t.Errorf("expected 1 request for mistral:7b, got %d", mistral.Requests)
}
if mistral.TokensInput != 50 {
t.Errorf("expected 50 input tokens for mistral:7b, got %d", mistral.TokensInput)
}
if mistral.TokensOutput != 25 {
t.Errorf("expected 25 output tokens for mistral:7b, got %d", mistral.TokensOutput)
}
}
func TestRecordError(t *testing.T) {
stats := New()
stats.RecordError()
stats.RecordError()
payload := stats.View()
if payload.Totals.Errors != 2 {
t.Errorf("expected 2 errors, got %d", payload.Totals.Errors)
}
}
func TestView(t *testing.T) {
stats := New()
stats.Record(&Request{
Model: "llama3:8b",
Endpoint: "chat",
Architecture: "llama",
APIType: "native",
})
// First view
_ = stats.View()
// View should not reset counters
payload := stats.View()
if payload.Totals.Requests != 1 {
t.Errorf("View should not reset counters, expected 1 request, got %d", payload.Totals.Requests)
}
}
func TestSnapshot(t *testing.T) {
stats := New()
stats.Record(&Request{
Model: "llama3:8b",
Endpoint: "chat",
PromptTokens: 100,
CompletionTokens: 50,
})
// Snapshot should return data and reset counters
snapshot := stats.Snapshot()
if snapshot.Totals.Requests != 1 {
t.Errorf("expected 1 request in snapshot, got %d", snapshot.Totals.Requests)
}
// After snapshot, counters should be reset
payload2 := stats.View()
if payload2.Totals.Requests != 0 {
t.Errorf("expected 0 requests after snapshot, got %d", payload2.Totals.Requests)
}
}
func TestConcurrentAccess(t *testing.T) {
stats := New()
done := make(chan bool)
// Concurrent writes
for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 100; j++ {
stats.Record(&Request{
Model: "llama3:8b",
PromptTokens: 10,
CompletionTokens: 5,
})
}
done <- true
}()
}
// Concurrent reads
for i := 0; i < 5; i++ {
go func() {
for j := 0; j < 100; j++ {
_ = stats.View()
_ = stats.GetModelStats()
}
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < 15; i++ {
<-done
}
payload := stats.View()
if payload.Totals.Requests != 1000 {
t.Errorf("expected 1000 requests, got %d", payload.Totals.Requests)
}
}

View File

@@ -381,6 +381,28 @@ func (t templateTools) String() string {
return string(bts)
}
// templateArgs is a map type with JSON string output for templates.
type templateArgs map[string]any
func (t templateArgs) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateProperties is a map type with JSON string output for templates.
type templateProperties map[string]api.ToolProperty
func (t templateProperties) String() string {
if t == nil {
return "{}"
}
bts, _ := json.Marshal(t)
return string(bts)
}
// templateTool is a template-compatible representation of api.Tool
// with Properties as a regular map for template ranging.
type templateTool struct {
@@ -396,11 +418,11 @@ type templateToolFunction struct {
}
type templateToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]api.ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties templateProperties `json:"properties"`
}
// templateToolCall is a template-compatible representation of api.ToolCall
@@ -413,7 +435,7 @@ type templateToolCall struct {
type templateToolCallFunction struct {
Index int
Name string
Arguments map[string]any
Arguments templateArgs
}
// templateMessage is a template-compatible representation of api.Message
@@ -446,7 +468,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
Defs: tool.Function.Parameters.Defs,
Items: tool.Function.Parameters.Items,
Required: tool.Function.Parameters.Required,
Properties: tool.Function.Parameters.Properties.ToMap(),
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
},
},
}
@@ -468,7 +490,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
Function: templateToolCallFunction{
Index: tc.Function.Index,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments.ToMap(),
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
},
})
}

View File

@@ -613,3 +613,159 @@ func TestCollate(t *testing.T) {
})
}
}
func TestTemplateArgumentsJSON(t *testing.T) {
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("location", "Tokyo")
args.Set("unit", "celsius")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Arguments output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplatePropertiesJSON(t *testing.T) {
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
// Should be valid JSON, not "map[location:{...}]"
if strings.HasPrefix(got, "map[") {
t.Errorf("Properties output as Go map format: %s", got)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
}
}
func TestTemplateArgumentsRange(t *testing.T) {
// Test that we can range over Arguments in templates
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
args := api.NewToolCallFunctionArguments()
args.Set("city", "Tokyo")
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{
Role: "assistant",
ToolCalls: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
}},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "city=Tokyo;" {
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
}
}
func TestTemplatePropertiesRange(t *testing.T) {
// Test that we can range over Properties in templates
// Note: template must reference .Messages to trigger the modern code path that converts Tools
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
template, err := Parse(tmpl)
if err != nil {
t.Fatal(err)
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
var buf bytes.Buffer
err = template.Execute(&buf, Values{
Messages: []api.Message{{Role: "user", Content: "test"}},
Tools: api.Tools{{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
})
if err != nil {
t.Fatal(err)
}
got := buf.String()
if got != "location:string;" {
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
}
}

View File

@@ -1104,108 +1104,3 @@ func PromptYesNo(question string) (bool, error) {
}
}
}
// CloudModelOption represents a suggested cloud model for the selection prompt.
type CloudModelOption struct {
Name string
Description string
}
// PromptModelChoice displays a model selection prompt with multiple options.
// Returns the selected model name, or empty string if user declined or cancelled.
func PromptModelChoice(question string, models []CloudModelOption) (string, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return "", err
}
defer term.Restore(fd, oldState)
// Build options: models + "No thanks, continue"
optionCount := len(models) + 1
selected := 0
// Total lines: question + models + "no thanks" + hint = optionCount + 2
totalLines := optionCount + 2
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
firstRender := true
render := func() {
if !firstRender {
fmt.Fprintf(os.Stderr, "\033[%dA\r", totalLines-1)
}
firstRender = false
// \r\n needed in raw mode for proper line breaks
fmt.Fprintf(os.Stderr, "\033[K\033[36m%s\033[0m\r\n", question)
for i, model := range models {
fmt.Fprintf(os.Stderr, "\033[K")
if i == selected {
fmt.Fprintf(os.Stderr, " \033[1;32m> %s\033[0m \033[90m%s\033[0m\r\n", model.Name, model.Description)
} else {
fmt.Fprintf(os.Stderr, " \033[90m%s %s\033[0m\r\n", model.Name, model.Description)
}
}
fmt.Fprintf(os.Stderr, "\033[K")
if selected == len(models) {
fmt.Fprintf(os.Stderr, " \033[1;32m> No thanks, continue\033[0m\r\n")
} else {
fmt.Fprintf(os.Stderr, " \033[90mNo thanks, continue\033[0m\r\n")
}
fmt.Fprintf(os.Stderr, "\033[K\033[90m(↑/↓ to navigate, Enter to confirm)\033[0m")
}
render()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return "", err
}
if n == 1 {
switch buf[0] {
case 'j', 'J':
if selected < optionCount-1 {
selected++
}
render()
case 'k', 'K':
if selected > 0 {
selected--
}
render()
case '\r', '\n':
fmt.Fprintf(os.Stderr, "\n")
if selected < len(models) {
return models[selected].Name, nil
}
return "", nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\n")
return "", nil
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
switch buf[2] {
case 'A': // Up
if selected > 0 {
selected--
}
render()
case 'B': // Down
if selected < optionCount-1 {
selected++
}
render()
}
}
}
}

View File

@@ -1,25 +0,0 @@
package agent
import (
"testing"
)
func TestCloudModelOptionStruct(t *testing.T) {
// Test that the struct is defined correctly
models := []CloudModelOption{
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
}
if len(models) != 2 {
t.Errorf("expected 2 models, got %d", len(models))
}
if models[0].Name != "glm-4.7:cloud" {
t.Errorf("expected glm-4.7:cloud, got %s", models[0].Name)
}
if models[1].Description != "Qwen3 Coder 480B" {
t.Errorf("expected 'Qwen3 Coder 480B', got %s", models[1].Description)
}
}

View File

@@ -1,41 +0,0 @@
package cmd
import (
"errors"
"testing"
)
func TestCloudModelSwitchRequest(t *testing.T) {
// Test the error type
req := &CloudModelSwitchRequest{Model: "glm-4.7:cloud"}
// Test Error() method
errMsg := req.Error()
expected := "switch to model: glm-4.7:cloud"
if errMsg != expected {
t.Errorf("expected %q, got %q", expected, errMsg)
}
// Test errors.As
var err error = req
var switchReq *CloudModelSwitchRequest
if !errors.As(err, &switchReq) {
t.Error("errors.As should return true for CloudModelSwitchRequest")
}
if switchReq.Model != "glm-4.7:cloud" {
t.Errorf("expected model glm-4.7:cloud, got %s", switchReq.Model)
}
}
func TestSuggestedCloudModels(t *testing.T) {
// Verify the suggested models are defined
if len(suggestedCloudModels) == 0 {
t.Error("suggestedCloudModels should not be empty")
}
// Check first model
if suggestedCloudModels[0].Name != "glm-4.7:cloud" {
t.Errorf("expected first model to be glm-4.7:cloud, got %s", suggestedCloudModels[0].Name)
}
}

View File

@@ -37,22 +37,6 @@ const (
charsPerToken = 4
)
// suggestedCloudModels are the models suggested to users after signing in.
// TODO(parthsareen): Dynamically recommend models based on user context instead of hardcoding
var suggestedCloudModels = []agent.CloudModelOption{
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
}
// CloudModelSwitchRequest signals that the user wants to switch to a different model.
type CloudModelSwitchRequest struct {
Model string
}
func (c *CloudModelSwitchRequest) Error() string {
return fmt.Sprintf("switch to model: %s", c.Model)
}
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
@@ -135,21 +119,6 @@ func waitForOllamaSignin(ctx context.Context) error {
return nil
}
// promptCloudModelSuggestion shows cloud model suggestions after successful sign-in.
// Returns the selected model name, or empty string if user declines.
func promptCloudModelSuggestion() string {
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "\033[1;36mTry cloud models for free!\033[0m\n")
fmt.Fprintf(os.Stderr, "\033[90mCloud models offer powerful capabilities without local hardware requirements.\033[0m\n")
fmt.Fprintf(os.Stderr, "\n")
selectedModel, err := agent.PromptModelChoice("Try a cloud model now?", suggestedCloudModels)
if err != nil || selectedModel == "" {
return ""
}
return selectedModel
}
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
@@ -175,40 +144,6 @@ type RunOptions struct {
// LastToolOutputTruncated stores the truncated version shown inline
LastToolOutputTruncated *string
// ActiveModel points to the current model name - can be updated mid-turn
// for model switching. If nil, opts.Model is used.
ActiveModel *string
}
// getActiveModel returns the current model name, checking ActiveModel pointer first.
func getActiveModel(opts *RunOptions) string {
if opts.ActiveModel != nil && *opts.ActiveModel != "" {
return *opts.ActiveModel
}
return opts.Model
}
// showModelConnection displays "Connecting to X on ollama.com" for cloud models.
func showModelConnection(ctx context.Context, modelName string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
info, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return err
}
if info.RemoteHost != "" {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
return nil
}
// Chat runs an agent chat loop with tool support.
@@ -308,7 +243,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// Agentic loop: continue until no more tool calls
for {
req := &api.ChatRequest{
Model: getActiveModel(&opts),
Model: opts.Model,
Messages: messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
@@ -332,6 +267,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, nil
}
// Check for 401 Unauthorized - prompt user to sign in
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
@@ -339,13 +275,9 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
suggestedModel := promptCloudModelSuggestion()
if suggestedModel != "" {
return nil, &CloudModelSwitchRequest{Model: suggestedModel}
}
// Retry the chat request
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
continue
continue // Retry the loop
}
}
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
@@ -483,20 +415,19 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
// Check if web search needs authentication
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
// Prompt user to sign in
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
// Get signin URL and wait for auth completion
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
suggestedModel := promptCloudModelSuggestion()
if suggestedModel != "" && opts.ActiveModel != nil {
*opts.ActiveModel = suggestedModel
showModelConnection(ctx, suggestedModel)
}
fmt.Fprintf(os.Stderr, "\033[90mRetrying web search...\033[0m\n")
// Retry the web search
fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
@@ -535,7 +466,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, getActiveModel(&opts))
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
toolResults = append(toolResults, api.Message{
Role: "tool",
@@ -694,28 +625,25 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
return out
}
// checkModelCapabilities checks if the model supports tools and thinking.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, supportsThinking bool, err error) {
// checkModelCapabilities checks if the model supports tools.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return false, false, err
return false, err
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return false, false, err
return false, err
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityTools {
supportsTools = true
}
if cap == model.CapabilityThinking {
supportsThinking = true
return true, nil
}
}
return supportsTools, supportsThinking, nil
return false, nil
}
// GenerateInteractive runs an interactive agent session.
@@ -735,17 +663,13 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
// Check if model supports tools and thinking
supportsTools, supportsThinking, err := checkModelCapabilities(cmd.Context(), modelName)
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
supportsTools = false
supportsThinking = false
}
// Track if session is using thinking mode
usingThinking := think != nil && supportsThinking
// Create tool registry only if model supports tools
var toolRegistry *tools.Registry
if supportsTools {
@@ -833,44 +757,30 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
LastToolOutput: &lastToolOutput,
LastToolOutputTruncated: &lastToolOutputTruncated,
}
// Reset expanded state for new tool execution
toolOutputExpanded = false
retryChat:
for {
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
LastToolOutput: &lastToolOutput,
LastToolOutputTruncated: &lastToolOutputTruncated,
ActiveModel: &modelName,
}
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
var switchReq *CloudModelSwitchRequest
if errors.As(err, &switchReq) {
newModel := switchReq.Model
if err := switchToModel(cmd.Context(), newModel, &modelName, &supportsTools, &supportsThinking, &toolRegistry, usingThinking); err != nil {
fmt.Fprintf(os.Stderr, "\033[33m%v\033[0m\n", err)
fmt.Fprintf(os.Stderr, "\033[90mContinuing with %s...\033[0m\n", modelName)
}
continue retryChat
}
return err
}
if assistant != nil {
messages = append(messages, *assistant)
}
break retryChat
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
return err
}
if assistant != nil {
messages = append(messages, *assistant)
}
sb.Reset()
@@ -878,52 +788,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
}
// switchToModel handles model switching with capability checks and UI updates.
func switchToModel(ctx context.Context, newModel string, modelName *string, supportsTools, supportsThinking *bool, toolRegistry **tools.Registry, usingThinking bool) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
newSupportsTools, newSupportsThinking, capErr := checkModelCapabilities(ctx, newModel)
if capErr != nil {
return fmt.Errorf("could not check model capabilities: %w", capErr)
}
// TODO(parthsareen): Handle thinking -> non-thinking model switch gracefully
if usingThinking && !newSupportsThinking {
return fmt.Errorf("%s does not support thinking mode", newModel)
}
// Show "Connecting to X on ollama.com" for cloud models
info, err := client.Show(ctx, &api.ShowRequest{Model: newModel})
if err == nil && info.RemoteHost != "" {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
*modelName = newModel
*supportsTools = newSupportsTools
*supportsThinking = newSupportsThinking
if *supportsTools {
if *toolRegistry == nil {
*toolRegistry = tools.DefaultRegistry()
}
if (*toolRegistry).Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join((*toolRegistry).Names(), ", "))
}
} else {
*toolRegistry = nil
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
return nil
}
// showToolsStatus displays the current tools and approval status.
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
if !supportsTools || registry == nil {