From f15b9178ec25ae0057bbb09abcf13dd44061a51f Mon Sep 17 00:00:00 2001 From: "LocalAI [bot]" <139863280+localai-bot@users.noreply.github.com> Date: Thu, 21 May 2026 16:34:02 +0200 Subject: [PATCH] feat(usage): track and visualise usage per API key (#9920) * feat(usage): add Source, APIKeyID, APIKeyName columns to UsageRecord Adds three additive columns plus UsageSource* constants. The columns are auto-migrated by InitDB. APIKeyID is a nullable foreign reference to UserAPIKey.ID; APIKeyName is snapshotted on each row so revoked keys keep showing their name in history. Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(usage): backfill Source on pre-feature usage rows InitDB now classifies any pre-existing usage_record with an empty source: 'legacy-api-key' user -> legacy, everything else -> web. The backfill is idempotent (only touches NULL/empty rows). Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(usage): add GetUserUsageBySource aggregator Groups by (bucket, source, api_key_id, api_key_name). Filters out legacy by default. Returns both per-bucket detail and roll-ups (by_source, by_key sorted desc and capped at 200, grand_total). The MAX(created_at) projection is iterated via Rows().Scan into a string column and parsed manually because the SQLite driver surfaces the aggregated timestamp as a string, which database/sql refuses to scan directly into time.Time. Postgres returns a real timestamp; the same string path handles its RFC3339 form too. Refs: #9862 Signed-off-by: Ettore Di Giacinto * fix(usage): log Rows() errors and assert LastUsed in tests Adds rows.Err() and Rows() open-failure logging in computeSourceTotals so silent data drops surface in logs. Logs on parseLastUsedString format misses for the same reason. Strengthens the snapshot-survival test to assert LastUsed is a recent timestamp, locking the SQLite time-string parser behaviour. Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(usage): add admin GetAllUsageBySource with filters and truncation Optional user_id and api_key_id filters (composed with AND). Legacy bucket is included for admin callers. truncated=true when more than 200 distinct keys would be in the by_key roll-up. Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(auth): plumb auth_source and auth_apikey through Echo context tryAuthenticate now sets auth_source on every successful branch (web for session/Bearer-session, apikey for Bearer-key/x-api-key/ token-cookie, legacy for legacy env key match). For named-key branches it also stores the resolved *UserAPIKey under auth_apikey so downstream middlewares can snapshot id+name without re-validating. Refs: #9862 Signed-off-by: Ettore Di Giacinto * fix(auth): expand tryAuthenticate godoc and cover Bearer-session branch Documents all three context-keys side effects (auth_source, auth_apikey, _auth_session) plus the split of responsibilities with the parent Middleware. Adds a test for the Bearer-as-session-token classification so future regressions there fail loudly. Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(usage): UsageMiddleware records source + snapshots key name Reads auth_source and auth_apikey from the Echo context (set by auth.Middleware in the previous task). Snapshots UserAPIKey.ID and Name onto each row so revoked keys remain readable in history. Falls back to source=web when no auth_source is set (auth disabled or unrecognised path). Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(usage): add /api/auth/usage/sources and admin variant Self endpoint filters legacy server-side; admin endpoint includes legacy and accepts user_id + api_key_id filters. Response includes buckets, totals.{by_source, by_key, grand_total}, and a truncated flag set when the per-key roll-up was capped at 200. Refs: #9862 Signed-off-by: Ettore Di Giacinto * docs(routes): mark test mirror handlers as keep-in-sync with production The newTestAuthApp helper duplicates production route handlers inline because it cannot use RegisterAuthRoutes (which requires a *application.Application). Naming the source path on each mirror makes the drift contract explicit for future maintainers. Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(ui): add usageApi.getMySources/getAdminSources + i18n strings Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(ui): add Sources tab skeleton with data fetch Adds Usage page tab that fetches /api/auth/usage/sources (or the admin variant). Renders raw totals plus a placeholder key list; real visualisations land in subsequent commits. Restructures the existing tab button block so Models and Sources are visible to non-admins (Users remains admin-only). Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(ui): source mix ribbon + searchable/sortable sources table Replaces the SourcesTab placeholder rendering with two reusable components: SourceMixRibbon (one segmented bar per source class) and SourcesTable (search + sort + revoked-key dim). Pulls the current API key list to detect revoked keys. Refs: #9862 Signed-off-by: Ettore Di Giacinto * fix(ui): skip revoked-key detection until the key list is known existingKeyIds defaulted to an empty Set, which made every live api_key row render as (revoked) during the brief window before apiKeysApi.list() resolved, and permanently after a fetch failure. Use null as the unknown state and suppress the revoked badge until the parent provides a real Set. Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(ui): top-N stacked time chart and drill-in chip for Sources tab Top 7 sources by total tokens get distinct colours; the rest roll up into 'Other'. Clicking a row in the SourcesTable dims everything except that series in the chart; the chip is the canonical clear. Refs: #9862 Signed-off-by: Ettore Di Giacinto * docs(usage): document per-API-key Sources tab and endpoints Extends features/authentication.md Usage Tracking section with: - A 'Sources' tab description and source-class taxonomy - Endpoint documentation for /api/auth/usage/sources and the admin variant - Response shape example with by_source / by_key / grand_total - Migration note about pre-feature row backfill Refs: #9862 Signed-off-by: Ettore Di Giacinto * fix(usage): silence errcheck on deferred rows.Close CI errcheck flagged the bare 'defer rows.Close()' in computeSourceTotals. Wrap in a closure that discards the close error explicitly; an error here is non-actionable since we have already drained the rows and logged any iteration failure. Refs: #9862 Signed-off-by: Ettore Di Giacinto * refactor(usage): bound batcher intake and add Shutdown/FlushNow hooks The pre-existing usage batcher had no cap on its add() path; the usageMaxPending=5000 constant only guarded the re-queue path after a failed write, leaving memory growth unbounded if the DB fell behind. This commit: - Adds the cap to add() so saturation drops new records (rate-limited warn at 1/1024) instead of growing unbounded. - Raises usageMaxPending to 50000 to absorb realistic inference bursts. - Replaces the package-level batcher global with a mutex-guarded pair plus a currentBatcher() accessor so Init / Shutdown cycles are race-free. - Adds ShutdownUsageRecorder() for graceful drain on process exit (not yet wired into app shutdown, just published). - Adds FlushNow() for deterministic tests; the middleware suite no longer needs 6s sleeps per spec and now runs in ~50ms instead of 18s. - Re-queue on failed flush is now cap-aware: prepends as much of the failed batch as fits alongside concurrent arrivals, instead of dropping the whole batch when full. Refs: #9862 Signed-off-by: Ettore Di Giacinto * feat(usage): drain usage batcher on graceful shutdown Registers ShutdownUsageRecorder with the existing signals.RegisterGracefulTerminationHandler so SIGINT/SIGTERM synchronously flushes any in-memory usage records before the process exits. Without this, up to one flush interval (5s) of recorded usage was lost when LocalAI restarted. Refs: #9862 Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto Co-authored-by: Ettore Di Giacinto --- core/http/app.go | 6 +- core/http/auth/db.go | 8 +- core/http/auth/middleware.go | 59 +++- core/http/auth/middleware_test.go | 118 ++++++++ core/http/auth/usage.go | 263 +++++++++++++++++- core/http/auth/usage_test.go | 192 +++++++++++++ core/http/middleware/usage.go | 144 ++++++++-- core/http/middleware/usage_test.go | 140 ++++++++++ .../react-ui/public/locales/en/admin.json | 24 +- core/http/react-ui/src/pages/Usage.jsx | 39 ++- .../src/pages/Usage/SourceMixRibbon.jsx | 83 ++++++ .../src/pages/Usage/SourceTimeChart.jsx | 147 ++++++++++ .../react-ui/src/pages/Usage/SourcesTab.jsx | 175 ++++++++++++ .../react-ui/src/pages/Usage/SourcesTable.jsx | 203 ++++++++++++++ core/http/react-ui/src/utils/api.js | 8 + core/http/routes/auth.go | 45 +++ core/http/routes/auth_test.go | 145 ++++++++++ docs/content/features/authentication.md | 79 +++++- 18 files changed, 1822 insertions(+), 56 deletions(-) create mode 100644 core/http/middleware/usage_test.go create mode 100644 core/http/react-ui/src/pages/Usage/SourceMixRibbon.jsx create mode 100644 core/http/react-ui/src/pages/Usage/SourceTimeChart.jsx create mode 100644 core/http/react-ui/src/pages/Usage/SourcesTab.jsx create mode 100644 core/http/react-ui/src/pages/Usage/SourcesTable.jsx diff --git a/core/http/app.go b/core/http/app.go index 99d11bd69..464e506db 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -28,6 +28,7 @@ import ( "github.com/mudler/LocalAI/core/services/monitoring" "github.com/mudler/LocalAI/core/services/nodes" "github.com/mudler/LocalAI/core/services/quantization" + "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/xlog" ) @@ -267,9 +268,12 @@ func API(application *application.Application) (*echo.Echo, error) { e.Static("/generated-videos", videoPath) } - // Initialize usage recording when auth DB is available + // Initialize usage recording when auth DB is available, and ensure the + // batcher drains its in-memory queue on graceful shutdown so the last + // few seconds of usage don't disappear when the process exits. if application.AuthDB() != nil { httpMiddleware.InitUsageRecorder(application.AuthDB()) + signals.RegisterGracefulTerminationHandler(httpMiddleware.ShutdownUsageRecorder) } // Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is diff --git a/core/http/auth/db.go b/core/http/auth/db.go index d860e5068..5d94557ea 100644 --- a/core/http/auth/db.go +++ b/core/http/auth/db.go @@ -38,9 +38,15 @@ func InitDB(databaseURL string) (*gorm.DB, error) { } // Backfill: users created before the provider column existed have an empty - // provider — treat them as local accounts so the UI can identify them. + // provider - treat them as local accounts so the UI can identify them. db.Exec("UPDATE users SET provider = ? WHERE provider = '' OR provider IS NULL", ProviderLocal) + // Backfill: pre-feature usage_records have no source column. Classify them so the + // new per-source aggregators include them. + if err := BackfillUsageSource(db); err != nil { + return nil, fmt.Errorf("failed to backfill usage source: %w", err) + } + // Create composite index on users(provider, subject) for fast OAuth lookups if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil { // Ignore error on postgres if index already exists diff --git a/core/http/auth/middleware.go b/core/http/auth/middleware.go index 01ec33a68..c67954640 100644 --- a/core/http/auth/middleware.go +++ b/core/http/auth/middleware.go @@ -16,8 +16,10 @@ import ( ) const ( - contextKeyUser = "auth_user" - contextKeyRole = "auth_role" + contextKeyUser = "auth_user" + contextKeyRole = "auth_role" + contextKeyAPIKey = "auth_apikey" + contextKeySource = "auth_source" ) // Middleware returns an Echo middleware that handles authentication. @@ -75,6 +77,7 @@ func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.Middlewar } c.Set(contextKeyUser, syntheticUser) c.Set(contextKeyRole, RoleAdmin) + c.Set(contextKeySource, UsageSourceLegacy) authenticated = true } } @@ -213,6 +216,20 @@ func GetUserRole(c echo.Context) string { return role } +// GetAPIKey returns the resolved API key from the echo context, or nil. +// Nil for session-cookie and legacy-env-key authentication. +func GetAPIKey(c echo.Context) *UserAPIKey { + k, _ := c.Get(contextKeyAPIKey).(*UserAPIKey) + return k +} + +// GetSource returns the request's authentication source: UsageSourceAPIKey, +// UsageSourceWeb, UsageSourceLegacy, or empty if no authentication was performed. +func GetSource(c echo.Context) string { + s, _ := c.Get(contextKeySource).(string) + return s +} + // RequireRouteFeature returns a global middleware that checks the user has access // to the feature required by the matched route. It uses the RouteFeatureRegistry // to look up the required feature for each route pattern + HTTP method. @@ -421,47 +438,67 @@ func RequireQuota(db *gorm.DB) echo.MiddlewareFunc { } // tryAuthenticate attempts to authenticate the request using the database. +// +// On success it returns the user and, as a side effect, sets the following +// values on the Echo context: +// - contextKeySource ("auth_source"): always set, one of UsageSourceWeb / +// UsageSourceAPIKey. UsageSourceLegacy is set elsewhere by the parent +// Middleware when a legacy env key matches. +// - contextKeyAPIKey ("auth_apikey"): set to the resolved *UserAPIKey for +// named-key branches (Bearer, x-api-key, xi-api-key, token cookie). +// - "_auth_session": session record, used by Middleware to drive cookie +// rotation. Only set on the session-cookie branch. +// +// contextKeyUser and contextKeyRole are populated by the parent Middleware +// after this function returns. func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User { hmacSecret := appConfig.Auth.APIKeyHMACSecret - // a. Session cookie + // a. Session cookie -> web UI if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" { if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil { // Store session for rotation check in middleware c.Set("_auth_session", session) + c.Set(contextKeySource, UsageSourceWeb) return user } } - // b. Authorization: Bearer token + // b. Authorization: Bearer authHeader := c.Request().Header.Get("Authorization") if strings.HasPrefix(authHeader, "Bearer ") { token := strings.TrimPrefix(authHeader, "Bearer ") - // Try as session ID first + // b1. Session token via Bearer -> still web UI if user, _ := ValidateSession(db, token, hmacSecret); user != nil { + c.Set(contextKeySource, UsageSourceWeb) return user } - // Try as user API key + // b2. Named API key if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil { + c.Set(contextKeySource, UsageSourceAPIKey) + c.Set(contextKeyAPIKey, key) return &key.User } } - // c. x-api-key / xi-api-key headers + // c. x-api-key / xi-api-key -> named API key for _, header := range []string{"x-api-key", "xi-api-key"} { - if key := c.Request().Header.Get(header); key != "" { - if apiKey, err := ValidateAPIKey(db, key, hmacSecret); err == nil { + if k := c.Request().Header.Get(header); k != "" { + if apiKey, err := ValidateAPIKey(db, k, hmacSecret); err == nil { + c.Set(contextKeySource, UsageSourceAPIKey) + c.Set(contextKeyAPIKey, apiKey) return &apiKey.User } } } - // d. token cookie (legacy) + // d. token cookie -> named API key if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" { - // Try as user API key if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil { + c.Set(contextKeySource, UsageSourceAPIKey) + c.Set(contextKeyAPIKey, key) return &key.User } } diff --git a/core/http/auth/middleware_test.go b/core/http/auth/middleware_test.go index e7b4daa60..5137851e1 100644 --- a/core/http/auth/middleware_test.go +++ b/core/http/auth/middleware_test.go @@ -303,4 +303,122 @@ var _ = Describe("Auth Middleware", func() { } }) }) + + Describe("auth context plumbing for usage source", func() { + // probeApp builds a minimal echo app with the auth middleware and a single + // "/probe" route that captures the user, source, and apikey from context. + type probe struct { + user *auth.User + source string + key *auth.UserAPIKey + } + probeApp := func(db *gorm.DB, appConfig *config.ApplicationConfig, p *probe) *echo.Echo { + e := echo.New() + e.Use(auth.Middleware(db, appConfig)) + e.GET("/probe", func(c echo.Context) error { + p.user = auth.GetUser(c) + p.source = auth.GetSource(c) + p.key = auth.GetAPIKey(c) + return c.NoContent(http.StatusOK) + }) + return e + } + + It("session cookie sets source=web, apikey=nil", func() { + db := testDB() + appConfig := config.NewApplicationConfig() + user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal) + token := createTestSession(db, user.ID) + + var p probe + app := probeApp(db, appConfig, &p) + rec := doRequest(app, http.MethodGet, "/probe", withSessionCookie(token)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(p.user).ToNot(BeNil()) + Expect(p.user.ID).To(Equal(user.ID)) + Expect(p.source).To(Equal(auth.UsageSourceWeb)) + Expect(p.key).To(BeNil()) + }) + + It("Bearer session token sets source=web, apikey=nil", func() { + db := testDB() + appConfig := config.NewApplicationConfig() + user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal) + token := createTestSession(db, user.ID) + + var p probe + app := probeApp(db, appConfig, &p) + rec := doRequest(app, http.MethodGet, "/probe", withBearerToken(token)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(p.user).ToNot(BeNil()) + Expect(p.user.ID).To(Equal(user.ID)) + Expect(p.source).To(Equal(auth.UsageSourceWeb)) + Expect(p.key).To(BeNil()) + }) + + It("Bearer API key sets source=apikey and exposes the resolved *UserAPIKey", func() { + db := testDB() + appConfig := config.NewApplicationConfig() + user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal) + plaintext, key, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil) + Expect(err).ToNot(HaveOccurred()) + + var p probe + app := probeApp(db, appConfig, &p) + rec := doRequest(app, http.MethodGet, "/probe", withBearerToken(plaintext)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(p.source).To(Equal(auth.UsageSourceAPIKey)) + Expect(p.key).ToNot(BeNil()) + Expect(p.key.ID).To(Equal(key.ID)) + }) + + It("x-api-key header sets source=apikey", func() { + db := testDB() + appConfig := config.NewApplicationConfig() + user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal) + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil) + Expect(err).ToNot(HaveOccurred()) + + var p probe + app := probeApp(db, appConfig, &p) + rec := doRequest(app, http.MethodGet, "/probe", withXApiKey(plaintext)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(p.source).To(Equal(auth.UsageSourceAPIKey)) + Expect(p.key).ToNot(BeNil()) + }) + + It("token cookie sets source=apikey", func() { + db := testDB() + appConfig := config.NewApplicationConfig() + user := createTestUser(db, "alice@example.com", auth.RoleUser, auth.ProviderLocal) + plaintext, _, err := auth.CreateAPIKey(db, user.ID, "ci", auth.RoleUser, appConfig.Auth.APIKeyHMACSecret, nil) + Expect(err).ToNot(HaveOccurred()) + + var p probe + app := probeApp(db, appConfig, &p) + rec := doRequest(app, http.MethodGet, "/probe", withTokenCookie(plaintext)) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(p.source).To(Equal(auth.UsageSourceAPIKey)) + Expect(p.key).ToNot(BeNil()) + }) + + It("legacy env key sets source=legacy, apikey=nil", func() { + db := testDB() + appConfig := config.NewApplicationConfig() + appConfig.ApiKeys = []string{"legacy-secret"} + + var p probe + app := probeApp(db, appConfig, &p) + rec := doRequest(app, http.MethodGet, "/probe", withBearerToken("legacy-secret")) + + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(p.source).To(Equal(auth.UsageSourceLegacy)) + Expect(p.key).To(BeNil()) + }) + }) }) diff --git a/core/http/auth/usage.go b/core/http/auth/usage.go index 31c3202b2..98c11093e 100644 --- a/core/http/auth/usage.go +++ b/core/http/auth/usage.go @@ -5,14 +5,31 @@ import ( "strings" "time" + "github.com/mudler/xlog" "gorm.io/gorm" ) +// Source classification for a UsageRecord. +const ( + UsageSourceAPIKey = "apikey" // request authenticated with a named UserAPIKey + UsageSourceWeb = "web" // request authenticated with a session cookie (web UI) + UsageSourceLegacy = "legacy" // request authenticated with an env-configured legacy key +) + // UsageRecord represents a single API request's token usage. type UsageRecord struct { - ID uint `gorm:"primaryKey;autoIncrement"` - UserID string `gorm:"size:36;index:idx_usage_user_time"` - UserName string `gorm:"size:255"` + ID uint `gorm:"primaryKey;autoIncrement"` + UserID string `gorm:"size:36;index:idx_usage_user_time"` + UserName string `gorm:"size:255"` + + // Source classifies how the request authenticated. One of UsageSource* constants. + // Empty for pre-feature rows until the InitDB backfill runs. + Source string `gorm:"size:16;index:idx_usage_source"` + // APIKeyID is the UserAPIKey.ID when Source == UsageSourceAPIKey. Nil otherwise. + APIKeyID *string `gorm:"size:36;index:idx_usage_apikey"` + // APIKeyName is a snapshot of UserAPIKey.Name at write time. Survives key deletion. + APIKeyName string `gorm:"size:255"` + Model string `gorm:"size:255;index"` Endpoint string `gorm:"size:255"` PromptTokens int64 @@ -30,9 +47,12 @@ func RecordUsage(db *gorm.DB, record *UsageRecord) error { // UsageBucket is an aggregated time bucket for the dashboard. type UsageBucket struct { Bucket string `json:"bucket"` - Model string `json:"model"` + Model string `json:"model,omitempty"` UserID string `json:"user_id,omitempty"` UserName string `json:"user_name,omitempty"` + Source string `json:"source,omitempty"` + APIKeyID string `json:"api_key_id,omitempty"` + APIKeyName string `json:"api_key_name,omitempty"` PromptTokens int64 `json:"prompt_tokens"` CompletionTokens int64 `json:"completion_tokens"` TotalTokens int64 `json:"total_tokens"` @@ -119,6 +139,28 @@ func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) { return buckets, nil } +// BackfillUsageSource sets the Source column on pre-feature usage rows. +// Idempotent: only touches rows where source is NULL or empty. +// - rows whose user_id == "legacy-api-key" -> UsageSourceLegacy +// - everything else -> UsageSourceWeb +func BackfillUsageSource(db *gorm.DB) error { + // Legacy first (more specific predicate) + if err := db.Exec( + `UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '') AND user_id = ?`, + UsageSourceLegacy, "legacy-api-key", + ).Error; err != nil { + return fmt.Errorf("backfill legacy usage source: %w", err) + } + // Everything else -> web + if err := db.Exec( + `UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '')`, + UsageSourceWeb, + ).Error; err != nil { + return fmt.Errorf("backfill web usage source: %w", err) + } + return nil +} + // GetAllUsage returns aggregated usage for all users (admin). Optional userID filter. func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) { sqlite := isSQLiteDB(db) @@ -149,3 +191,216 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) { } return buckets, nil } + +// TotalsEntry is a token+request roll-up. +type TotalsEntry struct { + Tokens int64 `json:"tokens"` + Requests int64 `json:"requests"` +} + +// KeyTotal is the per-key roll-up returned by sources endpoints. +type KeyTotal struct { + APIKeyID string `json:"api_key_id"` + APIKeyName string `json:"api_key_name"` + Tokens int64 `json:"tokens"` + Requests int64 `json:"requests"` + LastUsed time.Time `json:"last_used"` +} + +// SourceTotals summarises a per-source breakdown. +type SourceTotals struct { + BySource map[string]TotalsEntry `json:"by_source"` + ByKey []KeyTotal `json:"by_key"` // server-sorted desc by tokens, capped + GrandTotal TotalsEntry `json:"grand_total"` +} + +const maxKeyTotals = 200 + +// GetUserUsageBySource returns per-source aggregated usage for one user. Legacy +// is excluded by design (visible to admins only via the admin variant). +func GetUserUsageBySource(db *gorm.DB, userID, period string) ([]UsageBucket, SourceTotals, error) { + sqlite := isSQLiteDB(db) + since, dateFmt := periodToWindow(period, sqlite) + bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) + + query := db.Model(&UsageRecord{}). + Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+ + "SUM(prompt_tokens) as prompt_tokens, "+ + "SUM(completion_tokens) as completion_tokens, "+ + "SUM(total_tokens) as total_tokens, "+ + "COUNT(*) as request_count"). + Where("user_id = ?", userID). + Where("source <> ?", UsageSourceLegacy). + Group("bucket, source, api_key_id, api_key_name"). + Order("bucket ASC") + + if !since.IsZero() { + query = query.Where("created_at >= ?", since) + } + + var buckets []UsageBucket + if err := query.Find(&buckets).Error; err != nil { + return nil, SourceTotals{}, err + } + + totals := computeSourceTotals(db, userID, "", since, false) + return buckets, totals, nil +} + +// computeSourceTotals rolls up by_source / by_key / grand_total. +// userID/apiKeyID are optional filters. includeLegacy controls whether the +// legacy bucket is exposed (admin-only). +func computeSourceTotals(db *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) SourceTotals { + totals := SourceTotals{BySource: map[string]TotalsEntry{}} + + bySourceQ := db.Model(&UsageRecord{}). + Select("source, SUM(total_tokens) as tokens, COUNT(*) as requests"). + Group("source") + bySourceQ = applyFilters(bySourceQ, userID, apiKeyID, since, includeLegacy) + + var bySourceRows []struct { + Source string + Tokens int64 + Requests int64 + } + if err := bySourceQ.Scan(&bySourceRows).Error; err != nil { + xlog.Warn("computeSourceTotals: by-source Scan failed", "error", err) + return totals + } + for _, r := range bySourceRows { + totals.BySource[r.Source] = TotalsEntry{Tokens: r.Tokens, Requests: r.Requests} + totals.GrandTotal.Tokens += r.Tokens + totals.GrandTotal.Requests += r.Requests + } + + byKeyQ := db.Model(&UsageRecord{}). + Select("COALESCE(api_key_id, '') as api_key_id, api_key_name, "+ + "SUM(total_tokens) as tokens, COUNT(*) as requests, MAX(created_at) as last_used"). + Where("api_key_id IS NOT NULL AND api_key_id <> ''"). + Group("api_key_id, api_key_name"). + Order("tokens DESC"). + Limit(maxKeyTotals) + byKeyQ = applyFilters(byKeyQ, userID, apiKeyID, since, includeLegacy) + + // Iterate Rows() manually because MAX(created_at) is returned as a string by + // the SQLite driver, and Go's database/sql refuses to scan that into + // *time.Time. Postgres returns a proper timestamp. We accept both shapes + // via a Rows.Scan into a string column, then parse uniformly. + rows, err := byKeyQ.Rows() + if err != nil { + xlog.Warn("computeSourceTotals: by-key Rows() failed", "error", err) + } else { + defer func() { _ = rows.Close() }() + out := make([]KeyTotal, 0) + for rows.Next() { + var ( + apiKeyID, apiKeyName, lastUsedRaw string + tokens, requests int64 + ) + if scanErr := rows.Scan(&apiKeyID, &apiKeyName, &tokens, &requests, &lastUsedRaw); scanErr != nil { + continue + } + out = append(out, KeyTotal{ + APIKeyID: apiKeyID, + APIKeyName: apiKeyName, + Tokens: tokens, + Requests: requests, + LastUsed: parseLastUsedString(lastUsedRaw), + }) + } + if rerr := rows.Err(); rerr != nil { + xlog.Warn("computeSourceTotals: by-key rows iteration failed", "error", rerr) + } + totals.ByKey = out + } + + return totals +} + +// parseLastUsedString converts the textual MAX(created_at) value returned by +// SQLite (or any driver that surfaces the timestamp as a string) into a +// time.Time. Returns the zero time on parse failure. +func parseLastUsedString(s string) time.Time { + if s == "" { + return time.Time{} + } + // GORM's SQLite driver emits Go's default time formatting. Try the formats + // it commonly produces, falling back to RFC3339Nano. + layouts := []string{ + "2006-01-02 15:04:05.999999999 -0700 MST", + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02 15:04:05", + time.RFC3339Nano, + time.RFC3339, + } + for _, layout := range layouts { + if t, err := time.Parse(layout, s); err == nil { + return t + } + } + xlog.Warn("parseLastUsedString: unrecognised format", "value", s) + return time.Time{} +} + +// GetAllUsageBySource is the admin variant of GetUserUsageBySource. +// Optional filters: userID and apiKeyID. Legacy is included. +// truncated == true iff the per-key roll-up was capped at maxKeyTotals. +func GetAllUsageBySource(db *gorm.DB, period, userID, apiKeyID string) ([]UsageBucket, SourceTotals, bool, error) { + sqlite := isSQLiteDB(db) + since, dateFmt := periodToWindow(period, sqlite) + bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) + + query := db.Model(&UsageRecord{}). + Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+ + "user_id, user_name, "+ + "SUM(prompt_tokens) as prompt_tokens, "+ + "SUM(completion_tokens) as completion_tokens, "+ + "SUM(total_tokens) as total_tokens, "+ + "COUNT(*) as request_count"). + Group("bucket, source, api_key_id, api_key_name, user_id, user_name"). + Order("bucket ASC") + + query = applyFilters(query, userID, apiKeyID, since, true) + + var buckets []UsageBucket + if err := query.Find(&buckets).Error; err != nil { + return nil, SourceTotals{}, false, err + } + + totals := computeSourceTotals(db, userID, apiKeyID, since, true) + + // Count distinct api_key_ids matching the filters. If > maxKeyTotals, + // the by_key slice was capped and we signal truncation to the caller. + truncated := false + var distinct int64 + countQ := applyFilters( + db.Model(&UsageRecord{}). + Distinct("api_key_id"). + Where("api_key_id IS NOT NULL AND api_key_id <> ''"), + userID, apiKeyID, since, true, + ) + if err := countQ.Count(&distinct).Error; err != nil { + xlog.Warn("GetAllUsageBySource: distinct api_key_id count failed", "error", err) + } else { + truncated = distinct > maxKeyTotals + } + + return buckets, totals, truncated, nil +} + +func applyFilters(q *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) *gorm.DB { + if userID != "" { + q = q.Where("user_id = ?", userID) + } + if apiKeyID != "" { + q = q.Where("api_key_id = ?", apiKeyID) + } + if !since.IsZero() { + q = q.Where("created_at >= ?", since) + } + if !includeLegacy { + q = q.Where("source <> ?", UsageSourceLegacy) + } + return q +} diff --git a/core/http/auth/usage_test.go b/core/http/auth/usage_test.go index 8782ac095..7b8a457a2 100644 --- a/core/http/auth/usage_test.go +++ b/core/http/auth/usage_test.go @@ -3,11 +3,13 @@ package auth_test import ( + "fmt" "time" "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "gorm.io/gorm" ) var _ = Describe("Usage", func() { @@ -158,4 +160,194 @@ var _ = Describe("Usage", func() { } }) }) + + Describe("Usage source backfill", func() { + It("backfills 'web' for pre-feature rows", func() { + db := testDB() + + rawDB, err := db.DB() + Expect(err).ToNot(HaveOccurred()) + _, err = rawDB.Exec( + `INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`, + "user-x", "gpt-4", time.Now()) + Expect(err).ToNot(HaveOccurred()) + + Expect(auth.BackfillUsageSource(db)).To(Succeed()) + + var loaded auth.UsageRecord + Expect(db.Where("user_id = ?", "user-x").First(&loaded).Error).To(Succeed()) + Expect(loaded.Source).To(Equal(auth.UsageSourceWeb)) + }) + + It("backfills 'legacy' for pre-feature rows with legacy-api-key user_id", func() { + db := testDB() + + rawDB, err := db.DB() + Expect(err).ToNot(HaveOccurred()) + _, err = rawDB.Exec( + `INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`, + "legacy-api-key", "gpt-4", time.Now()) + Expect(err).ToNot(HaveOccurred()) + + Expect(auth.BackfillUsageSource(db)).To(Succeed()) + + var loaded auth.UsageRecord + Expect(db.Where("user_id = ?", "legacy-api-key").First(&loaded).Error).To(Succeed()) + Expect(loaded.Source).To(Equal(auth.UsageSourceLegacy)) + }) + + It("is idempotent on re-run", func() { + db := testDB() + Expect(auth.BackfillUsageSource(db)).To(Succeed()) + Expect(auth.BackfillUsageSource(db)).To(Succeed()) + }) + }) + + Describe("UsageRecord with source fields", func() { + It("persists Source, APIKeyID, APIKeyName", func() { + db := testDB() + keyID := "key-uuid-1" + record := &auth.UsageRecord{ + UserID: "user-1", + UserName: "Test User", + Source: auth.UsageSourceAPIKey, + APIKeyID: &keyID, + APIKeyName: "ci-runner", + Model: "gpt-4", + Endpoint: "/v1/chat/completions", + TotalTokens: 150, + CreatedAt: time.Now(), + } + Expect(auth.RecordUsage(db, record)).To(Succeed()) + + var loaded auth.UsageRecord + Expect(db.First(&loaded, record.ID).Error).To(Succeed()) + Expect(loaded.Source).To(Equal(auth.UsageSourceAPIKey)) + Expect(loaded.APIKeyID).ToNot(BeNil()) + Expect(*loaded.APIKeyID).To(Equal("key-uuid-1")) + Expect(loaded.APIKeyName).To(Equal("ci-runner")) + }) + + It("allows nil APIKeyID for web/legacy sources", func() { + db := testDB() + record := &auth.UsageRecord{ + UserID: "user-1", + Source: auth.UsageSourceWeb, + Model: "gpt-4", + CreatedAt: time.Now(), + } + Expect(auth.RecordUsage(db, record)).To(Succeed()) + + var loaded auth.UsageRecord + Expect(db.First(&loaded, record.ID).Error).To(Succeed()) + Expect(loaded.Source).To(Equal(auth.UsageSourceWeb)) + Expect(loaded.APIKeyID).To(BeNil()) + Expect(loaded.APIKeyName).To(BeEmpty()) + }) + }) + + Describe("GetUserUsageBySource", func() { + insert := func(db *gorm.DB, userID, source, keyID, keyName string, tokens int64, when time.Time) { + rec := &auth.UsageRecord{ + UserID: userID, + Source: source, + Model: "gpt-4", + TotalTokens: tokens, + CreatedAt: when, + } + if keyID != "" { + rec.APIKeyID = &keyID + rec.APIKeyName = keyName + } + Expect(auth.RecordUsage(db, rec)).To(Succeed()) + } + + It("returns only the caller's rows, never legacy", func() { + db := testDB() + now := time.Now() + insert(db, "alice", auth.UsageSourceAPIKey, "k1", "ci", 100, now) + insert(db, "alice", auth.UsageSourceWeb, "", "", 50, now) + insert(db, "alice", auth.UsageSourceLegacy, "", "", 30, now) + insert(db, "bob", auth.UsageSourceAPIKey, "k2", "bobk", 90, now) + + buckets, totals, err := auth.GetUserUsageBySource(db, "alice", "month") + Expect(err).ToNot(HaveOccurred()) + + for _, b := range buckets { + Expect(b.UserID).To(Or(BeEmpty(), Equal("alice"))) + Expect(b.Source).ToNot(Equal(auth.UsageSourceLegacy)) + } + + Expect(totals.GrandTotal.Tokens).To(Equal(int64(150))) + Expect(totals.BySource[auth.UsageSourceAPIKey].Tokens).To(Equal(int64(100))) + Expect(totals.BySource[auth.UsageSourceWeb].Tokens).To(Equal(int64(50))) + _, hasLegacy := totals.BySource[auth.UsageSourceLegacy] + Expect(hasLegacy).To(BeFalse()) + }) + + It("snapshots survive key deletion", func() { + db := testDB() + now := time.Now() + insert(db, "alice", auth.UsageSourceAPIKey, "deleted-key", "old-name", 42, now) + _, totals, err := auth.GetUserUsageBySource(db, "alice", "month") + Expect(err).ToNot(HaveOccurred()) + Expect(totals.ByKey).To(HaveLen(1)) + Expect(totals.ByKey[0].APIKeyName).To(Equal("old-name")) + Expect(totals.ByKey[0].APIKeyID).To(Equal("deleted-key")) + Expect(totals.ByKey[0].LastUsed).ToNot(BeZero()) + Expect(totals.ByKey[0].LastUsed).To(BeTemporally("~", now, 2*time.Second)) + }) + }) + + Describe("GetAllUsageBySource", func() { + insert := func(db *gorm.DB, userID, source, keyID string, tokens int64) { + rec := &auth.UsageRecord{ + UserID: userID, + Source: source, + Model: "gpt-4", + TotalTokens: tokens, + CreatedAt: time.Now(), + } + if keyID != "" { + rec.APIKeyID = &keyID + rec.APIKeyName = "name-" + keyID + } + Expect(auth.RecordUsage(db, rec)).To(Succeed()) + } + + It("includes legacy for admins", func() { + db := testDB() + insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10) + insert(db, "legacy-api-key", auth.UsageSourceLegacy, "", 5) + + _, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "") + Expect(err).ToNot(HaveOccurred()) + Expect(totals.BySource).To(HaveKey(auth.UsageSourceLegacy)) + Expect(totals.BySource[auth.UsageSourceLegacy].Tokens).To(Equal(int64(5))) + }) + + It("filters by user_id AND api_key_id", func() { + db := testDB() + insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10) + insert(db, "alice", auth.UsageSourceAPIKey, "k2", 20) + insert(db, "bob", auth.UsageSourceAPIKey, "k3", 30) + + _, totals, _, err := auth.GetAllUsageBySource(db, "month", "alice", "k2") + Expect(err).ToNot(HaveOccurred()) + Expect(totals.GrandTotal.Tokens).To(Equal(int64(20))) + }) + + It("sets truncated=true when by_key exceeds the cap", func() { + db := testDB() + for i := 0; i < 210; i++ { + insert(db, "alice", auth.UsageSourceAPIKey, fmt.Sprintf("key-%03d", i), int64(210-i)) + } + + _, totals, truncated, err := auth.GetAllUsageBySource(db, "month", "", "") + Expect(err).ToNot(HaveOccurred()) + Expect(truncated).To(BeTrue()) + Expect(totals.ByKey).To(HaveLen(200)) + Expect(totals.ByKey[0].Tokens > totals.ByKey[199].Tokens).To(BeTrue()) + }) + }) }) diff --git a/core/http/middleware/usage.go b/core/http/middleware/usage.go index b82c1ee3f..6dc4699b8 100644 --- a/core/http/middleware/usage.go +++ b/core/http/middleware/usage.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "sync" + "sync/atomic" "time" "github.com/labstack/echo/v4" @@ -14,18 +15,37 @@ import ( const ( usageFlushInterval = 5 * time.Second - usageMaxPending = 5000 + // usageMaxPending bounds the in-memory queue. Sized for bursty inference + // traffic on a self-hosted instance with a slow or unavailable DB. + usageMaxPending = 50000 ) // usageBatcher accumulates usage records and flushes them to the DB periodically. type usageBatcher struct { - mu sync.Mutex - pending []*auth.UsageRecord - db *gorm.DB + mu sync.Mutex + pending []*auth.UsageRecord + db *gorm.DB + stop chan struct{} + done chan struct{} + stopOnce sync.Once } +// droppedRecords counts records discarded because the in-memory queue was full. +// Used to rate-limit the warn log so a sustained outage doesn't flood it. +var droppedRecords atomic.Uint64 + func (b *usageBatcher) add(r *auth.UsageRecord) { b.mu.Lock() + if len(b.pending) >= usageMaxPending { + b.mu.Unlock() + // Rate-limit: one warn per 1024 drops keeps the log readable. + n := droppedRecords.Add(1) + if n&1023 == 1 { + xlog.Warn("usage batcher full, dropping record", + "cap", usageMaxPending, "total_dropped", n) + } + return + } b.pending = append(b.pending, r) b.mu.Unlock() } @@ -42,31 +62,102 @@ func (b *usageBatcher) flush() { if err := b.db.Create(&batch).Error; err != nil { xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err) - // Re-queue failed records with a cap to avoid unbounded growth + // Cap-aware re-queue: prepend as much of the failed batch as fits + // alongside any records added concurrently with the failed write. b.mu.Lock() - if len(b.pending) < usageMaxPending { - b.pending = append(batch, b.pending...) + room := usageMaxPending - len(b.pending) + if room > 0 { + if room > len(batch) { + room = len(batch) + } + b.pending = append(batch[:room], b.pending...) } b.mu.Unlock() } } -var batcher *usageBatcher +func (b *usageBatcher) run() { + defer close(b.done) + ticker := time.NewTicker(usageFlushInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + b.flush() + case <-b.stop: + b.flush() // final drain + return + } + } +} + +func (b *usageBatcher) shutdown() { + b.stopOnce.Do(func() { + close(b.stop) + <-b.done + }) +} + +// The package-level batcher is guarded by batcherMu so Init / Shutdown cycles +// (the test pattern) don't race against UsageMiddleware reads. +var ( + batcherMu sync.RWMutex + batcher *usageBatcher +) + +func currentBatcher() *usageBatcher { + batcherMu.RLock() + defer batcherMu.RUnlock() + return batcher +} // InitUsageRecorder starts a background goroutine that periodically flushes -// accumulated usage records to the database. +// accumulated usage records to the database. Calling it more than once +// shuts down the previous batcher first so its goroutine doesn't leak. func InitUsageRecorder(db *gorm.DB) { if db == nil { return } - batcher = &usageBatcher{db: db} - go func() { - ticker := time.NewTicker(usageFlushInterval) - defer ticker.Stop() - for range ticker.C { - batcher.flush() - } - }() + + batcherMu.Lock() + old := batcher + batcher = nil + batcherMu.Unlock() + if old != nil { + old.shutdown() + } + + b := &usageBatcher{ + db: db, + stop: make(chan struct{}), + done: make(chan struct{}), + } + batcherMu.Lock() + batcher = b + batcherMu.Unlock() + + go b.run() +} + +// ShutdownUsageRecorder stops the background flusher and synchronously drains +// pending records once. Safe to call multiple times. Not yet wired into the +// application lifecycle; intended for graceful process exit and tests. +func ShutdownUsageRecorder() { + batcherMu.Lock() + b := batcher + batcher = nil + batcherMu.Unlock() + if b != nil { + b.shutdown() + } +} + +// FlushNow synchronously flushes any pending usage records. Intended for tests +// that need deterministic behaviour without waiting for the ticker. +func FlushNow() { + if b := currentBatcher(); b != nil { + b.flush() + } } // usageResponseBody is the minimal structure we need from the response JSON. @@ -84,7 +175,8 @@ type usageResponseBody struct { func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if db == nil || batcher == nil { + b := currentBatcher() + if db == nil || b == nil { return next(c) } @@ -149,9 +241,17 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { return handlerErr } + source := auth.GetSource(c) + if source == "" { + // Auth disabled or unrecognised path: classify as web so the row is still + // bucketable rather than silently dropped from per-source aggregates. + source = auth.UsageSourceWeb + } + record := &auth.UsageRecord{ UserID: user.ID, UserName: user.Name, + Source: source, Model: resp.Model, Endpoint: c.Request().URL.Path, PromptTokens: resp.Usage.PromptTokens, @@ -161,7 +261,13 @@ func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { CreatedAt: startTime, } - batcher.add(record) + if key := auth.GetAPIKey(c); key != nil { + id := key.ID + record.APIKeyID = &id + record.APIKeyName = key.Name + } + + b.add(record) return handlerErr } diff --git a/core/http/middleware/usage_test.go b/core/http/middleware/usage_test.go new file mode 100644 index 000000000..7db03a6ba --- /dev/null +++ b/core/http/middleware/usage_test.go @@ -0,0 +1,140 @@ +//go:build auth + +package middleware_test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + "github.com/mudler/LocalAI/core/http/middleware" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +// testAuthDB returns a fresh in-memory SQLite auth DB. +func testAuthDB() *gorm.DB { + db, err := auth.InitDB(":memory:") + if err != nil { + panic(err) + } + return db +} + +var _ = Describe("UsageMiddleware", func() { + var ( + e *echo.Echo + db *gorm.DB + ) + + BeforeEach(func() { + db = testAuthDB() + e = echo.New() + middleware.InitUsageRecorder(db) + }) + + AfterEach(func() { + middleware.ShutdownUsageRecorder() + }) + + okHandler := func(c echo.Context) error { + body, _ := json.Marshal(map[string]any{ + "model": "gpt-4", + "usage": map[string]int{ + "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15, + }, + }) + c.Response().Header().Set("Content-Type", "application/json") + c.Response().WriteHeader(http.StatusOK) + _, _ = c.Response().Write(body) + return nil + } + + // FlushNow drains pending records synchronously, replacing the 6s sleep + // that was previously needed to wait for the batcher's ticker. + flush := middleware.FlushNow + + It("records source=web when auth_source is web", func() { + e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"}) + c.Set("auth_source", auth.UsageSourceWeb) + return next(c) + } + }, middleware.UsageMiddleware(db)) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + e.ServeHTTP(httptest.NewRecorder(), req) + flush() + + var rec auth.UsageRecord + Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed()) + Expect(rec.Source).To(Equal(auth.UsageSourceWeb)) + Expect(rec.APIKeyID).To(BeNil()) + Expect(rec.APIKeyName).To(BeEmpty()) + }) + + It("records source=apikey with snapshotted name when auth_apikey is set", func() { + e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"}) + c.Set("auth_source", auth.UsageSourceAPIKey) + c.Set("auth_apikey", &auth.UserAPIKey{ID: "key-1", Name: "ci-runner"}) + return next(c) + } + }, middleware.UsageMiddleware(db)) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + e.ServeHTTP(httptest.NewRecorder(), req) + flush() + + var rec auth.UsageRecord + Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed()) + Expect(rec.Source).To(Equal(auth.UsageSourceAPIKey)) + Expect(rec.APIKeyID).ToNot(BeNil()) + Expect(*rec.APIKeyID).To(Equal("key-1")) + Expect(rec.APIKeyName).To(Equal("ci-runner")) + }) + + It("FlushNow drains pending records synchronously", func() { + e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("auth_user", &auth.User{ID: "carol", Name: "Carol"}) + c.Set("auth_source", auth.UsageSourceWeb) + return next(c) + } + }, middleware.UsageMiddleware(db)) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + e.ServeHTTP(httptest.NewRecorder(), req) + + // No sleep: FlushNow should drain immediately. + middleware.FlushNow() + + var rec auth.UsageRecord + Expect(db.Where("user_id = ?", "carol").First(&rec).Error).To(Succeed()) + Expect(rec.Source).To(Equal(auth.UsageSourceWeb)) + }) + + It("falls back to source=web when auth_source is empty", func() { + e.POST("/v1/chat/completions", okHandler, func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("auth_user", &auth.User{ID: "alice", Name: "Alice"}) + // no auth_source set + return next(c) + } + }, middleware.UsageMiddleware(db)) + + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + e.ServeHTTP(httptest.NewRecorder(), req) + flush() + + var rec auth.UsageRecord + Expect(db.Where("user_id = ?", "alice").First(&rec).Error).To(Succeed()) + Expect(rec.Source).To(Equal(auth.UsageSourceWeb)) + }) +}) diff --git a/core/http/react-ui/public/locales/en/admin.json b/core/http/react-ui/public/locales/en/admin.json index a76ef8df1..4b5ce0bb0 100644 --- a/core/http/react-ui/public/locales/en/admin.json +++ b/core/http/react-ui/public/locales/en/admin.json @@ -53,7 +53,29 @@ }, "usage": { "title": "Usage", - "subtitle": "API token usage statistics" + "subtitle": "API token usage statistics", + "sources": { + "tab": "Sources", + "mixTitle": "Source mix", + "ribbonAria": "{{apikey}}% API keys, {{web}}% Web UI, {{legacy}}% Legacy", + "topSources": "Top sources over time", + "searchPlaceholder": "Search by name or prefix", + "sortBy": "Sort", + "sortTokens": "Tokens", + "sortRequests": "Requests", + "sortLastUsed": "Last used", + "sortName": "Name", + "webUI": "Web UI", + "legacy": "Legacy", + "revoked": "revoked", + "filteredTo": "Filtered to: {{name}}", + "clearFilter": "Clear filter", + "other": "Other ({{count}})", + "noTrafficShort": "No requests in this period.", + "noKeysYet": "Once requests come in, you'll see them broken down here.", + "createKey": "Create your first API key", + "truncatedWarning": "Showing top 200 keys. Apply a filter to narrow further." + } }, "explorer": { "title": "Explorer", diff --git a/core/http/react-ui/src/pages/Usage.jsx b/core/http/react-ui/src/pages/Usage.jsx index 9d5b51f69..468c6dd07 100644 --- a/core/http/react-ui/src/pages/Usage.jsx +++ b/core/http/react-ui/src/pages/Usage.jsx @@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next' import { useAuth } from '../context/AuthContext' import { apiUrl } from '../utils/basePath' import LoadingSpinner from '../components/LoadingSpinner' +import SourcesTab from './Usage/SourcesTab' const PERIODS = [ { key: 'day', label: 'Day' }, @@ -724,23 +725,27 @@ export default function Usage() { {p.label} ))} +
+ {isAdmin && ( - <> -
- - - + )} +
) )} + + {activeTab === 'sources' && ( + + )} )}
diff --git a/core/http/react-ui/src/pages/Usage/SourceMixRibbon.jsx b/core/http/react-ui/src/pages/Usage/SourceMixRibbon.jsx new file mode 100644 index 000000000..007f54939 --- /dev/null +++ b/core/http/react-ui/src/pages/Usage/SourceMixRibbon.jsx @@ -0,0 +1,83 @@ +import { useTranslation } from 'react-i18next' + +const SEGMENT_COLORS = { + apikey: 'var(--color-primary)', + web: 'var(--color-info, #3b82f6)', + legacy: 'var(--color-warning, #f59e0b)', +} + +// SourceMixRibbon renders one segmented horizontal bar showing the share of +// tokens by source class (apikey / web / legacy). Clicking a segment invokes +// onSelectSourceClass with the segment key so the parent can filter the view. +// +// Props: +// bySource: { apikey?: {tokens, requests}, web?: {...}, legacy?: {...} } +// keyCount: number of distinct API keys in the dataset (for the legend) +// onSelectSourceClass: (cls: 'apikey'|'web'|'legacy') => void (optional) +export default function SourceMixRibbon({ bySource = {}, keyCount = 0, onSelectSourceClass }) { + const { t } = useTranslation('admin') + + const apikey = (bySource.apikey?.tokens) || 0 + const web = (bySource.web?.tokens) || 0 + const legacy = (bySource.legacy?.tokens) || 0 + const total = apikey + web + legacy || 1 + + const pct = (n) => Math.round((n / total) * 100) + const apiPct = pct(apikey) + const webPct = pct(web) + const legacyPct = pct(legacy) + + const segments = [ + { key: 'apikey', label: `${apiPct}% API keys (${keyCount})`, pct: apiPct, color: SEGMENT_COLORS.apikey }, + { key: 'web', label: `${webPct}% ${t('usage.sources.webUI')}`, pct: webPct, color: SEGMENT_COLORS.web }, + { key: 'legacy', label: `${legacyPct}% ${t('usage.sources.legacy')}`, pct: legacyPct, color: SEGMENT_COLORS.legacy }, + ].filter((s) => s.pct > 0) + + return ( +
+
+ {t('usage.sources.mixTitle')} +
+
+ {segments.map((s) => ( +
+
+ {segments.map((s) => ( + + + {s.label} + + ))} +
+
+ ) +} diff --git a/core/http/react-ui/src/pages/Usage/SourceTimeChart.jsx b/core/http/react-ui/src/pages/Usage/SourceTimeChart.jsx new file mode 100644 index 000000000..6472e6c61 --- /dev/null +++ b/core/http/react-ui/src/pages/Usage/SourceTimeChart.jsx @@ -0,0 +1,147 @@ +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' + +const TOP_N = 7 +// Distinct, accessible-ish series colors that read on both light and dark themes. +const SERIES_COLORS = [ + 'var(--color-primary)', + 'var(--color-success, #10b981)', + 'var(--color-warning, #f59e0b)', + 'var(--color-info, #3b82f6)', + 'var(--color-danger, #ef4444)', + '#a855f7', + '#ec4899', +] +const OTHER_COLOR = 'var(--color-text-muted, #94a3b8)' + +function identityFor(bucket) { + return bucket.api_key_id || bucket.source || 'unknown' +} + +// buckets: UsageBucket[] from /api/auth/usage/sources (server-sorted ASC by bucket) +// selectedKey: 'web' | 'legacy' | api_key_id | null +// totals: SourceTotals (for the "Other (count)" legend label) +export default function SourceTimeChart({ buckets = [], selectedKey, totals }) { + const { t } = useTranslation('admin') + + // Find the top-N identities by total tokens across the period. + const topIds = useMemo(() => { + const sums = new Map() + for (const b of buckets) { + const id = identityFor(b) + sums.set(id, (sums.get(id) || 0) + (b.total_tokens || 0)) + } + return [...sums.entries()] + .sort((a, b) => b[1] - a[1]) + .slice(0, TOP_N) + .map(([id]) => id) + }, [buckets]) + + const topSet = useMemo(() => new Set(topIds), [topIds]) + + // Resolve a display label for an identity (api_key_id -> snapshotted name, or source name). + const labelByIdentity = useMemo(() => { + const m = new Map() + for (const b of buckets) { + const id = identityFor(b) + if (m.has(id)) continue + if (b.source === 'web') { m.set(id, t('usage.sources.webUI')); continue } + if (b.source === 'legacy') { m.set(id, t('usage.sources.legacy')); continue } + m.set(id, b.api_key_name || b.api_key_id || id) + } + return m + }, [buckets, t]) + + // Build a dense per-bucket row, splitting top-N vs Other. + const series = useMemo(() => { + const byBucket = new Map() + for (const b of buckets) { + const id = identityFor(b) + const seriesId = topSet.has(id) ? id : '__other__' + const row = byBucket.get(b.bucket) || { bucket: b.bucket, total: 0 } + row[seriesId] = (row[seriesId] || 0) + (b.total_tokens || 0) + row.total += b.total_tokens || 0 + byBucket.set(b.bucket, row) + } + return [...byBucket.values()] + }, [buckets, topSet]) + + const max = useMemo( + () => series.reduce((m, r) => Math.max(m, r.total), 0) || 1, + [series] + ) + + const seriesIds = [...topIds, '__other__'] + const colorOf = (id) => + id === '__other__' + ? OTHER_COLOR + : SERIES_COLORS[topIds.indexOf(id) % SERIES_COLORS.length] + + const labelOfId = (id) => { + if (id === '__other__') return null // computed inline (need count) + return labelByIdentity.get(id) || id + } + + const otherCount = Math.max(0, (totals?.by_key?.length || 0) - TOP_N) + + // SVG geometry: 24px wide per bar (2px gap), 100px tall, viewBox stretches with bar count. + const barWidth = 20 + const barGap = 4 + const slotWidth = barWidth + barGap + const height = 100 + const width = Math.max(series.length * slotWidth, 200) + + return ( +
+
+ {t('usage.sources.topSources')} +
+ + + {series.map((row, i) => { + let y = height + return ( + + {seriesIds.map(id => { + const v = row[id] || 0 + if (!v) return null + const h = (v / max) * height + y -= h + const dim = selectedKey && selectedKey !== id ? 0.25 : 1 + const title = id === '__other__' + ? t('usage.sources.other', { count: otherCount }) + : labelOfId(id) + return ( + + {`${row.bucket} - ${title}: ${v.toLocaleString()}`} + + ) + })} + + ) + })} + + +
+ {seriesIds.map(id => ( + + + {id === '__other__' + ? t('usage.sources.other', { count: otherCount }) + : labelOfId(id)} + + ))} +
+
+ ) +} diff --git a/core/http/react-ui/src/pages/Usage/SourcesTab.jsx b/core/http/react-ui/src/pages/Usage/SourcesTab.jsx new file mode 100644 index 000000000..79e5e93d4 --- /dev/null +++ b/core/http/react-ui/src/pages/Usage/SourcesTab.jsx @@ -0,0 +1,175 @@ +import { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { usageApi, apiKeysApi } from '../../utils/api' +import { useAuth } from '../../context/AuthContext' +import LoadingSpinner from '../../components/LoadingSpinner' +import SourceMixRibbon from './SourceMixRibbon' +import SourcesTable from './SourcesTable' +import SourceTimeChart from './SourceTimeChart' + +const EMPTY_DATA = { + buckets: [], + totals: { by_source: {}, by_key: [], grand_total: { tokens: 0, requests: 0 } }, + truncated: false, +} + +// Resolve a human label for the currently selected key (web/legacy class or api_key_id). +function labelForSelected(totals, selectedKey, t) { + if (!selectedKey) return '' + if (selectedKey === 'web') return t('usage.sources.webUI') + if (selectedKey === 'legacy') return t('usage.sources.legacy') + const row = (totals?.by_key || []).find(k => k.api_key_id === selectedKey) + return row ? (row.api_key_name || selectedKey) : selectedKey +} + +// SourcesTab fetches and renders per-source / per-API-key usage breakdown. +// Task 10 replaces the raw JSON / list placeholders with SourceMixRibbon and +// SourcesTable. Task 11 will add the time chart and drill-in chip. +export default function SourcesTab({ period, adminUserId }) { + const { t } = useTranslation('admin') + const { isAdmin } = useAuth() + + const [data, setData] = useState(EMPTY_DATA) + const [loading, setLoading] = useState(false) + const [error, setError] = useState(null) + + const [selectedKey, setSelectedKey] = useState(null) + const [search, setSearch] = useState('') + const [sortKey, setSortKey] = useState('tokens') + + // Pull the current set of API key ids so the table can mark unknown keys as + // revoked. null = "don't know yet" so the table won't dim live keys during + // the fetch or after a failure. + const [existingKeyIds, setExistingKeyIds] = useState(null) + useEffect(() => { + apiKeysApi + .list() + .then((resp) => { + const list = Array.isArray(resp) ? resp : (resp?.keys || []) + setExistingKeyIds(new Set(list.map((k) => k.id))) + }) + .catch(() => { /* leave existingKeyIds null so revoked detection is skipped */ }) + }, []) + + useEffect(() => { + let cancelled = false + setLoading(true) + setError(null) + const p = isAdmin + ? usageApi.getAdminSources(period, adminUserId) + : usageApi.getMySources(period) + p + .then((d) => { if (!cancelled) setData(d || EMPTY_DATA) }) + .catch((e) => { if (!cancelled) setError(e) }) + .finally(() => { if (!cancelled) setLoading(false) }) + return () => { cancelled = true } + }, [isAdmin, period, adminUserId]) + + const totals = data.totals || EMPTY_DATA.totals + const buckets = data.buckets || EMPTY_DATA.buckets + const grandT = totals.grand_total || { tokens: 0, requests: 0 } + const truncated = data.truncated || false + + const isEmpty = !loading && (grandT.tokens || 0) === 0 && (grandT.requests || 0) === 0 + + if (loading) { + return ( +
+ +
+ ) + } + + if (error) { + return ( +
+
+

Failed to load

+

{String(error.message || error)}

+
+ ) + } + + if (isEmpty) { + return ( +
+
+

{t('usage.sources.noTrafficShort')}

+

{t('usage.sources.noKeysYet')}

+
+ ) + } + + return ( +
+
+ setSelectedKey(cls)} + /> +
+ + {selectedKey && ( +
+ + + {t('usage.sources.filteredTo', { name: labelForSelected(totals, selectedKey, t) })} + + +
+ )} + +
+ +
+ +
+ +
+ + {truncated && ( +
+ {t('usage.sources.truncatedWarning')} +
+ )} +
+ ) +} diff --git a/core/http/react-ui/src/pages/Usage/SourcesTable.jsx b/core/http/react-ui/src/pages/Usage/SourcesTable.jsx new file mode 100644 index 000000000..00abc5f75 --- /dev/null +++ b/core/http/react-ui/src/pages/Usage/SourcesTable.jsx @@ -0,0 +1,203 @@ +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' + +const SORT_FNS = { + tokens: (a, b) => (b.tokens || 0) - (a.tokens || 0), + requests: (a, b) => (b.requests || 0) - (a.requests || 0), + last_used: (a, b) => new Date(b.last_used || 0).getTime() - new Date(a.last_used || 0).getTime(), + name: (a, b) => (a.name || '').localeCompare(b.name || ''), +} + +function formatTokens(n) { + if (!n) return '0' + if (n >= 1_000_000) return (n / 1_000_000).toFixed(1) + 'M' + if (n >= 1_000) return (n / 1_000).toFixed(1) + 'k' + return String(n) +} + +function formatRelative(iso) { + if (!iso) return '-' + const t = new Date(iso).getTime() + if (Number.isNaN(t) || t <= 0) return '-' + const diff = Date.now() - t + if (diff < 60_000) return 'just now' + if (diff < 3_600_000) return Math.round(diff / 60_000) + 'm ago' + if (diff < 86_400_000) return Math.round(diff / 3_600_000) + 'h ago' + return Math.round(diff / 86_400_000) + 'd ago' +} + +// SourcesTable is the searchable, sortable list of key totals plus pseudo-rows +// for the web UI and legacy (unkeyed) source classes. Clicking a row selects +// it; the parent decides what to do with the selection (the drill-in panel +// will be wired in Task 11). +// +// Props: +// totals: SourceTotals payload (from /api/auth/usage/sources) +// selectedKey: currently-selected row id (api_key_id | 'web' | 'legacy' | null) +// onSelectKey: (id|null) => void +// search / setSearch: free-text filter state lifted to the parent +// sortKey / setSortKey: sort column state lifted to the parent +// existingKeyIds: Set of current (non-revoked) api key ids, or null +// when the parent hasn't yet learned which keys exist. Null suppresses the +// revoked badge entirely so live keys aren't dimmed during the fetch or +// after a failure. +export default function SourcesTable({ + totals, + selectedKey, + onSelectKey, + search, + setSearch, + sortKey, + setSortKey, + existingKeyIds = null, +}) { + const { t } = useTranslation('admin') + + const rows = useMemo(() => { + const named = (totals?.by_key || []).map((k) => ({ + kind: 'apikey', + id: k.api_key_id, + name: k.api_key_name || k.api_key_id, + prefix: '', + tokens: k.tokens, + requests: k.requests, + last_used: k.last_used, + revoked: existingKeyIds != null && !existingKeyIds.has(k.api_key_id), + })) + const web = totals?.by_source?.web + ? [{ + kind: 'web', + id: 'web', + name: t('usage.sources.webUI'), + prefix: '-', + tokens: totals.by_source.web.tokens, + requests: totals.by_source.web.requests, + }] + : [] + const leg = totals?.by_source?.legacy + ? [{ + kind: 'legacy', + id: 'legacy', + name: t('usage.sources.legacy'), + prefix: '-', + tokens: totals.by_source.legacy.tokens, + requests: totals.by_source.legacy.requests, + }] + : [] + return [...named, ...web, ...leg] + }, [totals, existingKeyIds, t]) + + const filtered = useMemo(() => { + const q = (search || '').trim().toLowerCase() + const list = q + ? rows.filter((r) => (r.name || '').toLowerCase().includes(q) || (r.prefix || '').toLowerCase().includes(q)) + : rows + return [...list].sort(SORT_FNS[sortKey] || SORT_FNS.tokens) + }, [rows, search, sortKey]) + + const iconFor = (kind) => + kind === 'apikey' ? 'fas fa-key' : kind === 'web' ? 'fas fa-globe' : 'fas fa-gear' + + return ( +
+
+ setSearch(e.target.value)} + placeholder={t('usage.sources.searchPlaceholder')} + aria-label={t('usage.sources.searchPlaceholder')} + style={{ + flex: '1 1 12rem', + minWidth: 160, + padding: 'var(--spacing-xs) var(--spacing-sm)', + border: '1px solid var(--color-border-subtle)', + borderRadius: 'var(--radius-sm)', + background: 'var(--color-bg-primary)', + color: 'var(--color-text-primary)', + }} + /> + +
+ +
+ + + + + + + + + + + + {filtered.map((r) => { + const isSel = selectedKey === r.id + return ( + onSelectKey?.(isSel ? null : r.id)} + style={{ + cursor: 'pointer', + background: isSel ? 'var(--color-bg-secondary)' : undefined, + opacity: r.revoked ? 0.5 : 1, + }} + > + + + + + + + ) + })} + +
{t('usage.sources.sortName')}Prefix{t('usage.sources.sortRequests')}{t('usage.sources.sortTokens')}{t('usage.sources.sortLastUsed')}
+ + + {r.name} + {r.revoked && ( + + ({t('usage.sources.revoked')}) + + )} + + {r.prefix || '-'} + {Number(r.requests || 0).toLocaleString()} + + {formatTokens(r.tokens || 0)} + + {formatRelative(r.last_used)} +
+
+
+ ) +} diff --git a/core/http/react-ui/src/utils/api.js b/core/http/react-ui/src/utils/api.js index 78f0b4f68..a8ffa2f04 100644 --- a/core/http/react-ui/src/utils/api.js +++ b/core/http/react-ui/src/utils/api.js @@ -422,6 +422,14 @@ export const usageApi = { if (userId) url += `&user_id=${encodeURIComponent(userId)}` return fetchJSON(url) }, + getMySources: (period) => + fetchJSON(`/api/auth/usage/sources?period=${period || 'month'}`), + getAdminSources: (period, userId, apiKeyId) => { + let url = `/api/auth/admin/usage/sources?period=${period || 'month'}` + if (userId) url += `&user_id=${encodeURIComponent(userId)}` + if (apiKeyId) url += `&api_key_id=${encodeURIComponent(apiKeyId)}` + return fetchJSON(url) + }, getMyQuotas: () => fetchJSON('/api/auth/quota'), } diff --git a/core/http/routes/auth.go b/core/http/routes/auth.go index 3f42adbf8..ef8372fff 100644 --- a/core/http/routes/auth.go +++ b/core/http/routes/auth.go @@ -789,6 +789,30 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { }) }) + // GET /api/auth/usage/sources - caller's per-source breakdown (no legacy) + e.GET("/api/auth/usage/sources", func(c echo.Context) error { + user := auth.GetUser(c) + if user == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + + period := c.QueryParam("period") + if period == "" { + period = "month" + } + + buckets, totals, err := auth.GetUserUsageBySource(db, user.ID, period) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to get usage"}) + } + + return c.JSON(http.StatusOK, map[string]any{ + "buckets": buckets, + "totals": totals, + "truncated": false, + }) + }) + // Admin endpoints adminMw := auth.RequireAdmin() @@ -1104,6 +1128,27 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { }) }, adminMw) + // GET /api/auth/admin/usage/sources - all users' per-source breakdown (admin only) + e.GET("/api/auth/admin/usage/sources", func(c echo.Context) error { + period := c.QueryParam("period") + if period == "" { + period = "month" + } + userID := c.QueryParam("user_id") + apiKeyID := c.QueryParam("api_key_id") + + buckets, totals, truncated, err := auth.GetAllUsageBySource(db, period, userID, apiKeyID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to get usage"}) + } + + return c.JSON(http.StatusOK, map[string]any{ + "buckets": buckets, + "totals": totals, + "truncated": truncated, + }) + }, adminMw) + // --- Invite management endpoints --- // POST /api/auth/admin/invites - create invite (admin only) diff --git a/core/http/routes/auth_test.go b/core/http/routes/auth_test.go index 89f5b4657..561d3fde9 100644 --- a/core/http/routes/auth_test.go +++ b/core/http/routes/auth_test.go @@ -286,6 +286,45 @@ func newTestAuthApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo return c.JSON(http.StatusOK, map[string]string{"message": "user deleted"}) }, adminMw) + // Mirror of production handler in routes/auth.go GET /api/auth/usage/sources. + // Keep this body in sync with the real handler; this test app cannot call + // RegisterAuthRoutes because it needs a *application.Application. + e.GET("/api/auth/usage/sources", func(c echo.Context) error { + user := auth.GetUser(c) + if user == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + period := c.QueryParam("period") + if period == "" { + period = "month" + } + buckets, totals, err := auth.GetUserUsageBySource(db, user.ID, period) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to get usage"}) + } + return c.JSON(http.StatusOK, map[string]any{ + "buckets": buckets, "totals": totals, "truncated": false, + }) + }) + + // Mirror of production handler in routes/auth.go GET /api/auth/admin/usage/sources. + // Keep this body in sync with the real handler. + e.GET("/api/auth/admin/usage/sources", func(c echo.Context) error { + period := c.QueryParam("period") + if period == "" { + period = "month" + } + userID := c.QueryParam("user_id") + apiKeyID := c.QueryParam("api_key_id") + buckets, totals, truncated, err := auth.GetAllUsageBySource(db, period, userID, apiKeyID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to get usage"}) + } + return c.JSON(http.StatusOK, map[string]any{ + "buckets": buckets, "totals": totals, "truncated": truncated, + }) + }, adminMw) + // Regular API endpoint for testing e.POST("/v1/chat/completions", func(c echo.Context) error { return c.String(http.StatusOK, "ok") @@ -931,4 +970,110 @@ var _ = Describe("Auth Routes", Label("auth"), func() { Expect(providers).To(ContainElement(auth.ProviderGitHub)) }) }) + + Describe("GET /api/auth/usage/sources", func() { + It("returns only the caller's data, never legacy", func() { + app := newTestAuthApp(db, appConfig) + + alice := createRouteTestUser(db, "alice@example.com", auth.RoleUser) + aliceToken, err := auth.CreateSession(db, alice.ID, "") + Expect(err).ToNot(HaveOccurred()) + + keyID := "k-alice" + now := time.Now() + Expect(auth.RecordUsage(db, &auth.UsageRecord{ + UserID: alice.ID, Source: auth.UsageSourceAPIKey, + APIKeyID: &keyID, APIKeyName: "alice-key", + Model: "gpt-4", TotalTokens: 100, CreatedAt: now, + })).To(Succeed()) + Expect(auth.RecordUsage(db, &auth.UsageRecord{ + UserID: alice.ID, Source: auth.UsageSourceWeb, + Model: "gpt-4", TotalTokens: 50, CreatedAt: now, + })).To(Succeed()) + Expect(auth.RecordUsage(db, &auth.UsageRecord{ + UserID: "legacy-api-key", Source: auth.UsageSourceLegacy, + Model: "gpt-4", TotalTokens: 30, CreatedAt: now, + })).To(Succeed()) + + rec := doAuthRequest(app, http.MethodGet, "/api/auth/usage/sources?period=month", nil, withSession(aliceToken)) + Expect(rec.Code).To(Equal(http.StatusOK)) + + var resp struct { + Buckets []auth.UsageBucket `json:"buckets"` + Totals auth.SourceTotals `json:"totals"` + Truncated bool `json:"truncated"` + } + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + _, hasLegacy := resp.Totals.BySource[auth.UsageSourceLegacy] + Expect(hasLegacy).To(BeFalse()) + Expect(resp.Totals.GrandTotal.Tokens).To(Equal(int64(150))) + Expect(resp.Truncated).To(BeFalse()) + }) + + It("returns 401 when unauthenticated", func() { + app := newTestAuthApp(db, appConfig) + + // Without a session cookie or bearer token, the global auth middleware + // should refuse the request before our handler runs. + rec := doAuthRequest(app, http.MethodGet, "/api/auth/usage/sources?period=month", nil) + Expect(rec.Code).To(Equal(http.StatusUnauthorized)) + }) + }) + + Describe("GET /api/auth/admin/usage/sources", func() { + It("returns 403 for non-admin", func() { + app := newTestAuthApp(db, appConfig) + + alice := createRouteTestUser(db, "alice@example.com", auth.RoleUser) + aliceToken, _ := auth.CreateSession(db, alice.ID, "") + + rec := doAuthRequest(app, http.MethodGet, "/api/auth/admin/usage/sources?period=month", nil, withSession(aliceToken)) + Expect(rec.Code).To(Equal(http.StatusForbidden)) + }) + + It("returns legacy bucket for admin and applies api_key_id filter", func() { + app := newTestAuthApp(db, appConfig) + + admin := createRouteTestUser(db, "admin@example.com", auth.RoleAdmin) + adminToken, _ := auth.CreateSession(db, admin.ID, "") + + k1 := "k1" + k2 := "k2" + now := time.Now() + Expect(auth.RecordUsage(db, &auth.UsageRecord{UserID: "alice", Source: auth.UsageSourceAPIKey, APIKeyID: &k1, APIKeyName: "ci", Model: "gpt-4", TotalTokens: 10, CreatedAt: now})).To(Succeed()) + Expect(auth.RecordUsage(db, &auth.UsageRecord{UserID: "alice", Source: auth.UsageSourceAPIKey, APIKeyID: &k2, APIKeyName: "lap", Model: "gpt-4", TotalTokens: 20, CreatedAt: now})).To(Succeed()) + Expect(auth.RecordUsage(db, &auth.UsageRecord{UserID: "legacy-api-key", Source: auth.UsageSourceLegacy, Model: "gpt-4", TotalTokens: 5, CreatedAt: now})).To(Succeed()) + + rec := doAuthRequest(app, http.MethodGet, + "/api/auth/admin/usage/sources?period=month&api_key_id=k2", nil, withSession(adminToken)) + Expect(rec.Code).To(Equal(http.StatusOK)) + + var resp struct { + Totals auth.SourceTotals `json:"totals"` + Truncated bool `json:"truncated"` + } + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + Expect(resp.Totals.GrandTotal.Tokens).To(Equal(int64(20))) + }) + + It("includes legacy in by_source for admin with no filter", func() { + app := newTestAuthApp(db, appConfig) + + admin := createRouteTestUser(db, "admin@example.com", auth.RoleAdmin) + adminToken, _ := auth.CreateSession(db, admin.ID, "") + + now := time.Now() + Expect(auth.RecordUsage(db, &auth.UsageRecord{UserID: "legacy-api-key", Source: auth.UsageSourceLegacy, Model: "gpt-4", TotalTokens: 7, CreatedAt: now})).To(Succeed()) + + rec := doAuthRequest(app, http.MethodGet, "/api/auth/admin/usage/sources?period=month", nil, withSession(adminToken)) + Expect(rec.Code).To(Equal(http.StatusOK)) + + var resp struct { + Totals auth.SourceTotals `json:"totals"` + } + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + Expect(resp.Totals.BySource).To(HaveKey(auth.UsageSourceLegacy)) + Expect(resp.Totals.BySource[auth.UsageSourceLegacy].Tokens).To(Equal(int64(7))) + }) + }) }) diff --git a/docs/content/features/authentication.md b/docs/content/features/authentication.md index 6c0d50637..35f3cc9ae 100644 --- a/docs/content/features/authentication.md +++ b/docs/content/features/authentication.md @@ -253,10 +253,12 @@ User API keys inherit the creating user's role. Admin keys grant admin access; u | `GET` | `/api/auth/api-keys` | List user's API keys | Yes | | `DELETE` | `/api/auth/api-keys/:id` | Revoke API key | Yes | | `GET` | `/api/auth/usage` | User's own usage stats | Yes | +| `GET` | `/api/auth/usage/sources` | User's own per-API-key / per-source breakdown | Yes | | `GET` | `/api/auth/admin/users` | List all users | Admin | | `PUT` | `/api/auth/admin/users/:id/role` | Change user role | Admin | | `DELETE` | `/api/auth/admin/users/:id` | Delete user | Admin | | `GET` | `/api/auth/admin/usage` | All users' usage stats | Admin | +| `GET` | `/api/auth/admin/usage/sources` | All users' per-API-key / per-source breakdown | Admin | | `POST` | `/api/auth/admin/invites` | Create invite link | Admin | | `GET` | `/api/auth/admin/invites` | List all invites | Admin | | `DELETE` | `/api/auth/admin/invites/:id` | Revoke unused invite | Admin | @@ -327,10 +329,79 @@ curl "http://localhost:8080/api/auth/admin/usage?period=month&user_id=" ### Usage Dashboard The web UI Usage page provides: -- **Period selector** — switch between day, week, month, and all-time views -- **Summary cards** — total requests, prompt tokens, completion tokens, total tokens -- **By Model table** — per-model breakdown with visual usage bars -- **By User table** (admin only) — per-user breakdown across all models +- **Period selector** - switch between day, week, month, and all-time views +- **Summary cards** - total requests, prompt tokens, completion tokens, total tokens +- **By Model table** - per-model breakdown with visual usage bars +- **By User table** (admin only) - per-user breakdown across all models +- **Sources tab** - per-API-key and per-source breakdown (described below) + +### Per-API-key Breakdown + +The **Sources** tab on the Usage page surfaces a third dimension of the same data: traffic broken down by API key and by request source. Three source classes are tracked: + +- **API key** - request authenticated with a named user API key (`Authorization: Bearer lai-...`, `x-api-key`, or `token` cookie). Each key shows up with its label (snapshotted at write time, so revoked keys still display the original name). +- **Web UI** - request authenticated with a browser session cookie. +- **Legacy** - request authenticated with an env-configured `LOCALAI_API_KEY`. Visible to admins only. + +The Sources tab is visible to every authenticated user. Non-admins see only their own keys plus their own Web UI traffic (legacy is filtered server-side). Admins see every key from every user. + +The tab is laid out as: + +- A **source mix ribbon** showing the percentage split across the three classes. +- A **top-N + Other stacked time chart** (top 7 sources by total tokens; the rest roll up). +- A **searchable, sortable table** of every key plus the Web UI and Legacy pseudo-rows. Click a row to filter the chart to that source. + +#### Endpoints + +| Method | Path | Auth | Description | +|--------|------|------|-------------| +| `GET` | `/api/auth/usage/sources` | Self | Caller's per-source breakdown. Excludes legacy. | +| `GET` | `/api/auth/admin/usage/sources` | Admin | All users' per-source breakdown. Accepts `user_id` and `api_key_id` filters. Includes legacy. | + +Both endpoints accept the same `period` parameter (`day`, `week`, `month`, `all`) as `/api/auth/usage`. + +```bash +# Your own per-source usage for the last week +curl "http://localhost:8080/api/auth/usage/sources?period=week" \ + -H "Authorization: Bearer " + +# Admin: filter to a single API key across all users +curl "http://localhost:8080/api/auth/admin/usage/sources?period=month&api_key_id=" \ + -H "Authorization: Bearer " +``` + +**Response shape:** + +```json +{ + "buckets": [ + { "bucket": "2026-05-19", "source": "apikey", + "api_key_id": "uuid", "api_key_name": "ci-runner", + "total_tokens": 20000, "request_count": 142, "...": "..." }, + { "bucket": "2026-05-19", "source": "web", + "total_tokens": 300, "request_count": 11, "...": "..." } + ], + "totals": { + "by_source": { + "apikey": { "tokens": 1234567, "requests": 8420 }, + "web": { "tokens": 92000, "requests": 211 } + }, + "by_key": [ + { "api_key_id": "uuid", "api_key_name": "ci-runner", + "tokens": 2100000, "requests": 8420, + "last_used": "2026-05-20T12:34:56Z" } + ], + "grand_total": { "tokens": 1334777, "requests": 8645 } + }, + "truncated": false +} +``` + +The `by_key` list is server-sorted by tokens descending and capped at 200 entries. When more keys would qualify, the response sets `"truncated": true` so the UI can show a notice. + +#### Migration of pre-feature data + +Usage rows recorded before this feature have no `source` column. On startup, `InitDB` backfills them as `legacy` when the synthetic `legacy-api-key` user_id was used, and `web` for everything else. The migration is idempotent; existing aggregations remain correct after the upgrade. ## Combining Auth Modes