mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 01:33:06 -05:00
Compare commits
4 Commits
llama-upda
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0496e6125 | ||
|
|
2d57bcbc64 | ||
|
|
060f9341c0 | ||
|
|
7b62c41060 |
13
api/types.go
13
api/types.go
@@ -912,6 +912,19 @@ type UserResponse struct {
|
||||
Plan string `json:"plan,omitempty"`
|
||||
}
|
||||
|
||||
type UsageResponse struct {
|
||||
// Start is the time the server started tracking usage (UTC, RFC 3339).
|
||||
Start time.Time `json:"start"`
|
||||
Usage []ModelUsageData `json:"usage"`
|
||||
}
|
||||
|
||||
type ModelUsageData struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
}
|
||||
|
||||
// Tensor describes the metadata for a given tensor.
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
@@ -50,7 +52,7 @@ func (c *Claude) Run(model string) error {
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type Clawdbot struct{}
|
||||
@@ -90,7 +92,7 @@ func (c *Clawdbot) Edit(models []string) error {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = "http://127.0.0.1:11434/v1"
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
@@ -117,7 +119,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
|
||||
@@ -218,7 +218,7 @@ func TestDroidEdit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if model["baseUrl"] != "http://localhost:11434/v1" {
|
||||
if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
|
||||
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
|
||||
}
|
||||
if model["apiKey"] != "ollama" {
|
||||
@@ -447,7 +447,7 @@ const testDroidSettingsFixture = `{
|
||||
{
|
||||
"model": "existing-ollama-model",
|
||||
"displayName": "existing-ollama-model",
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"apiKey": "ollama",
|
||||
"provider": "generic-chat-completion-api",
|
||||
"maxOutputTokens": 64000,
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
@@ -88,7 +90,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": "http://localhost:11434/v1",
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
48
docs/api.md
48
docs/api.md
@@ -15,6 +15,7 @@
|
||||
- [Push a Model](#push-a-model)
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Usage](#usage)
|
||||
- [Version](#version)
|
||||
- [Experimental: Image Generation](#image-generation-experimental)
|
||||
|
||||
@@ -1854,6 +1855,53 @@ curl http://localhost:11434/api/embeddings -d '{
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
GET /api/usage
|
||||
```
|
||||
|
||||
Show aggregate usage statistics per model since the server started. All timestamps are UTC in RFC 3339 format.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/usage
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"start": "2025-01-27T20:00:00Z",
|
||||
"usage": [
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"requests": 5,
|
||||
"prompt_tokens": 130,
|
||||
"completion_tokens": 890
|
||||
},
|
||||
{
|
||||
"model": "deepseek-r1",
|
||||
"requests": 2,
|
||||
"prompt_tokens": 48,
|
||||
"completion_tokens": 312
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Response fields
|
||||
|
||||
- `start`: when the server started tracking usage (UTC, RFC 3339)
|
||||
- `usage`: list of per-model usage statistics
|
||||
- `model`: model name
|
||||
- `requests`: total number of completed requests
|
||||
- `prompt_tokens`: total prompt tokens evaluated
|
||||
- `completion_tokens`: total completion tokens generated
|
||||
|
||||
## Version
|
||||
|
||||
```
|
||||
|
||||
@@ -85,6 +85,7 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
lowVRAM bool
|
||||
usage *UsageTracker
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -273,6 +274,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
fn := func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
|
||||
}
|
||||
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
@@ -579,6 +584,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
res.Context = tokens
|
||||
}
|
||||
|
||||
s.usage.Record(req.Model, cr.PromptEvalCount, cr.EvalCount)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
@@ -1590,6 +1597,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
|
||||
r.GET("/api/usage", s.UsageHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
@@ -1658,7 +1667,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
s := &Server{addr: ln.Addr(), usage: NewUsageTracker()}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
@@ -1875,6 +1884,10 @@ func (s *Server) SignoutHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (s *Server) UsageHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, s.usage.Stats())
|
||||
}
|
||||
|
||||
func (s *Server) PsHandler(c *gin.Context) {
|
||||
models := []api.ProcessModelResponse{}
|
||||
|
||||
@@ -2033,6 +2046,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
fn := func(resp api.ChatResponse) error {
|
||||
if resp.Done {
|
||||
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
|
||||
}
|
||||
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
@@ -2253,6 +2270,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
s.usage.Record(req.Model, r.PromptEvalCount, r.EvalCount)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
|
||||
@@ -29,6 +29,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -222,6 +223,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -34,6 +34,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -218,6 +219,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -88,19 +88,39 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/api/chat" {
|
||||
t.Errorf("Expected path '/api/chat', got %s", r.URL.Path)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/api/chat":
|
||||
resp := api.ChatResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 20,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case "/api/generate":
|
||||
resp := api.GenerateResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 5,
|
||||
EvalCount: 15,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected path %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
@@ -111,7 +131,7 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
s := Server{}
|
||||
s := Server{usage: NewUsageTracker()}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud",
|
||||
RemoteHost: rs.URL,
|
||||
@@ -159,6 +179,61 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remote chat usage tracking", func(t *testing.T) {
|
||||
stats := s.usage.Stats()
|
||||
found := false
|
||||
for _, m := range stats.Usage {
|
||||
if m.Model == "test-cloud" {
|
||||
found = true
|
||||
if m.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 10 {
|
||||
t.Errorf("expected 10 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 20 {
|
||||
t.Errorf("expected 20 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected usage entry for test-cloud")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remote generate usage tracking", func(t *testing.T) {
|
||||
// Reset the tracker for a clean test
|
||||
s.usage = NewUsageTracker()
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-cloud",
|
||||
Prompt: "hello",
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
stats := s.usage.Stats()
|
||||
found := false
|
||||
for _, m := range stats.Usage {
|
||||
if m.Model == "test-cloud" {
|
||||
found = true
|
||||
if m.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 5 {
|
||||
t.Errorf("expected 5 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 15 {
|
||||
t.Errorf("expected 15 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected usage entry for test-cloud")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateChat(t *testing.T) {
|
||||
@@ -176,6 +251,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -892,6 +968,7 @@ func TestGenerate(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1376,6 +1453,7 @@ func TestGenerateLogprobs(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1556,6 +1634,7 @@ func TestChatLogprobs(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1666,6 +1745,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2112,6 +2192,7 @@ func TestGenerateUnload(t *testing.T) {
|
||||
var loadFnCalled bool
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2213,6 +2294,7 @@ func TestGenerateWithImages(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2370,6 +2452,7 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -255,6 +255,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -406,6 +407,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -588,6 +590,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
62
server/usage.go
Normal file
62
server/usage.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ModelUsage struct {
|
||||
Requests int64
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
}
|
||||
|
||||
type UsageTracker struct {
|
||||
mu sync.Mutex
|
||||
start time.Time
|
||||
models map[string]*ModelUsage
|
||||
}
|
||||
|
||||
func NewUsageTracker() *UsageTracker {
|
||||
return &UsageTracker{
|
||||
start: time.Now().UTC(),
|
||||
models: make(map[string]*ModelUsage),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UsageTracker) Record(model string, promptTokens, completionTokens int) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
m, ok := u.models[model]
|
||||
if !ok {
|
||||
m = &ModelUsage{}
|
||||
u.models[model] = m
|
||||
}
|
||||
|
||||
m.Requests++
|
||||
m.PromptTokens += int64(promptTokens)
|
||||
m.CompletionTokens += int64(completionTokens)
|
||||
}
|
||||
|
||||
func (u *UsageTracker) Stats() api.UsageResponse {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
byModel := make([]api.ModelUsageData, 0, len(u.models))
|
||||
for model, usage := range u.models {
|
||||
byModel = append(byModel, api.ModelUsageData{
|
||||
Model: model,
|
||||
Requests: usage.Requests,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
})
|
||||
}
|
||||
|
||||
return api.UsageResponse{
|
||||
Start: u.start,
|
||||
Usage: byModel,
|
||||
}
|
||||
}
|
||||
136
server/usage_test.go
Normal file
136
server/usage_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestUsageTrackerRecord(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
tracker.Record("model-a", 10, 20)
|
||||
tracker.Record("model-a", 5, 15)
|
||||
tracker.Record("model-b", 100, 200)
|
||||
|
||||
stats := tracker.Stats()
|
||||
|
||||
if len(stats.Usage) != 2 {
|
||||
t.Fatalf("expected 2 models, got %d", len(stats.Usage))
|
||||
}
|
||||
|
||||
lookup := make(map[string]api.ModelUsageData)
|
||||
for _, m := range stats.Usage {
|
||||
lookup[m.Model] = m
|
||||
}
|
||||
|
||||
a := lookup["model-a"]
|
||||
if a.Requests != 2 {
|
||||
t.Errorf("model-a requests: expected 2, got %d", a.Requests)
|
||||
}
|
||||
if a.PromptTokens != 15 {
|
||||
t.Errorf("model-a prompt tokens: expected 15, got %d", a.PromptTokens)
|
||||
}
|
||||
if a.CompletionTokens != 35 {
|
||||
t.Errorf("model-a completion tokens: expected 35, got %d", a.CompletionTokens)
|
||||
}
|
||||
|
||||
b := lookup["model-b"]
|
||||
if b.Requests != 1 {
|
||||
t.Errorf("model-b requests: expected 1, got %d", b.Requests)
|
||||
}
|
||||
if b.PromptTokens != 100 {
|
||||
t.Errorf("model-b prompt tokens: expected 100, got %d", b.PromptTokens)
|
||||
}
|
||||
if b.CompletionTokens != 200 {
|
||||
t.Errorf("model-b completion tokens: expected 200, got %d", b.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTrackerConcurrent(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 100 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.Record("model-a", 1, 2)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
stats := tracker.Stats()
|
||||
if len(stats.Usage) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(stats.Usage))
|
||||
}
|
||||
|
||||
m := stats.Usage[0]
|
||||
if m.Requests != 100 {
|
||||
t.Errorf("requests: expected 100, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 100 {
|
||||
t.Errorf("prompt tokens: expected 100, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 200 {
|
||||
t.Errorf("completion tokens: expected 200, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTrackerStart(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
stats := tracker.Stats()
|
||||
if stats.Start.IsZero() {
|
||||
t.Error("expected non-zero start time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
}
|
||||
|
||||
s.usage.Record("llama3", 50, 100)
|
||||
s.usage.Record("llama3", 25, 50)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/usage", nil)
|
||||
|
||||
s.UsageHandler(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.UsageResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Usage) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(resp.Usage))
|
||||
}
|
||||
|
||||
m := resp.Usage[0]
|
||||
if m.Model != "llama3" {
|
||||
t.Errorf("expected model llama3, got %s", m.Model)
|
||||
}
|
||||
if m.Requests != 2 {
|
||||
t.Errorf("expected 2 requests, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 75 {
|
||||
t.Errorf("expected 75 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 150 {
|
||||
t.Errorf("expected 150 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user