feat(usage): add GetUserUsageBySource aggregator

Groups by (bucket, source, api_key_id, api_key_name). Filters out
legacy by default. Returns both per-bucket detail and roll-ups
(by_source, by_key sorted desc and capped at 200, grand_total).

The MAX(created_at) projection is iterated via Rows().Scan into a
string column and parsed manually because the SQLite driver surfaces
the aggregated timestamp as a string, which database/sql refuses to
scan directly into time.Time. Postgres returns a real timestamp; the
same string path handles its RFC3339 form too.

Refs: #9862
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-05-20 22:25:52 +00:00
parent e63a0e8fd7
commit bc615cb013
2 changed files with 212 additions and 0 deletions

View File

@@ -190,3 +190,163 @@ func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
}
return buckets, nil
}
// TotalsEntry is a token+request roll-up.
type TotalsEntry struct {
Tokens int64 `json:"tokens"`
Requests int64 `json:"requests"`
}
// KeyTotal is the per-key roll-up returned by sources endpoints.
type KeyTotal struct {
APIKeyID string `json:"api_key_id"`
APIKeyName string `json:"api_key_name"`
Tokens int64 `json:"tokens"`
Requests int64 `json:"requests"`
LastUsed time.Time `json:"last_used"`
}
// SourceTotals summarises a per-source breakdown.
type SourceTotals struct {
BySource map[string]TotalsEntry `json:"by_source"`
ByKey []KeyTotal `json:"by_key"` // server-sorted desc by tokens, capped
GrandTotal TotalsEntry `json:"grand_total"`
}
const maxKeyTotals = 200
// GetUserUsageBySource returns per-source aggregated usage for one user. Legacy
// is excluded by design (visible to admins only via the admin variant).
func GetUserUsageBySource(db *gorm.DB, userID, period string) ([]UsageBucket, SourceTotals, error) {
sqlite := isSQLiteDB(db)
since, dateFmt := periodToWindow(period, sqlite)
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}).
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
"SUM(prompt_tokens) as prompt_tokens, "+
"SUM(completion_tokens) as completion_tokens, "+
"SUM(total_tokens) as total_tokens, "+
"COUNT(*) as request_count").
Where("user_id = ?", userID).
Where("source <> ?", UsageSourceLegacy).
Group("bucket, source, api_key_id, api_key_name").
Order("bucket ASC")
if !since.IsZero() {
query = query.Where("created_at >= ?", since)
}
var buckets []UsageBucket
if err := query.Find(&buckets).Error; err != nil {
return nil, SourceTotals{}, err
}
totals := computeSourceTotals(db, userID, "", since, false)
return buckets, totals, nil
}
// computeSourceTotals rolls up by_source / by_key / grand_total.
// userID/apiKeyID are optional filters. includeLegacy controls whether the
// legacy bucket is exposed (admin-only).
func computeSourceTotals(db *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) SourceTotals {
totals := SourceTotals{BySource: map[string]TotalsEntry{}}
bySourceQ := db.Model(&UsageRecord{}).
Select("source, SUM(total_tokens) as tokens, COUNT(*) as requests").
Group("source")
bySourceQ = applyFilters(bySourceQ, userID, apiKeyID, since, includeLegacy)
var bySourceRows []struct {
Source string
Tokens int64
Requests int64
}
if err := bySourceQ.Scan(&bySourceRows).Error; err != nil {
return totals
}
for _, r := range bySourceRows {
totals.BySource[r.Source] = TotalsEntry{Tokens: r.Tokens, Requests: r.Requests}
totals.GrandTotal.Tokens += r.Tokens
totals.GrandTotal.Requests += r.Requests
}
byKeyQ := db.Model(&UsageRecord{}).
Select("COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
"SUM(total_tokens) as tokens, COUNT(*) as requests, MAX(created_at) as last_used").
Where("api_key_id IS NOT NULL AND api_key_id <> ''").
Group("api_key_id, api_key_name").
Order("tokens DESC").
Limit(maxKeyTotals)
byKeyQ = applyFilters(byKeyQ, userID, apiKeyID, since, includeLegacy)
// Iterate Rows() manually because MAX(created_at) is returned as a string by
// the SQLite driver, and Go's database/sql refuses to scan that into
// *time.Time. Postgres returns a proper timestamp. We accept both shapes
// via a Rows.Scan into a string column, then parse uniformly.
rows, err := byKeyQ.Rows()
if err == nil {
defer rows.Close()
out := make([]KeyTotal, 0)
for rows.Next() {
var (
apiKeyID, apiKeyName, lastUsedRaw string
tokens, requests int64
)
if scanErr := rows.Scan(&apiKeyID, &apiKeyName, &tokens, &requests, &lastUsedRaw); scanErr != nil {
continue
}
out = append(out, KeyTotal{
APIKeyID: apiKeyID,
APIKeyName: apiKeyName,
Tokens: tokens,
Requests: requests,
LastUsed: parseLastUsedString(lastUsedRaw),
})
}
totals.ByKey = out
}
return totals
}
// parseLastUsedString converts the textual MAX(created_at) value returned by
// SQLite (or any driver that surfaces the timestamp as a string) into a
// time.Time. Returns the zero time on parse failure.
func parseLastUsedString(s string) time.Time {
if s == "" {
return time.Time{}
}
// GORM's SQLite driver emits Go's default time formatting. Try the formats
// it commonly produces, falling back to RFC3339Nano.
layouts := []string{
"2006-01-02 15:04:05.999999999 -0700 MST",
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02 15:04:05",
time.RFC3339Nano,
time.RFC3339,
}
for _, layout := range layouts {
if t, err := time.Parse(layout, s); err == nil {
return t
}
}
return time.Time{}
}
func applyFilters(q *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) *gorm.DB {
if userID != "" {
q = q.Where("user_id = ?", userID)
}
if apiKeyID != "" {
q = q.Where("api_key_id = ?", apiKeyID)
}
if !since.IsZero() {
q = q.Where("created_at >= ?", since)
}
if !includeLegacy {
q = q.Where("source <> ?", UsageSourceLegacy)
}
return q
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)
var _ = Describe("Usage", func() {
@@ -243,4 +244,55 @@ var _ = Describe("Usage", func() {
Expect(loaded.APIKeyName).To(BeEmpty())
})
})
Describe("GetUserUsageBySource", func() {
insert := func(db *gorm.DB, userID, source, keyID, keyName string, tokens int64, when time.Time) {
rec := &auth.UsageRecord{
UserID: userID,
Source: source,
Model: "gpt-4",
TotalTokens: tokens,
CreatedAt: when,
}
if keyID != "" {
rec.APIKeyID = &keyID
rec.APIKeyName = keyName
}
Expect(auth.RecordUsage(db, rec)).To(Succeed())
}
It("returns only the caller's rows, never legacy", func() {
db := testDB()
now := time.Now()
insert(db, "alice", auth.UsageSourceAPIKey, "k1", "ci", 100, now)
insert(db, "alice", auth.UsageSourceWeb, "", "", 50, now)
insert(db, "alice", auth.UsageSourceLegacy, "", "", 30, now)
insert(db, "bob", auth.UsageSourceAPIKey, "k2", "bobk", 90, now)
buckets, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
Expect(err).ToNot(HaveOccurred())
for _, b := range buckets {
Expect(b.UserID).To(Or(BeEmpty(), Equal("alice")))
Expect(b.Source).ToNot(Equal(auth.UsageSourceLegacy))
}
Expect(totals.GrandTotal.Tokens).To(Equal(int64(150)))
Expect(totals.BySource[auth.UsageSourceAPIKey].Tokens).To(Equal(int64(100)))
Expect(totals.BySource[auth.UsageSourceWeb].Tokens).To(Equal(int64(50)))
_, hasLegacy := totals.BySource[auth.UsageSourceLegacy]
Expect(hasLegacy).To(BeFalse())
})
It("snapshots survive key deletion", func() {
db := testDB()
now := time.Now()
insert(db, "alice", auth.UsageSourceAPIKey, "deleted-key", "old-name", 42, now)
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
Expect(err).ToNot(HaveOccurred())
Expect(totals.ByKey).To(HaveLen(1))
Expect(totals.ByKey[0].APIKeyName).To(Equal("old-name"))
Expect(totals.ByKey[0].APIKeyID).To(Equal("deleted-key"))
})
})
})