Files
LocalAI/core/http/auth/usage_test.go
LocalAI [bot] f0cb02afb8 feat(usage): attribute Sources rows to user accounts in admin view (#9935)
The merged feature (#9920) let admins see per-API-key and per-source
totals but did not surface which user owned each key, and lumped
every user's Web UI traffic into a single global Web UI row. This
makes the admin Sources tab properly per-user attributable:

- KeyTotal gains UserID + UserName, populated from the snapshot the
  usage middleware already records. The by_key roll-up now groups by
  (api_key_id, api_key_name, user_id, user_name).
- New SourceTotals.ByUserSource roll-up groups (source, user_id,
  user_name) for sources without a key identity (web, legacy). Only
  populated on the admin path (includeLegacy=true); the non-admin
  endpoint stays unchanged for backwards compatibility.
- SourcesTable accepts showUserColumn={isAdmin}; admin view renders
  a User column, makes the search match user name/id, and expands
  Web UI / legacy pseudo-rows from the global aggregate to one row
  per user using by_user_source.

Refs: #9862

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-05-21 23:23:06 +02:00

435 lines
14 KiB
Go

//go:build auth
package auth_test
import (
"fmt"
"time"
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)
var _ = Describe("Usage", func() {
Describe("RecordUsage", func() {
It("inserts a usage record", func() {
db := testDB()
record := &auth.UsageRecord{
UserID: "user-1",
UserName: "Test User",
Model: "gpt-4",
Endpoint: "/v1/chat/completions",
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
Duration: 1200,
CreatedAt: time.Now(),
}
err := auth.RecordUsage(db, record)
Expect(err).ToNot(HaveOccurred())
Expect(record.ID).ToNot(BeZero())
})
})
Describe("GetUserUsage", func() {
It("returns aggregated usage for a specific user", func() {
db := testDB()
// Insert records for two users
for range 3 {
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-a",
UserName: "Alice",
Model: "gpt-4",
Endpoint: "/v1/chat/completions",
PromptTokens: 100,
TotalTokens: 150,
CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
}
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-b",
UserName: "Bob",
Model: "gpt-4",
PromptTokens: 200,
TotalTokens: 300,
CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
buckets, err := auth.GetUserUsage(db, "user-a", "month")
Expect(err).ToNot(HaveOccurred())
Expect(buckets).ToNot(BeEmpty())
// All returned buckets should be for user-a's model
totalPrompt := int64(0)
for _, b := range buckets {
totalPrompt += b.PromptTokens
}
Expect(totalPrompt).To(Equal(int64(300)))
})
It("filters by period", func() {
db := testDB()
// Record in the past (beyond day window)
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-c",
UserName: "Carol",
Model: "gpt-4",
PromptTokens: 100,
TotalTokens: 100,
CreatedAt: time.Now().Add(-48 * time.Hour),
})
Expect(err).ToNot(HaveOccurred())
// Record now
err = auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-c",
UserName: "Carol",
Model: "gpt-4",
PromptTokens: 200,
TotalTokens: 200,
CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
// Day period should only include recent record
buckets, err := auth.GetUserUsage(db, "user-c", "day")
Expect(err).ToNot(HaveOccurred())
totalPrompt := int64(0)
for _, b := range buckets {
totalPrompt += b.PromptTokens
}
Expect(totalPrompt).To(Equal(int64(200)))
// Month period should include both
buckets, err = auth.GetUserUsage(db, "user-c", "month")
Expect(err).ToNot(HaveOccurred())
totalPrompt = 0
for _, b := range buckets {
totalPrompt += b.PromptTokens
}
Expect(totalPrompt).To(Equal(int64(300)))
})
})
Describe("GetAllUsage", func() {
It("returns usage for all users", func() {
db := testDB()
for _, uid := range []string{"user-x", "user-y"} {
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: uid,
UserName: uid,
Model: "gpt-4",
PromptTokens: 100,
TotalTokens: 150,
CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
}
buckets, err := auth.GetAllUsage(db, "month", "")
Expect(err).ToNot(HaveOccurred())
Expect(len(buckets)).To(BeNumerically(">=", 2))
})
It("filters by user ID when specified", func() {
db := testDB()
err := auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-p", UserName: "Pat", Model: "gpt-4",
PromptTokens: 100, TotalTokens: 100, CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
err = auth.RecordUsage(db, &auth.UsageRecord{
UserID: "user-q", UserName: "Quinn", Model: "gpt-4",
PromptTokens: 200, TotalTokens: 200, CreatedAt: time.Now(),
})
Expect(err).ToNot(HaveOccurred())
buckets, err := auth.GetAllUsage(db, "month", "user-p")
Expect(err).ToNot(HaveOccurred())
for _, b := range buckets {
Expect(b.UserID).To(Equal("user-p"))
}
})
})
Describe("Usage source backfill", func() {
It("backfills 'web' for pre-feature rows", func() {
db := testDB()
rawDB, err := db.DB()
Expect(err).ToNot(HaveOccurred())
_, err = rawDB.Exec(
`INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`,
"user-x", "gpt-4", time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(auth.BackfillUsageSource(db)).To(Succeed())
var loaded auth.UsageRecord
Expect(db.Where("user_id = ?", "user-x").First(&loaded).Error).To(Succeed())
Expect(loaded.Source).To(Equal(auth.UsageSourceWeb))
})
It("backfills 'legacy' for pre-feature rows with legacy-api-key user_id", func() {
db := testDB()
rawDB, err := db.DB()
Expect(err).ToNot(HaveOccurred())
_, err = rawDB.Exec(
`INSERT INTO usage_records (user_id, source, model, created_at, total_tokens, prompt_tokens, completion_tokens, duration) VALUES (?, '', ?, ?, 0, 0, 0, 0)`,
"legacy-api-key", "gpt-4", time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(auth.BackfillUsageSource(db)).To(Succeed())
var loaded auth.UsageRecord
Expect(db.Where("user_id = ?", "legacy-api-key").First(&loaded).Error).To(Succeed())
Expect(loaded.Source).To(Equal(auth.UsageSourceLegacy))
})
It("is idempotent on re-run", func() {
db := testDB()
Expect(auth.BackfillUsageSource(db)).To(Succeed())
Expect(auth.BackfillUsageSource(db)).To(Succeed())
})
})
Describe("UsageRecord with source fields", func() {
It("persists Source, APIKeyID, APIKeyName", func() {
db := testDB()
keyID := "key-uuid-1"
record := &auth.UsageRecord{
UserID: "user-1",
UserName: "Test User",
Source: auth.UsageSourceAPIKey,
APIKeyID: &keyID,
APIKeyName: "ci-runner",
Model: "gpt-4",
Endpoint: "/v1/chat/completions",
TotalTokens: 150,
CreatedAt: time.Now(),
}
Expect(auth.RecordUsage(db, record)).To(Succeed())
var loaded auth.UsageRecord
Expect(db.First(&loaded, record.ID).Error).To(Succeed())
Expect(loaded.Source).To(Equal(auth.UsageSourceAPIKey))
Expect(loaded.APIKeyID).ToNot(BeNil())
Expect(*loaded.APIKeyID).To(Equal("key-uuid-1"))
Expect(loaded.APIKeyName).To(Equal("ci-runner"))
})
It("allows nil APIKeyID for web/legacy sources", func() {
db := testDB()
record := &auth.UsageRecord{
UserID: "user-1",
Source: auth.UsageSourceWeb,
Model: "gpt-4",
CreatedAt: time.Now(),
}
Expect(auth.RecordUsage(db, record)).To(Succeed())
var loaded auth.UsageRecord
Expect(db.First(&loaded, record.ID).Error).To(Succeed())
Expect(loaded.Source).To(Equal(auth.UsageSourceWeb))
Expect(loaded.APIKeyID).To(BeNil())
Expect(loaded.APIKeyName).To(BeEmpty())
})
})
Describe("GetUserUsageBySource", func() {
insert := func(db *gorm.DB, userID, source, keyID, keyName string, tokens int64, when time.Time) {
rec := &auth.UsageRecord{
UserID: userID,
Source: source,
Model: "gpt-4",
TotalTokens: tokens,
CreatedAt: when,
}
if keyID != "" {
rec.APIKeyID = &keyID
rec.APIKeyName = keyName
}
Expect(auth.RecordUsage(db, rec)).To(Succeed())
}
It("returns only the caller's rows, never legacy", func() {
db := testDB()
now := time.Now()
insert(db, "alice", auth.UsageSourceAPIKey, "k1", "ci", 100, now)
insert(db, "alice", auth.UsageSourceWeb, "", "", 50, now)
insert(db, "alice", auth.UsageSourceLegacy, "", "", 30, now)
insert(db, "bob", auth.UsageSourceAPIKey, "k2", "bobk", 90, now)
buckets, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
Expect(err).ToNot(HaveOccurred())
for _, b := range buckets {
Expect(b.UserID).To(Or(BeEmpty(), Equal("alice")))
Expect(b.Source).ToNot(Equal(auth.UsageSourceLegacy))
}
Expect(totals.GrandTotal.Tokens).To(Equal(int64(150)))
Expect(totals.BySource[auth.UsageSourceAPIKey].Tokens).To(Equal(int64(100)))
Expect(totals.BySource[auth.UsageSourceWeb].Tokens).To(Equal(int64(50)))
_, hasLegacy := totals.BySource[auth.UsageSourceLegacy]
Expect(hasLegacy).To(BeFalse())
})
It("snapshots survive key deletion", func() {
db := testDB()
now := time.Now()
insert(db, "alice", auth.UsageSourceAPIKey, "deleted-key", "old-name", 42, now)
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
Expect(err).ToNot(HaveOccurred())
Expect(totals.ByKey).To(HaveLen(1))
Expect(totals.ByKey[0].APIKeyName).To(Equal("old-name"))
Expect(totals.ByKey[0].APIKeyID).To(Equal("deleted-key"))
Expect(totals.ByKey[0].LastUsed).ToNot(BeZero())
Expect(totals.ByKey[0].LastUsed).To(BeTemporally("~", now, 2*time.Second))
})
})
Describe("GetAllUsageBySource", func() {
insert := func(db *gorm.DB, userID, source, keyID string, tokens int64) {
rec := &auth.UsageRecord{
UserID: userID,
Source: source,
Model: "gpt-4",
TotalTokens: tokens,
CreatedAt: time.Now(),
}
if keyID != "" {
rec.APIKeyID = &keyID
rec.APIKeyName = "name-" + keyID
}
Expect(auth.RecordUsage(db, rec)).To(Succeed())
}
It("includes legacy for admins", func() {
db := testDB()
insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10)
insert(db, "legacy-api-key", auth.UsageSourceLegacy, "", 5)
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
Expect(err).ToNot(HaveOccurred())
Expect(totals.BySource).To(HaveKey(auth.UsageSourceLegacy))
Expect(totals.BySource[auth.UsageSourceLegacy].Tokens).To(Equal(int64(5)))
})
It("filters by user_id AND api_key_id", func() {
db := testDB()
insert(db, "alice", auth.UsageSourceAPIKey, "k1", 10)
insert(db, "alice", auth.UsageSourceAPIKey, "k2", 20)
insert(db, "bob", auth.UsageSourceAPIKey, "k3", 30)
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "alice", "k2")
Expect(err).ToNot(HaveOccurred())
Expect(totals.GrandTotal.Tokens).To(Equal(int64(20)))
})
It("sets truncated=true when by_key exceeds the cap", func() {
db := testDB()
for i := 0; i < 210; i++ {
insert(db, "alice", auth.UsageSourceAPIKey, fmt.Sprintf("key-%03d", i), int64(210-i))
}
_, totals, truncated, err := auth.GetAllUsageBySource(db, "month", "", "")
Expect(err).ToNot(HaveOccurred())
Expect(truncated).To(BeTrue())
Expect(totals.ByKey).To(HaveLen(200))
Expect(totals.ByKey[0].Tokens > totals.ByKey[199].Tokens).To(BeTrue())
})
// insertNamed records a row with explicit user_id, user_name, source,
// and optional api key snapshot. Used by the user-attribution tests
// below which the older insert helper can't express.
insertNamed := func(db *gorm.DB, userID, userName, source, keyID, keyName string, tokens int64) {
rec := &auth.UsageRecord{
UserID: userID,
UserName: userName,
Source: source,
Model: "gpt-4",
TotalTokens: tokens,
CreatedAt: time.Now(),
}
if keyID != "" {
rec.APIKeyID = &keyID
rec.APIKeyName = keyName
}
Expect(auth.RecordUsage(db, rec)).To(Succeed())
}
It("attributes each KeyTotal to its owner user", func() {
db := testDB()
insertNamed(db, "alice", "Alice", auth.UsageSourceAPIKey, "k1", "ci-runner", 100)
insertNamed(db, "bob", "Bob", auth.UsageSourceAPIKey, "k2", "lap", 50)
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
Expect(err).ToNot(HaveOccurred())
Expect(totals.ByKey).To(HaveLen(2))
byID := map[string]auth.KeyTotal{}
for _, k := range totals.ByKey {
byID[k.APIKeyID] = k
}
Expect(byID["k1"].UserID).To(Equal("alice"))
Expect(byID["k1"].UserName).To(Equal("Alice"))
Expect(byID["k2"].UserID).To(Equal("bob"))
Expect(byID["k2"].UserName).To(Equal("Bob"))
})
It("breaks Web UI and legacy traffic out per user in by_user_source for admin", func() {
db := testDB()
// Alice and Bob both have Web UI traffic; a synthetic legacy user
// also contributes. ByUserSource should expose one row per
// (source, user) pair, never for source=apikey.
insertNamed(db, "alice", "Alice", auth.UsageSourceWeb, "", "", 30)
insertNamed(db, "bob", "Bob", auth.UsageSourceWeb, "", "", 70)
insertNamed(db, "legacy-api-key", "API Key User", auth.UsageSourceLegacy, "", "", 10)
insertNamed(db, "alice", "Alice", auth.UsageSourceAPIKey, "k1", "ci-runner", 5)
_, totals, _, err := auth.GetAllUsageBySource(db, "month", "", "")
Expect(err).ToNot(HaveOccurred())
Expect(totals.ByUserSource).ToNot(BeEmpty())
for _, r := range totals.ByUserSource {
Expect(r.Source).ToNot(Equal(auth.UsageSourceAPIKey))
}
webByUser := map[string]int64{}
legacyByUser := map[string]int64{}
for _, r := range totals.ByUserSource {
switch r.Source {
case auth.UsageSourceWeb:
webByUser[r.UserID] = r.Tokens
case auth.UsageSourceLegacy:
legacyByUser[r.UserID] = r.Tokens
}
}
Expect(webByUser["alice"]).To(Equal(int64(30)))
Expect(webByUser["bob"]).To(Equal(int64(70)))
Expect(legacyByUser["legacy-api-key"]).To(Equal(int64(10)))
})
It("does NOT populate by_user_source in the non-admin path", func() {
db := testDB()
insertNamed(db, "alice", "Alice", auth.UsageSourceWeb, "", "", 30)
_, totals, err := auth.GetUserUsageBySource(db, "alice", "month")
Expect(err).ToNot(HaveOccurred())
// Non-admin path uses includeLegacy=false, so by_user_source stays nil.
Expect(totals.ByUserSource).To(BeNil())
})
})
})