mirror of
https://github.com/ollama/ollama.git
synced 2026-01-23 06:53:03 -05:00
Compare commits
1 Commits
parth/decr
...
usage
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d132315276 |
@@ -377,6 +377,15 @@ func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
|||||||
return &lr, nil
|
return &lr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Usage returns usage statistics and system info.
|
||||||
|
func (c *Client) Usage(ctx context.Context) (*UsageResponse, error) {
|
||||||
|
var ur UsageResponse
|
||||||
|
if err := c.do(ctx, http.MethodGet, "/api/usage", nil, &ur); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ur, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Copy copies a model - creating a model with another name from an existing
|
// Copy copies a model - creating a model with another name from an existing
|
||||||
// model.
|
// model.
|
||||||
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
||||||
|
|||||||
27
api/types.go
27
api/types.go
@@ -792,6 +792,33 @@ type ProcessResponse struct {
|
|||||||
Models []ProcessModelResponse `json:"models"`
|
Models []ProcessModelResponse `json:"models"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UsageResponse is the response from [Client.Usage].
|
||||||
|
type UsageResponse struct {
|
||||||
|
GPUs []GPUUsage `json:"gpus,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GPUUsage contains GPU/device memory usage breakdown.
|
||||||
|
type GPUUsage struct {
|
||||||
|
Name string `json:"name"` // Device name (e.g., "Apple M2 Max", "NVIDIA GeForce RTX 4090")
|
||||||
|
Backend string `json:"backend"` // CUDA, ROCm, Metal, etc.
|
||||||
|
Total uint64 `json:"total"`
|
||||||
|
Free uint64 `json:"free"`
|
||||||
|
Used uint64 `json:"used"` // Memory used by Ollama
|
||||||
|
Other uint64 `json:"other"` // Memory used by other processes
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageStats contains usage statistics.
|
||||||
|
type UsageStats struct {
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TokensInput int64 `json:"tokens_input"`
|
||||||
|
TokensOutput int64 `json:"tokens_output"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Models map[string]int64 `json:"models,omitempty"`
|
||||||
|
Sources map[string]int64 `json:"sources,omitempty"`
|
||||||
|
ToolCalls int64 `json:"tool_calls,omitempty"`
|
||||||
|
StructuredOutput int64 `json:"structured_output,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ListModelResponse is a single model description in [ListResponse].
|
// ListModelResponse is a single model description in [ListResponse].
|
||||||
type ListModelResponse struct {
|
type ListModelResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|||||||
@@ -1833,6 +1833,7 @@ func NewCLI() *cobra.Command {
|
|||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
RunE: ListRunningHandler,
|
RunE: ListRunningHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
copyCmd := &cobra.Command{
|
copyCmd := &cobra.Command{
|
||||||
Use: "cp SOURCE DESTINATION",
|
Use: "cp SOURCE DESTINATION",
|
||||||
Short: "Copy a model",
|
Short: "Copy a model",
|
||||||
|
|||||||
@@ -206,6 +206,8 @@ var (
|
|||||||
UseAuth = Bool("OLLAMA_AUTH")
|
UseAuth = Bool("OLLAMA_AUTH")
|
||||||
// Enable Vulkan backend
|
// Enable Vulkan backend
|
||||||
EnableVulkan = Bool("OLLAMA_VULKAN")
|
EnableVulkan = Bool("OLLAMA_VULKAN")
|
||||||
|
// Usage enables usage statistics reporting
|
||||||
|
Usage = Bool("OLLAMA_USAGE")
|
||||||
)
|
)
|
||||||
|
|
||||||
func String(s string) func() string {
|
func String(s string) func() string {
|
||||||
|
|||||||
128
server/routes.go
128
server/routes.go
@@ -20,6 +20,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -44,6 +45,7 @@ import (
|
|||||||
"github.com/ollama/ollama/model/renderers"
|
"github.com/ollama/ollama/model/renderers"
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
|
"github.com/ollama/ollama/server/usage"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/thinking"
|
"github.com/ollama/ollama/thinking"
|
||||||
"github.com/ollama/ollama/tools"
|
"github.com/ollama/ollama/tools"
|
||||||
@@ -82,6 +84,7 @@ type Server struct {
|
|||||||
addr net.Addr
|
addr net.Addr
|
||||||
sched *Scheduler
|
sched *Scheduler
|
||||||
lowVRAM bool
|
lowVRAM bool
|
||||||
|
stats *usage.Stats
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -104,6 +107,30 @@ var (
|
|||||||
errBadTemplate = errors.New("template error")
|
errBadTemplate = errors.New("template error")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// usage records a request to usage stats if enabled.
|
||||||
|
func (s *Server) usage(c *gin.Context, endpoint, model, architecture string, promptTokens, completionTokens int, usedTools bool) {
|
||||||
|
if s.stats == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stats.Record(&usage.Request{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Model: model,
|
||||||
|
Architecture: architecture,
|
||||||
|
APIType: usage.ClassifyAPIType(c.Request.URL.Path),
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
UsedTools: usedTools,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// usageError records a failed request to usage stats if enabled.
|
||||||
|
func (s *Server) usageError() {
|
||||||
|
if s.stats == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stats.RecordError()
|
||||||
|
}
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
if err := opts.FromMap(model.Options); err != nil {
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
@@ -374,7 +401,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
s.handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -561,6 +588,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
res.DoneReason = cr.DoneReason.String()
|
res.DoneReason = cr.DoneReason.String()
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
s.usage(c, "generate", m.ShortName, m.Config.ModelFamily, cr.PromptEvalCount, cr.EvalCount, false)
|
||||||
|
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||||
@@ -680,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
s.handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -790,6 +818,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||||
PromptEvalCount: int(totalTokens),
|
PromptEvalCount: int(totalTokens),
|
||||||
}
|
}
|
||||||
|
s.usage(c, "embed", m.ShortName, m.Config.ModelFamily, int(totalTokens), 0, false)
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -827,7 +856,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
|
|
||||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
s.handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1531,6 +1560,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
r.GET("/api/ps", s.PsHandler)
|
r.GET("/api/ps", s.PsHandler)
|
||||||
|
r.GET("/api/usage", s.UsageHandler)
|
||||||
r.POST("/api/generate", s.GenerateHandler)
|
r.POST("/api/generate", s.GenerateHandler)
|
||||||
r.POST("/api/chat", s.ChatHandler)
|
r.POST("/api/chat", s.ChatHandler)
|
||||||
r.POST("/api/embed", s.EmbedHandler)
|
r.POST("/api/embed", s.EmbedHandler)
|
||||||
@@ -1593,6 +1623,13 @@ func Serve(ln net.Listener) error {
|
|||||||
|
|
||||||
s := &Server{addr: ln.Addr()}
|
s := &Server{addr: ln.Addr()}
|
||||||
|
|
||||||
|
// Initialize usage stats if enabled
|
||||||
|
if envconfig.Usage() {
|
||||||
|
s.stats = usage.New()
|
||||||
|
s.stats.Start()
|
||||||
|
slog.Info("usage stats enabled")
|
||||||
|
}
|
||||||
|
|
||||||
var rc *ollama.Registry
|
var rc *ollama.Registry
|
||||||
if useClient2 {
|
if useClient2 {
|
||||||
var err error
|
var err error
|
||||||
@@ -1632,6 +1669,9 @@ func Serve(ln net.Listener) error {
|
|||||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
go func() {
|
go func() {
|
||||||
<-signals
|
<-signals
|
||||||
|
if s.stats != nil {
|
||||||
|
s.stats.Stop()
|
||||||
|
}
|
||||||
srvr.Close()
|
srvr.Close()
|
||||||
schedDone()
|
schedDone()
|
||||||
sched.unloadAllRunners()
|
sched.unloadAllRunners()
|
||||||
@@ -1649,6 +1689,24 @@ func Serve(ln net.Listener) error {
|
|||||||
gpus := discover.GPUDevices(ctx, nil)
|
gpus := discover.GPUDevices(ctx, nil)
|
||||||
discover.LogDetails(gpus)
|
discover.LogDetails(gpus)
|
||||||
|
|
||||||
|
// Set GPU info for usage reporting
|
||||||
|
if s.stats != nil {
|
||||||
|
usage.GPUInfoFunc = func() []usage.GPU {
|
||||||
|
var result []usage.GPU
|
||||||
|
for _, gpu := range gpus {
|
||||||
|
result = append(result, usage.GPU{
|
||||||
|
Name: gpu.Name,
|
||||||
|
VRAMBytes: gpu.TotalMemory,
|
||||||
|
ComputeMajor: gpu.ComputeMajor,
|
||||||
|
ComputeMinor: gpu.ComputeMinor,
|
||||||
|
DriverMajor: gpu.DriverMajor,
|
||||||
|
DriverMinor: gpu.DriverMinor,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var totalVRAM uint64
|
var totalVRAM uint64
|
||||||
for _, gpu := range gpus {
|
for _, gpu := range gpus {
|
||||||
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
|
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
|
||||||
@@ -1852,6 +1910,63 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) UsageHandler(c *gin.Context) {
|
||||||
|
// Get total VRAM used by Ollama
|
||||||
|
s.sched.loadedMu.Lock()
|
||||||
|
var totalOllamaVRAM uint64
|
||||||
|
for _, runner := range s.sched.loaded {
|
||||||
|
totalOllamaVRAM += runner.vramSize
|
||||||
|
}
|
||||||
|
s.sched.loadedMu.Unlock()
|
||||||
|
|
||||||
|
var resp api.UsageResponse
|
||||||
|
|
||||||
|
// Get GPU/device info
|
||||||
|
gpus := discover.GPUDevices(c.Request.Context(), nil)
|
||||||
|
|
||||||
|
// On Apple Silicon, use system memory instead of Metal's recommendedMaxWorkingSetSize
|
||||||
|
// because unified memory means GPU and CPU share the same physical RAM pool
|
||||||
|
var sysTotal, sysFree uint64
|
||||||
|
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||||
|
sysInfo := discover.GetSystemInfo()
|
||||||
|
sysTotal = sysInfo.TotalMemory
|
||||||
|
sysFree = sysInfo.FreeMemory
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gpu := range gpus {
|
||||||
|
total := gpu.TotalMemory
|
||||||
|
free := gpu.FreeMemory
|
||||||
|
|
||||||
|
// On Apple Silicon, override with system memory values
|
||||||
|
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" && sysTotal > 0 {
|
||||||
|
total = sysTotal
|
||||||
|
free = sysFree
|
||||||
|
}
|
||||||
|
|
||||||
|
used := total - free
|
||||||
|
ollamaUsed := min(totalOllamaVRAM, used)
|
||||||
|
otherUsed := used - ollamaUsed
|
||||||
|
|
||||||
|
// Use Description for Name (actual device name like "Apple M2 Max")
|
||||||
|
// Fall back to backend name if Description is empty
|
||||||
|
name := gpu.Description
|
||||||
|
if name == "" {
|
||||||
|
name = gpu.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.GPUs = append(resp.GPUs, api.GPUUsage{
|
||||||
|
Name: name,
|
||||||
|
Backend: gpu.Library,
|
||||||
|
Total: total,
|
||||||
|
Free: free,
|
||||||
|
Used: ollamaUsed,
|
||||||
|
Other: otherUsed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
|
||||||
func toolCallId() string {
|
func toolCallId() string {
|
||||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
b := make([]byte, 8)
|
b := make([]byte, 8)
|
||||||
@@ -2032,7 +2147,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
s.handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2180,6 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
res.DoneReason = r.DoneReason.String()
|
res.DoneReason = r.DoneReason.String()
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, r.PromptEvalCount, r.EvalCount, len(req.Tools) > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
if builtinParser != nil {
|
if builtinParser != nil {
|
||||||
@@ -2355,6 +2471,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
resp.Message.ToolCalls = toolCalls
|
resp.Message.ToolCalls = toolCalls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, resp.PromptEvalCount, resp.EvalCount, len(toolCalls) > 0)
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -2362,7 +2479,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
func (s *Server) handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
|
s.usageError()
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
|||||||
60
server/routes_usage_test.go
Normal file
60
server/routes_usage_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsageHandler(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
t.Run("empty server", func(t *testing.T) {
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := createRequest(t, s.UsageHandler, nil)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.UsageResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GPUs may or may not be present depending on system
|
||||||
|
// Just verify we can decode the response
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("response structure", func(t *testing.T) {
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := createRequest(t, s.UsageHandler, nil)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we can decode the response as valid JSON
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The response should be a valid object (not null)
|
||||||
|
if resp == nil {
|
||||||
|
t.Error("expected non-nil response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
65
server/usage/reporter.go
Normal file
65
server/usage/reporter.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package usage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
reportTimeout = 10 * time.Second
|
||||||
|
usageURL = "https://ollama.com/api/usage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeartbeatResponse is the response from the heartbeat endpoint.
|
||||||
|
type HeartbeatResponse struct {
|
||||||
|
UpdateVersion string `json:"update_version,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAvailable returns the available update version, if any.
|
||||||
|
func (t *Stats) UpdateAvailable() string {
|
||||||
|
if v := t.updateAvailable.Load(); v != nil {
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendHeartbeat sends usage stats and checks for updates.
|
||||||
|
func (t *Stats) sendHeartbeat(payload *Payload) {
|
||||||
|
data, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), reportTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, usageURL, bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s", version.Version))
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var heartbeat HeartbeatResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&heartbeat); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.updateAvailable.Store(heartbeat.UpdateVersion)
|
||||||
|
}
|
||||||
23
server/usage/source.go
Normal file
23
server/usage/source.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package usage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// API type constants
|
||||||
|
const (
|
||||||
|
APITypeOllama = "ollama"
|
||||||
|
APITypeOpenAI = "openai"
|
||||||
|
APITypeAnthropic = "anthropic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClassifyAPIType determines the API type from the request path.
|
||||||
|
func ClassifyAPIType(path string) string {
|
||||||
|
if strings.HasPrefix(path, "/v1/messages") {
|
||||||
|
return APITypeAnthropic
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "/v1/") {
|
||||||
|
return APITypeOpenAI
|
||||||
|
}
|
||||||
|
return APITypeOllama
|
||||||
|
}
|
||||||
324
server/usage/usage.go
Normal file
324
server/usage/usage.go
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
// Package usage provides in-memory usage statistics collection and reporting.
|
||||||
|
package usage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/discover"
|
||||||
|
"github.com/ollama/ollama/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Stats collects usage statistics in memory and reports them periodically.
|
||||||
|
type Stats struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Atomic counters for hot path
|
||||||
|
requestsTotal atomic.Int64
|
||||||
|
tokensPrompt atomic.Int64
|
||||||
|
tokensCompletion atomic.Int64
|
||||||
|
errorsTotal atomic.Int64
|
||||||
|
|
||||||
|
// Map-based counters (require lock)
|
||||||
|
endpoints map[string]int64
|
||||||
|
architectures map[string]int64
|
||||||
|
apis map[string]int64
|
||||||
|
models map[string]*ModelStats // per-model stats
|
||||||
|
|
||||||
|
// Feature usage
|
||||||
|
toolCalls atomic.Int64
|
||||||
|
structuredOutput atomic.Int64
|
||||||
|
|
||||||
|
// Update info (set by reporter after pinging update endpoint)
|
||||||
|
updateAvailable atomic.Value // string
|
||||||
|
|
||||||
|
// Reporter
|
||||||
|
stopCh chan struct{}
|
||||||
|
doneCh chan struct{}
|
||||||
|
interval time.Duration
|
||||||
|
endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelStats tracks per-model usage statistics.
|
||||||
|
type ModelStats struct {
|
||||||
|
Requests int64
|
||||||
|
TokensInput int64
|
||||||
|
TokensOutput int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request contains the data to record for a single request.
|
||||||
|
type Request struct {
|
||||||
|
Endpoint string // "chat", "generate", "embed"
|
||||||
|
Model string // model name (e.g., "llama3.2:3b")
|
||||||
|
Architecture string // model architecture (e.g., "llama", "qwen2")
|
||||||
|
APIType string // "native" or "openai_compat"
|
||||||
|
PromptTokens int
|
||||||
|
CompletionTokens int
|
||||||
|
UsedTools bool
|
||||||
|
StructuredOutput bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemInfo contains hardware information to report.
|
||||||
|
type SystemInfo struct {
|
||||||
|
OS string `json:"os"`
|
||||||
|
Arch string `json:"arch"`
|
||||||
|
CPUCores int `json:"cpu_cores"`
|
||||||
|
RAMBytes uint64 `json:"ram_bytes"`
|
||||||
|
GPUs []GPU `json:"gpus,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GPU contains information about a GPU.
|
||||||
|
type GPU struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
VRAMBytes uint64 `json:"vram_bytes"`
|
||||||
|
ComputeMajor int `json:"compute_major,omitempty"`
|
||||||
|
ComputeMinor int `json:"compute_minor,omitempty"`
|
||||||
|
DriverMajor int `json:"driver_major,omitempty"`
|
||||||
|
DriverMinor int `json:"driver_minor,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Payload is the data sent to the heartbeat endpoint.
|
||||||
|
type Payload struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
Time time.Time `json:"time"`
|
||||||
|
System SystemInfo `json:"system"`
|
||||||
|
|
||||||
|
Totals struct {
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Errors int64 `json:"errors"`
|
||||||
|
InputTokens int64 `json:"input_tokens"`
|
||||||
|
OutputTokens int64 `json:"output_tokens"`
|
||||||
|
} `json:"totals"`
|
||||||
|
|
||||||
|
Endpoints map[string]int64 `json:"endpoints"`
|
||||||
|
Architectures map[string]int64 `json:"architectures"`
|
||||||
|
APIs map[string]int64 `json:"apis"`
|
||||||
|
|
||||||
|
Features struct {
|
||||||
|
ToolCalls int64 `json:"tool_calls"`
|
||||||
|
StructuredOutput int64 `json:"structured_output"`
|
||||||
|
} `json:"features"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultInterval = 1 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
// New creates a new Stats instance.
|
||||||
|
func New(opts ...Option) *Stats {
|
||||||
|
t := &Stats{
|
||||||
|
endpoints: make(map[string]int64),
|
||||||
|
architectures: make(map[string]int64),
|
||||||
|
apis: make(map[string]int64),
|
||||||
|
models: make(map[string]*ModelStats),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
interval: defaultInterval,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures the Stats instance.
|
||||||
|
type Option func(*Stats)
|
||||||
|
|
||||||
|
// WithInterval sets the reporting interval.
|
||||||
|
func WithInterval(d time.Duration) Option {
|
||||||
|
return func(t *Stats) {
|
||||||
|
t.interval = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record records a request. This is the hot path and should be fast.
|
||||||
|
func (t *Stats) Record(r *Request) {
|
||||||
|
t.requestsTotal.Add(1)
|
||||||
|
t.tokensPrompt.Add(int64(r.PromptTokens))
|
||||||
|
t.tokensCompletion.Add(int64(r.CompletionTokens))
|
||||||
|
|
||||||
|
if r.UsedTools {
|
||||||
|
t.toolCalls.Add(1)
|
||||||
|
}
|
||||||
|
if r.StructuredOutput {
|
||||||
|
t.structuredOutput.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
t.endpoints[r.Endpoint]++
|
||||||
|
t.architectures[r.Architecture]++
|
||||||
|
t.apis[r.APIType]++
|
||||||
|
|
||||||
|
// Track per-model stats
|
||||||
|
if r.Model != "" {
|
||||||
|
if t.models[r.Model] == nil {
|
||||||
|
t.models[r.Model] = &ModelStats{}
|
||||||
|
}
|
||||||
|
t.models[r.Model].Requests++
|
||||||
|
t.models[r.Model].TokensInput += int64(r.PromptTokens)
|
||||||
|
t.models[r.Model].TokensOutput += int64(r.CompletionTokens)
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordError records a failed request.
|
||||||
|
func (t *Stats) RecordError() {
|
||||||
|
t.errorsTotal.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelStats returns a copy of per-model statistics.
|
||||||
|
func (t *Stats) GetModelStats() map[string]*ModelStats {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make(map[string]*ModelStats, len(t.models))
|
||||||
|
for k, v := range t.models {
|
||||||
|
result[k] = &ModelStats{
|
||||||
|
Requests: v.Requests,
|
||||||
|
TokensInput: v.TokensInput,
|
||||||
|
TokensOutput: v.TokensOutput,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// View returns current stats without resetting counters.
|
||||||
|
func (t *Stats) View() *Payload {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Copy maps
|
||||||
|
endpoints := make(map[string]int64, len(t.endpoints))
|
||||||
|
for k, v := range t.endpoints {
|
||||||
|
endpoints[k] = v
|
||||||
|
}
|
||||||
|
architectures := make(map[string]int64, len(t.architectures))
|
||||||
|
for k, v := range t.architectures {
|
||||||
|
architectures[k] = v
|
||||||
|
}
|
||||||
|
apis := make(map[string]int64, len(t.apis))
|
||||||
|
for k, v := range t.apis {
|
||||||
|
apis[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &Payload{
|
||||||
|
Version: version.Version,
|
||||||
|
Time: now,
|
||||||
|
System: getSystemInfo(),
|
||||||
|
Endpoints: endpoints,
|
||||||
|
Architectures: architectures,
|
||||||
|
APIs: apis,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Totals.Requests = t.requestsTotal.Load()
|
||||||
|
p.Totals.Errors = t.errorsTotal.Load()
|
||||||
|
p.Totals.InputTokens = t.tokensPrompt.Load()
|
||||||
|
p.Totals.OutputTokens = t.tokensCompletion.Load()
|
||||||
|
p.Features.ToolCalls = t.toolCalls.Load()
|
||||||
|
p.Features.StructuredOutput = t.structuredOutput.Load()
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot returns current stats and resets counters.
|
||||||
|
func (t *Stats) Snapshot() *Payload {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
p := &Payload{
|
||||||
|
Version: version.Version,
|
||||||
|
Time: now,
|
||||||
|
System: getSystemInfo(),
|
||||||
|
Endpoints: t.endpoints,
|
||||||
|
Architectures: t.architectures,
|
||||||
|
APIs: t.apis,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Totals.Requests = t.requestsTotal.Swap(0)
|
||||||
|
p.Totals.Errors = t.errorsTotal.Swap(0)
|
||||||
|
p.Totals.InputTokens = t.tokensPrompt.Swap(0)
|
||||||
|
p.Totals.OutputTokens = t.tokensCompletion.Swap(0)
|
||||||
|
p.Features.ToolCalls = t.toolCalls.Swap(0)
|
||||||
|
p.Features.StructuredOutput = t.structuredOutput.Swap(0)
|
||||||
|
|
||||||
|
// Reset maps
|
||||||
|
t.endpoints = make(map[string]int64)
|
||||||
|
t.architectures = make(map[string]int64)
|
||||||
|
t.apis = make(map[string]int64)
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemInfo collects hardware information.
|
||||||
|
func getSystemInfo() SystemInfo {
|
||||||
|
info := SystemInfo{
|
||||||
|
OS: runtime.GOOS,
|
||||||
|
Arch: runtime.GOARCH,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get CPU and memory info
|
||||||
|
sysInfo := discover.GetSystemInfo()
|
||||||
|
info.CPUCores = sysInfo.ThreadCount
|
||||||
|
info.RAMBytes = sysInfo.TotalMemory
|
||||||
|
|
||||||
|
// Get GPU info
|
||||||
|
gpus := getGPUInfo()
|
||||||
|
info.GPUs = gpus
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// GPUInfoFunc is a function that returns GPU information.
|
||||||
|
// It's set by the server package after GPU discovery.
|
||||||
|
var GPUInfoFunc func() []GPU
|
||||||
|
|
||||||
|
// getGPUInfo collects GPU information.
|
||||||
|
func getGPUInfo() []GPU {
|
||||||
|
if GPUInfoFunc != nil {
|
||||||
|
return GPUInfoFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the periodic reporting goroutine.
|
||||||
|
func (t *Stats) Start() {
|
||||||
|
go t.reportLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops reporting and waits for the final report.
|
||||||
|
func (t *Stats) Stop() {
|
||||||
|
close(t.stopCh)
|
||||||
|
<-t.doneCh
|
||||||
|
}
|
||||||
|
|
||||||
|
// reportLoop runs the periodic reporting.
|
||||||
|
func (t *Stats) reportLoop() {
|
||||||
|
defer close(t.doneCh)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(t.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
t.report()
|
||||||
|
case <-t.stopCh:
|
||||||
|
// Send final report before stopping
|
||||||
|
t.report()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// report sends usage stats and checks for updates.
|
||||||
|
func (t *Stats) report() {
|
||||||
|
payload := t.Snapshot()
|
||||||
|
t.sendHeartbeat(payload)
|
||||||
|
}
|
||||||
194
server/usage/usage_test.go
Normal file
194
server/usage/usage_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package usage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
stats := New()
|
||||||
|
if stats == nil {
|
||||||
|
t.Fatal("New() returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecord(t *testing.T) {
|
||||||
|
stats := New()
|
||||||
|
|
||||||
|
stats.Record(&Request{
|
||||||
|
Model: "llama3:8b",
|
||||||
|
Endpoint: "chat",
|
||||||
|
Architecture: "llama",
|
||||||
|
APIType: "native",
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
UsedTools: true,
|
||||||
|
StructuredOutput: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check totals
|
||||||
|
payload := stats.View()
|
||||||
|
if payload.Totals.Requests != 1 {
|
||||||
|
t.Errorf("expected 1 request, got %d", payload.Totals.Requests)
|
||||||
|
}
|
||||||
|
if payload.Totals.InputTokens != 100 {
|
||||||
|
t.Errorf("expected 100 prompt tokens, got %d", payload.Totals.InputTokens)
|
||||||
|
}
|
||||||
|
if payload.Totals.OutputTokens != 50 {
|
||||||
|
t.Errorf("expected 50 completion tokens, got %d", payload.Totals.OutputTokens)
|
||||||
|
}
|
||||||
|
if payload.Features.ToolCalls != 1 {
|
||||||
|
t.Errorf("expected 1 tool call, got %d", payload.Features.ToolCalls)
|
||||||
|
}
|
||||||
|
if payload.Features.StructuredOutput != 0 {
|
||||||
|
t.Errorf("expected 0 structured outputs, got %d", payload.Features.StructuredOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelStats(t *testing.T) {
|
||||||
|
stats := New()
|
||||||
|
|
||||||
|
// Record requests for multiple models
|
||||||
|
stats.Record(&Request{
|
||||||
|
Model: "llama3:8b",
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
})
|
||||||
|
stats.Record(&Request{
|
||||||
|
Model: "llama3:8b",
|
||||||
|
PromptTokens: 200,
|
||||||
|
CompletionTokens: 100,
|
||||||
|
})
|
||||||
|
stats.Record(&Request{
|
||||||
|
Model: "mistral:7b",
|
||||||
|
PromptTokens: 50,
|
||||||
|
CompletionTokens: 25,
|
||||||
|
})
|
||||||
|
|
||||||
|
modelStats := stats.GetModelStats()
|
||||||
|
|
||||||
|
// Check llama3:8b stats
|
||||||
|
llama := modelStats["llama3:8b"]
|
||||||
|
if llama == nil {
|
||||||
|
t.Fatal("expected llama3:8b stats")
|
||||||
|
}
|
||||||
|
if llama.Requests != 2 {
|
||||||
|
t.Errorf("expected 2 requests for llama3:8b, got %d", llama.Requests)
|
||||||
|
}
|
||||||
|
if llama.TokensInput != 300 {
|
||||||
|
t.Errorf("expected 300 input tokens for llama3:8b, got %d", llama.TokensInput)
|
||||||
|
}
|
||||||
|
if llama.TokensOutput != 150 {
|
||||||
|
t.Errorf("expected 150 output tokens for llama3:8b, got %d", llama.TokensOutput)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check mistral:7b stats
|
||||||
|
mistral := modelStats["mistral:7b"]
|
||||||
|
if mistral == nil {
|
||||||
|
t.Fatal("expected mistral:7b stats")
|
||||||
|
}
|
||||||
|
if mistral.Requests != 1 {
|
||||||
|
t.Errorf("expected 1 request for mistral:7b, got %d", mistral.Requests)
|
||||||
|
}
|
||||||
|
if mistral.TokensInput != 50 {
|
||||||
|
t.Errorf("expected 50 input tokens for mistral:7b, got %d", mistral.TokensInput)
|
||||||
|
}
|
||||||
|
if mistral.TokensOutput != 25 {
|
||||||
|
t.Errorf("expected 25 output tokens for mistral:7b, got %d", mistral.TokensOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordError(t *testing.T) {
|
||||||
|
stats := New()
|
||||||
|
|
||||||
|
stats.RecordError()
|
||||||
|
stats.RecordError()
|
||||||
|
|
||||||
|
payload := stats.View()
|
||||||
|
if payload.Totals.Errors != 2 {
|
||||||
|
t.Errorf("expected 2 errors, got %d", payload.Totals.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestView(t *testing.T) {
|
||||||
|
stats := New()
|
||||||
|
|
||||||
|
stats.Record(&Request{
|
||||||
|
Model: "llama3:8b",
|
||||||
|
Endpoint: "chat",
|
||||||
|
Architecture: "llama",
|
||||||
|
APIType: "native",
|
||||||
|
})
|
||||||
|
|
||||||
|
// First view
|
||||||
|
_ = stats.View()
|
||||||
|
|
||||||
|
// View should not reset counters
|
||||||
|
payload := stats.View()
|
||||||
|
if payload.Totals.Requests != 1 {
|
||||||
|
t.Errorf("View should not reset counters, expected 1 request, got %d", payload.Totals.Requests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshot(t *testing.T) {
|
||||||
|
stats := New()
|
||||||
|
|
||||||
|
stats.Record(&Request{
|
||||||
|
Model: "llama3:8b",
|
||||||
|
Endpoint: "chat",
|
||||||
|
PromptTokens: 100,
|
||||||
|
CompletionTokens: 50,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Snapshot should return data and reset counters
|
||||||
|
snapshot := stats.Snapshot()
|
||||||
|
if snapshot.Totals.Requests != 1 {
|
||||||
|
t.Errorf("expected 1 request in snapshot, got %d", snapshot.Totals.Requests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// After snapshot, counters should be reset
|
||||||
|
payload2 := stats.View()
|
||||||
|
if payload2.Totals.Requests != 0 {
|
||||||
|
t.Errorf("expected 0 requests after snapshot, got %d", payload2.Totals.Requests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
stats := New()
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
|
// Concurrent writes
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func() {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
stats.Record(&Request{
|
||||||
|
Model: "llama3:8b",
|
||||||
|
PromptTokens: 10,
|
||||||
|
CompletionTokens: 5,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent reads
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
go func() {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
_ = stats.View()
|
||||||
|
_ = stats.GetModelStats()
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := stats.View()
|
||||||
|
if payload.Totals.Requests != 1000 {
|
||||||
|
t.Errorf("expected 1000 requests, got %d", payload.Totals.Requests)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user