Compare commits

..

1 Commits

Author SHA1 Message Date
ParthSareen
26b0a198b7 cmd: ollama launch allow extra params 2026-01-28 23:13:51 -08:00
18 changed files with 139 additions and 421 deletions

View File

@@ -912,19 +912,6 @@ type UserResponse struct {
Plan string `json:"plan,omitempty"`
}
type UsageResponse struct {
// Start is the time the server started tracking usage (UTC, RFC 3339).
Start time.Time `json:"start"`
Usage []ModelUsageData `json:"usage"`
}
type ModelUsageData struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
}
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`

View File

@@ -15,11 +15,15 @@ type Claude struct{}
func (c *Claude) String() string { return "Claude Code" }
func (c *Claude) args(model string) []string {
func (c *Claude) args(model string, extraArgs []string) []string {
var args []string
if model != "" {
return []string{"--model", model}
args = append(args, "--model", model)
}
return nil
if len(extraArgs) > 0 {
args = append(args, extraArgs...)
}
return args
}
func (c *Claude) findPath() (string, error) {
@@ -41,13 +45,13 @@ func (c *Claude) findPath() (string, error) {
return fallback, nil
}
func (c *Claude) Run(model string) error {
func (c *Claude) Run(model string, extraArgs []string) error {
claudePath, err := c.findPath()
if err != nil {
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
}
cmd := exec.Command(claudePath, c.args(model)...)
cmd := exec.Command(claudePath, c.args(model, extraArgs)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -82,19 +82,23 @@ func TestClaudeArgs(t *testing.T) {
c := &Claude{}
tests := []struct {
name string
model string
want []string
name string
model string
extraArgs []string
want []string
}{
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
{"empty model", "", nil},
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and extra args", "llama3.2", []string{"--yolo", "--hi"}, []string{"--model", "llama3.2", "--yolo", "--hi"}},
{"empty model with extra args", "", []string{"--help"}, []string{"--help"}},
{"multiple extra args", "llama3.2", []string{"--flag1", "--flag2", "value"}, []string{"--model", "llama3.2", "--flag1", "--flag2", "value"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
got := c.args(tt.model, tt.extraArgs)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.extraArgs, got, tt.want)
}
})
}

View File

@@ -19,7 +19,7 @@ func (c *Clawdbot) String() string { return "Clawdbot" }
const ansiGreen = "\033[32m"
func (c *Clawdbot) Run(model string) error {
func (c *Clawdbot) Run(model string, extraArgs []string) error {
if _, err := exec.LookPath("clawdbot"); err != nil {
return fmt.Errorf("clawdbot is not installed, install from https://docs.clawd.bot")
}
@@ -32,7 +32,13 @@ func (c *Clawdbot) Run(model string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("clawdbot", "gateway")
// Build args: "gateway" first, then any extra args
args := []string{"gateway"}
if len(extraArgs) > 0 {
args = append(args, extraArgs...)
}
cmd := exec.Command("clawdbot", args...)
cmd.Stdin = os.Stdin
// Capture output to detect "already running" message

View File

@@ -14,20 +14,23 @@ type Codex struct{}
func (c *Codex) String() string { return "Codex" }
func (c *Codex) args(model string) []string {
func (c *Codex) args(model string, extraArgs []string) []string {
args := []string{"--oss"}
if model != "" {
args = append(args, "-m", model)
}
if len(extraArgs) > 0 {
args = append(args, extraArgs...)
}
return args
}
func (c *Codex) Run(model string) error {
func (c *Codex) Run(model string, extraArgs []string) error {
if err := checkCodexVersion(); err != nil {
return err
}
cmd := exec.Command("codex", c.args(model)...)
cmd := exec.Command("codex", c.args(model, extraArgs)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -9,19 +9,22 @@ func TestCodexArgs(t *testing.T) {
c := &Codex{}
tests := []struct {
name string
model string
want []string
name string
model string
extraArgs []string
want []string
}{
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", []string{"--oss"}},
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
{"empty model", "", nil, []string{"--oss"}},
{"with model and extra args", "qwen3-coder", []string{"--yolo"}, []string{"--oss", "-m", "qwen3-coder", "--yolo"}},
{"empty model with extra args", "", []string{"--help"}, []string{"--oss", "--help"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model)
got := c.args(tt.model, tt.extraArgs)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.extraArgs, got, tt.want)
}
})
}

View File

@@ -39,7 +39,7 @@ type modelEntry struct {
func (d *Droid) String() string { return "Droid" }
func (d *Droid) Run(model string) error {
func (d *Droid) Run(model string, extraArgs []string) error {
if _, err := exec.LookPath("droid"); err != nil {
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
}
@@ -53,7 +53,7 @@ func (d *Droid) Run(model string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("droid")
cmd := exec.Command("droid", extraArgs...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -22,7 +22,7 @@ import (
// Runner can run an integration with a model.
type Runner interface {
Run(model string) error
Run(model string, extraArgs []string) error
// String returns the human-readable name of the integration
String() string
}
@@ -222,13 +222,13 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
return selected, nil
}
func runIntegration(name, modelName string) error {
func runIntegration(name, modelName string, extraArgs []string) error {
r, ok := integrations[name]
if !ok {
return fmt.Errorf("unknown integration: %s", name)
}
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
return r.Run(modelName)
return r.Run(modelName, extraArgs)
}
// LaunchCmd returns the cobra command for launching integrations.
@@ -237,7 +237,7 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
var configFlag bool
cmd := &cobra.Command{
Use: "launch [INTEGRATION]",
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
Short: "Launch an integration with Ollama",
Long: `Launch an integration configured with Ollama models.
@@ -252,13 +252,17 @@ Examples:
ollama launch
ollama launch claude
ollama launch claude --model <model>
ollama launch droid --config (does not auto-launch)`,
Args: cobra.MaximumNArgs(1),
ollama launch droid --config (does not auto-launch)
ollama launch claude -- --yolo --hi (pass extra args to integration)`,
Args: cobra.ArbitraryArgs,
PreRunE: checkServerHeartbeat,
RunE: func(cmd *cobra.Command, args []string) error {
// Extract integration name and pass through remaining args
var name string
var extraArgs []string
if len(args) > 0 {
name = args[0]
extraArgs = args[1:]
} else {
var err error
name, err = selectIntegration()
@@ -278,7 +282,7 @@ Examples:
// If launching without --model, use saved config if available
if !configFlag && modelFlag == "" {
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
return runIntegration(name, config.Models[0])
return runIntegration(name, config.Models[0], extraArgs)
}
}
@@ -339,13 +343,13 @@ Examples:
if configFlag {
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
return runIntegration(name, models[0])
return runIntegration(name, models[0], extraArgs)
}
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
return nil
}
return runIntegration(name, models[0])
return runIntegration(name, models[0], extraArgs)
},
}

View File

@@ -90,8 +90,8 @@ func TestLaunchCmd(t *testing.T) {
cmd := LaunchCmd(mockCheck)
t.Run("command structure", func(t *testing.T) {
if cmd.Use != "launch [INTEGRATION]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
}
if cmd.Short == "" {
t.Error("Short description should not be empty")
@@ -121,7 +121,7 @@ func TestLaunchCmd(t *testing.T) {
}
func TestRunIntegration_UnknownIntegration(t *testing.T) {
err := runIntegration("unknown-integration", "model")
err := runIntegration("unknown-integration", "model", nil)
if err == nil {
t.Error("expected error for unknown integration, got nil")
}
@@ -182,7 +182,69 @@ func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
// Test Run() exists (we can't call it without actually running the command)
// Just verify the method is available
var _ func(string) error = r.Run
var _ func(string, []string) error = r.Run
})
}
}
func TestParseExtraArgs(t *testing.T) {
tests := []struct {
name string
args []string
wantArgs []string
wantExtraArgs []string
}{
{
name: "no extra args",
args: []string{"claude"},
wantArgs: []string{"claude"},
wantExtraArgs: nil,
},
{
name: "with extra args after --",
args: []string{"claude", "--", "--yolo", "--hi"},
wantArgs: []string{"claude"},
wantExtraArgs: []string{"--yolo", "--hi"},
},
{
name: "extra args only after --",
args: []string{"codex", "--", "--help"},
wantArgs: []string{"codex"},
wantExtraArgs: []string{"--help"},
},
{
name: "-- at end with no args after",
args: []string{"claude", "--"},
wantArgs: []string{"claude", "--"},
wantExtraArgs: nil,
},
{
name: "multiple args after --",
args: []string{"claude", "--", "--flag1", "--flag2", "value", "--flag3"},
wantArgs: []string{"claude"},
wantExtraArgs: []string{"--flag1", "--flag2", "value", "--flag3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the parsing logic from LaunchCmd
args := tt.args
var extraArgs []string
for i, arg := range args {
if arg == "--" && i < len(args)-1 {
extraArgs = args[i+1:]
args = args[:i]
break
}
}
if !slices.Equal(args, tt.wantArgs) {
t.Errorf("args = %v, want %v", args, tt.wantArgs)
}
if !slices.Equal(extraArgs, tt.wantExtraArgs) {
t.Errorf("extraArgs = %v, want %v", extraArgs, tt.wantExtraArgs)
}
})
}
}

View File

@@ -18,7 +18,7 @@ type OpenCode struct{}
func (o *OpenCode) String() string { return "OpenCode" }
func (o *OpenCode) Run(model string) error {
func (o *OpenCode) Run(model string, extraArgs []string) error {
if _, err := exec.LookPath("opencode"); err != nil {
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
}
@@ -32,7 +32,7 @@ func (o *OpenCode) Run(model string) error {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("opencode")
cmd := exec.Command("opencode", extraArgs...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

View File

@@ -15,7 +15,6 @@
- [Push a Model](#push-a-model)
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Usage](#usage)
- [Version](#version)
- [Experimental: Image Generation](#image-generation-experimental)
@@ -1855,53 +1854,6 @@ curl http://localhost:11434/api/embeddings -d '{
}
```
## Usage
```
GET /api/usage
```
Show aggregate usage statistics per model since the server started. All timestamps are UTC in RFC 3339 format.
### Examples
#### Request
```shell
curl http://localhost:11434/api/usage
```
#### Response
```json
{
"start": "2025-01-27T20:00:00Z",
"usage": [
{
"model": "llama3.2",
"requests": 5,
"prompt_tokens": 130,
"completion_tokens": 890
},
{
"model": "deepseek-r1",
"requests": 2,
"prompt_tokens": 48,
"completion_tokens": 312
}
]
}
```
#### Response fields
- `start`: when the server started tracking usage (UTC, RFC 3339)
- `usage`: list of per-model usage statistics
- `model`: model name
- `requests`: total number of completed requests
- `prompt_tokens`: total prompt tokens evaluated
- `completion_tokens`: total completion tokens generated
## Version
```

View File

@@ -85,7 +85,6 @@ type Server struct {
addr net.Addr
sched *Scheduler
lowVRAM bool
usage *UsageTracker
}
func init() {
@@ -274,10 +273,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
c.Header("Content-Type", contentType)
fn := func(resp api.GenerateResponse) error {
if resp.Done {
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
}
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
@@ -584,8 +579,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
res.Context = tokens
}
s.usage.Record(req.Model, cr.PromptEvalCount, cr.EvalCount)
}
if builtinParser != nil {
@@ -1597,8 +1590,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
r.GET("/api/usage", s.UsageHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.GenerateHandler)
@@ -1667,7 +1658,7 @@ func Serve(ln net.Listener) error {
}
}
s := &Server{addr: ln.Addr(), usage: NewUsageTracker()}
s := &Server{addr: ln.Addr()}
var rc *ollama.Registry
if useClient2 {
@@ -1884,10 +1875,6 @@ func (s *Server) SignoutHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil)
}
func (s *Server) UsageHandler(c *gin.Context) {
c.JSON(http.StatusOK, s.usage.Stats())
}
func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{}
@@ -2046,10 +2033,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
c.Header("Content-Type", contentType)
fn := func(resp api.ChatResponse) error {
if resp.Done {
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
}
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
@@ -2270,8 +2253,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
s.usage.Record(req.Model, r.PromptEvalCount, r.EvalCount)
}
if builtinParser != nil {

View File

@@ -29,7 +29,6 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -223,7 +222,6 @@ func TestChatDebugRenderOnly(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -34,7 +34,6 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -219,7 +218,6 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -88,39 +88,19 @@ func TestGenerateChatRemote(t *testing.T) {
if r.Method != http.MethodPost {
t.Errorf("Expected POST request, got %s", r.Method)
}
if r.URL.Path != "/api/chat" {
t.Errorf("Expected path '/api/chat', got %s", r.URL.Path)
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/api/chat":
resp := api.ChatResponse{
Model: "test",
Done: true,
DoneReason: "load",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 20,
},
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
case "/api/generate":
resp := api.GenerateResponse{
Model: "test",
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 5,
EvalCount: 15,
},
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
default:
t.Errorf("unexpected path %s", r.URL.Path)
resp := api.ChatResponse{
Model: "test",
Done: true,
DoneReason: "load",
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
}))
defer rs.Close()
@@ -131,7 +111,7 @@ func TestGenerateChatRemote(t *testing.T) {
}
t.Setenv("OLLAMA_REMOTES", p.Hostname())
s := Server{usage: NewUsageTracker()}
s := Server{}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-cloud",
RemoteHost: rs.URL,
@@ -179,61 +159,6 @@ func TestGenerateChatRemote(t *testing.T) {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
t.Run("remote chat usage tracking", func(t *testing.T) {
stats := s.usage.Stats()
found := false
for _, m := range stats.Usage {
if m.Model == "test-cloud" {
found = true
if m.Requests != 1 {
t.Errorf("expected 1 request, got %d", m.Requests)
}
if m.PromptTokens != 10 {
t.Errorf("expected 10 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 20 {
t.Errorf("expected 20 completion tokens, got %d", m.CompletionTokens)
}
}
}
if !found {
t.Error("expected usage entry for test-cloud")
}
})
t.Run("remote generate usage tracking", func(t *testing.T) {
// Reset the tracker for a clean test
s.usage = NewUsageTracker()
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-cloud",
Prompt: "hello",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
stats := s.usage.Stats()
found := false
for _, m := range stats.Usage {
if m.Model == "test-cloud" {
found = true
if m.Requests != 1 {
t.Errorf("expected 1 request, got %d", m.Requests)
}
if m.PromptTokens != 5 {
t.Errorf("expected 5 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 15 {
t.Errorf("expected 15 completion tokens, got %d", m.CompletionTokens)
}
}
}
if !found {
t.Error("expected usage entry for test-cloud")
}
})
}
func TestGenerateChat(t *testing.T) {
@@ -251,7 +176,6 @@ func TestGenerateChat(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -968,7 +892,6 @@ func TestGenerate(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -1453,7 +1376,6 @@ func TestGenerateLogprobs(t *testing.T) {
}
s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -1634,7 +1556,6 @@ func TestChatLogprobs(t *testing.T) {
}
s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -1745,7 +1666,6 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
}
s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -2192,7 +2112,6 @@ func TestGenerateUnload(t *testing.T) {
var loadFnCalled bool
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -2294,7 +2213,6 @@ func TestGenerateWithImages(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -2452,7 +2370,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
opts := api.DefaultOptions()
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -255,7 +255,6 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -407,7 +406,6 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -590,7 +588,6 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -1,62 +0,0 @@
package server
import (
"sync"
"time"
"github.com/ollama/ollama/api"
)
type ModelUsage struct {
Requests int64
PromptTokens int64
CompletionTokens int64
}
type UsageTracker struct {
mu sync.Mutex
start time.Time
models map[string]*ModelUsage
}
func NewUsageTracker() *UsageTracker {
return &UsageTracker{
start: time.Now().UTC(),
models: make(map[string]*ModelUsage),
}
}
func (u *UsageTracker) Record(model string, promptTokens, completionTokens int) {
u.mu.Lock()
defer u.mu.Unlock()
m, ok := u.models[model]
if !ok {
m = &ModelUsage{}
u.models[model] = m
}
m.Requests++
m.PromptTokens += int64(promptTokens)
m.CompletionTokens += int64(completionTokens)
}
func (u *UsageTracker) Stats() api.UsageResponse {
u.mu.Lock()
defer u.mu.Unlock()
byModel := make([]api.ModelUsageData, 0, len(u.models))
for model, usage := range u.models {
byModel = append(byModel, api.ModelUsageData{
Model: model,
Requests: usage.Requests,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
})
}
return api.UsageResponse{
Start: u.start,
Usage: byModel,
}
}

View File

@@ -1,136 +0,0 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
func TestUsageTrackerRecord(t *testing.T) {
tracker := NewUsageTracker()
tracker.Record("model-a", 10, 20)
tracker.Record("model-a", 5, 15)
tracker.Record("model-b", 100, 200)
stats := tracker.Stats()
if len(stats.Usage) != 2 {
t.Fatalf("expected 2 models, got %d", len(stats.Usage))
}
lookup := make(map[string]api.ModelUsageData)
for _, m := range stats.Usage {
lookup[m.Model] = m
}
a := lookup["model-a"]
if a.Requests != 2 {
t.Errorf("model-a requests: expected 2, got %d", a.Requests)
}
if a.PromptTokens != 15 {
t.Errorf("model-a prompt tokens: expected 15, got %d", a.PromptTokens)
}
if a.CompletionTokens != 35 {
t.Errorf("model-a completion tokens: expected 35, got %d", a.CompletionTokens)
}
b := lookup["model-b"]
if b.Requests != 1 {
t.Errorf("model-b requests: expected 1, got %d", b.Requests)
}
if b.PromptTokens != 100 {
t.Errorf("model-b prompt tokens: expected 100, got %d", b.PromptTokens)
}
if b.CompletionTokens != 200 {
t.Errorf("model-b completion tokens: expected 200, got %d", b.CompletionTokens)
}
}
func TestUsageTrackerConcurrent(t *testing.T) {
tracker := NewUsageTracker()
var wg sync.WaitGroup
for range 100 {
wg.Add(1)
go func() {
defer wg.Done()
tracker.Record("model-a", 1, 2)
}()
}
wg.Wait()
stats := tracker.Stats()
if len(stats.Usage) != 1 {
t.Fatalf("expected 1 model, got %d", len(stats.Usage))
}
m := stats.Usage[0]
if m.Requests != 100 {
t.Errorf("requests: expected 100, got %d", m.Requests)
}
if m.PromptTokens != 100 {
t.Errorf("prompt tokens: expected 100, got %d", m.PromptTokens)
}
if m.CompletionTokens != 200 {
t.Errorf("completion tokens: expected 200, got %d", m.CompletionTokens)
}
}
func TestUsageTrackerStart(t *testing.T) {
tracker := NewUsageTracker()
stats := tracker.Stats()
if stats.Start.IsZero() {
t.Error("expected non-zero start time")
}
}
func TestUsageHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
s := &Server{
usage: NewUsageTracker(),
}
s.usage.Record("llama3", 50, 100)
s.usage.Record("llama3", 25, 50)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/usage", nil)
s.UsageHandler(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.UsageResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if len(resp.Usage) != 1 {
t.Fatalf("expected 1 model, got %d", len(resp.Usage))
}
m := resp.Usage[0]
if m.Model != "llama3" {
t.Errorf("expected model llama3, got %s", m.Model)
}
if m.Requests != 2 {
t.Errorf("expected 2 requests, got %d", m.Requests)
}
if m.PromptTokens != 75 {
t.Errorf("expected 75 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 150 {
t.Errorf("expected 150 completion tokens, got %d", m.CompletionTokens)
}
}