From bc615cb013e5d8c47f3567e279e9b9d675515fee Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 20 May 2026 22:25:52 +0000 Subject: [PATCH] feat(usage): add GetUserUsageBySource aggregator Groups by (bucket, source, api_key_id, api_key_name). Filters out legacy by default. Returns both per-bucket detail and roll-ups (by_source, by_key sorted desc and capped at 200, grand_total). The MAX(created_at) projection is iterated via Rows().Scan into a string column and parsed manually because the SQLite driver surfaces the aggregated timestamp as a string, which database/sql refuses to scan directly into time.Time. Postgres returns a real timestamp; the same string path handles its RFC3339 form too. Refs: #9862 Signed-off-by: Ettore Di Giacinto --- core/http/auth/usage.go | 160 +++++++++++++++++++++++++++++++++++ core/http/auth/usage_test.go | 52 ++++++++++++ 2 files changed, 212 insertions(+) diff --git a/core/http/auth/usage.go b/core/http/auth/usage.go index 99cecef77..e9d454e0e 100644 --- a/core/http/auth/usage.go +++ b/core/http/auth/usage.go @@ -190,3 +190,163 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) { } return buckets, nil } + +// TotalsEntry is a token+request roll-up. +type TotalsEntry struct { + Tokens int64 `json:"tokens"` + Requests int64 `json:"requests"` +} + +// KeyTotal is the per-key roll-up returned by sources endpoints. +type KeyTotal struct { + APIKeyID string `json:"api_key_id"` + APIKeyName string `json:"api_key_name"` + Tokens int64 `json:"tokens"` + Requests int64 `json:"requests"` + LastUsed time.Time `json:"last_used"` +} + +// SourceTotals summarises a per-source breakdown. +type SourceTotals struct { + BySource map[string]TotalsEntry `json:"by_source"` + ByKey []KeyTotal `json:"by_key"` // server-sorted desc by tokens, capped + GrandTotal TotalsEntry `json:"grand_total"` +} + +const maxKeyTotals = 200 + +// GetUserUsageBySource returns per-source aggregated usage for one user. Legacy +// is excluded by design (visible to admins only via the admin variant). +func GetUserUsageBySource(db *gorm.DB, userID, period string) ([]UsageBucket, SourceTotals, error) { + sqlite := isSQLiteDB(db) + since, dateFmt := periodToWindow(period, sqlite) + bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) + + query := db.Model(&UsageRecord{}). + Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+ + "SUM(prompt_tokens) as prompt_tokens, "+ + "SUM(completion_tokens) as completion_tokens, "+ + "SUM(total_tokens) as total_tokens, "+ + "COUNT(*) as request_count"). + Where("user_id = ?", userID). + Where("source <> ?", UsageSourceLegacy). + Group("bucket, source, api_key_id, api_key_name"). + Order("bucket ASC") + + if !since.IsZero() { + query = query.Where("created_at >= ?", since) + } + + var buckets []UsageBucket + if err := query.Find(&buckets).Error; err != nil { + return nil, SourceTotals{}, err + } + + totals := computeSourceTotals(db, userID, "", since, false) + return buckets, totals, nil +} + +// computeSourceTotals rolls up by_source / by_key / grand_total. +// userID/apiKeyID are optional filters. includeLegacy controls whether the +// legacy bucket is exposed (admin-only). +func computeSourceTotals(db *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) SourceTotals { + totals := SourceTotals{BySource: map[string]TotalsEntry{}} + + bySourceQ := db.Model(&UsageRecord{}). + Select("source, SUM(total_tokens) as tokens, COUNT(*) as requests"). + Group("source") + bySourceQ = applyFilters(bySourceQ, userID, apiKeyID, since, includeLegacy) + + var bySourceRows []struct { + Source string + Tokens int64 + Requests int64 + } + if err := bySourceQ.Scan(&bySourceRows).Error; err != nil { + return totals + } + for _, r := range bySourceRows { + totals.BySource[r.Source] = TotalsEntry{Tokens: r.Tokens, Requests: r.Requests} + totals.GrandTotal.Tokens += r.Tokens + totals.GrandTotal.Requests += r.Requests + } + + byKeyQ := db.Model(&UsageRecord{}). + Select("COALESCE(api_key_id, '') as api_key_id, api_key_name, "+ + "SUM(total_tokens) as tokens, COUNT(*) as requests, MAX(created_at) as last_used"). + Where("api_key_id IS NOT NULL AND api_key_id <> ''"). + Group("api_key_id, api_key_name"). + Order("tokens DESC"). + Limit(maxKeyTotals) + byKeyQ = applyFilters(byKeyQ, userID, apiKeyID, since, includeLegacy) + + // Iterate Rows() manually because MAX(created_at) is returned as a string by + // the SQLite driver, and Go's database/sql refuses to scan that into + // *time.Time. Postgres returns a proper timestamp. We accept both shapes + // via a Rows.Scan into a string column, then parse uniformly. + rows, err := byKeyQ.Rows() + if err == nil { + defer rows.Close() + out := make([]KeyTotal, 0) + for rows.Next() { + var ( + apiKeyID, apiKeyName, lastUsedRaw string + tokens, requests int64 + ) + if scanErr := rows.Scan(&apiKeyID, &apiKeyName, &tokens, &requests, &lastUsedRaw); scanErr != nil { + continue + } + out = append(out, KeyTotal{ + APIKeyID: apiKeyID, + APIKeyName: apiKeyName, + Tokens: tokens, + Requests: requests, + LastUsed: parseLastUsedString(lastUsedRaw), + }) + } + totals.ByKey = out + } + + return totals +} + +// parseLastUsedString converts the textual MAX(created_at) value returned by +// SQLite (or any driver that surfaces the timestamp as a string) into a +// time.Time. Returns the zero time on parse failure. +func parseLastUsedString(s string) time.Time { + if s == "" { + return time.Time{} + } + // GORM's SQLite driver emits Go's default time formatting. Try the formats + // it commonly produces, falling back to RFC3339Nano. + layouts := []string{ + "2006-01-02 15:04:05.999999999 -0700 MST", + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02 15:04:05", + time.RFC3339Nano, + time.RFC3339, + } + for _, layout := range layouts { + if t, err := time.Parse(layout, s); err == nil { + return t + } + } + return time.Time{} +} + +func applyFilters(q *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) *gorm.DB { + if userID != "" { + q = q.Where("user_id = ?", userID) + } + if apiKeyID != "" { + q = q.Where("api_key_id = ?", apiKeyID) + } + if !since.IsZero() { + q = q.Where("created_at >= ?", since) + } + if !includeLegacy { + q = q.Where("source <> ?", UsageSourceLegacy) + } + return q +} diff --git a/core/http/auth/usage_test.go b/core/http/auth/usage_test.go index 41ff1bd65..df021f7b1 100644 --- a/core/http/auth/usage_test.go +++ b/core/http/auth/usage_test.go @@ -8,6 +8,7 @@ import ( "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "gorm.io/gorm" ) var _ = Describe("Usage", func() { @@ -243,4 +244,55 @@ var _ = Describe("Usage", func() { Expect(loaded.APIKeyName).To(BeEmpty()) }) }) + + Describe("GetUserUsageBySource", func() { + insert := func(db *gorm.DB, userID, source, keyID, keyName string, tokens int64, when time.Time) { + rec := &auth.UsageRecord{ + UserID: userID, + Source: source, + Model: "gpt-4", + TotalTokens: tokens, + CreatedAt: when, + } + if keyID != "" { + rec.APIKeyID = &keyID + rec.APIKeyName = keyName + } + Expect(auth.RecordUsage(db, rec)).To(Succeed()) + } + + It("returns only the caller's rows, never legacy", func() { + db := testDB() + now := time.Now() + insert(db, "alice", auth.UsageSourceAPIKey, "k1", "ci", 100, now) + insert(db, "alice", auth.UsageSourceWeb, "", "", 50, now) + insert(db, "alice", auth.UsageSourceLegacy, "", "", 30, now) + insert(db, "bob", auth.UsageSourceAPIKey, "k2", "bobk", 90, now) + + buckets, totals, err := auth.GetUserUsageBySource(db, "alice", "month") + Expect(err).ToNot(HaveOccurred()) + + for _, b := range buckets { + Expect(b.UserID).To(Or(BeEmpty(), Equal("alice"))) + Expect(b.Source).ToNot(Equal(auth.UsageSourceLegacy)) + } + + Expect(totals.GrandTotal.Tokens).To(Equal(int64(150))) + Expect(totals.BySource[auth.UsageSourceAPIKey].Tokens).To(Equal(int64(100))) + Expect(totals.BySource[auth.UsageSourceWeb].Tokens).To(Equal(int64(50))) + _, hasLegacy := totals.BySource[auth.UsageSourceLegacy] + Expect(hasLegacy).To(BeFalse()) + }) + + It("snapshots survive key deletion", func() { + db := testDB() + now := time.Now() + insert(db, "alice", auth.UsageSourceAPIKey, "deleted-key", "old-name", 42, now) + _, totals, err := auth.GetUserUsageBySource(db, "alice", "month") + Expect(err).ToNot(HaveOccurred()) + Expect(totals.ByKey).To(HaveLen(1)) + Expect(totals.ByKey[0].APIKeyName).To(Equal("old-name")) + Expect(totals.ByKey[0].APIKeyID).To(Equal("deleted-key")) + }) + }) })