Files
LocalAI/core/http/auth/usage.go
LocalAI [bot] f0cb02afb8 feat(usage): attribute Sources rows to user accounts in admin view (#9935)
The merged feature (#9920) let admins see per-API-key and per-source
totals but did not surface which user owned each key, and lumped
every user's Web UI traffic into a single global Web UI row. This
makes the admin Sources tab properly per-user attributable:

- KeyTotal gains UserID + UserName, populated from the snapshot the
  usage middleware already records. The by_key roll-up now groups by
  (api_key_id, api_key_name, user_id, user_name).
- New SourceTotals.ByUserSource roll-up groups (source, user_id,
  user_name) for sources without a key identity (web, legacy). Only
  populated on the admin path (includeLegacy=true); the non-admin
  endpoint stays unchanged for backwards compatibility.
- SourcesTable accepts showUserColumn={isAdmin}; admin view renders
  a User column, makes the search match user name/id, and expands
  Web UI / legacy pseudo-rows from the global aggregate to one row
  per user using by_user_source.

Refs: #9862

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-05-21 23:23:06 +02:00

448 lines
15 KiB
Go

package auth
import (
"fmt"
"strings"
"time"
"github.com/mudler/xlog"
"gorm.io/gorm"
)
// Source classification for a UsageRecord.
const (
UsageSourceAPIKey = "apikey" // request authenticated with a named UserAPIKey
UsageSourceWeb = "web" // request authenticated with a session cookie (web UI)
UsageSourceLegacy = "legacy" // request authenticated with an env-configured legacy key
)
// UsageRecord represents a single API request's token usage.
type UsageRecord struct {
ID uint `gorm:"primaryKey;autoIncrement"`
UserID string `gorm:"size:36;index:idx_usage_user_time"`
UserName string `gorm:"size:255"`
// Source classifies how the request authenticated. One of UsageSource* constants.
// Empty for pre-feature rows until the InitDB backfill runs.
Source string `gorm:"size:16;index:idx_usage_source"`
// APIKeyID is the UserAPIKey.ID when Source == UsageSourceAPIKey. Nil otherwise.
APIKeyID *string `gorm:"size:36;index:idx_usage_apikey"`
// APIKeyName is a snapshot of UserAPIKey.Name at write time. Survives key deletion.
APIKeyName string `gorm:"size:255"`
Model string `gorm:"size:255;index"`
Endpoint string `gorm:"size:255"`
PromptTokens int64
CompletionTokens int64
TotalTokens int64
Duration int64 // milliseconds
CreatedAt time.Time `gorm:"index:idx_usage_user_time"`
}
// RecordUsage inserts a usage record.
func RecordUsage(db *gorm.DB, record *UsageRecord) error {
return db.Create(record).Error
}
// UsageBucket is an aggregated time bucket for the dashboard.
type UsageBucket struct {
Bucket string `json:"bucket"`
Model string `json:"model,omitempty"`
UserID string `json:"user_id,omitempty"`
UserName string `json:"user_name,omitempty"`
Source string `json:"source,omitempty"`
APIKeyID string `json:"api_key_id,omitempty"`
APIKeyName string `json:"api_key_name,omitempty"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
RequestCount int64 `json:"request_count"`
}
// UsageTotals is a summary of all usage.
type UsageTotals struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
RequestCount int64 `json:"request_count"`
}
// periodToWindow returns the time window and SQL date format for a period.
func periodToWindow(period string, isSQLite bool) (time.Time, string) {
now := time.Now()
var since time.Time
var dateFmt string
switch period {
case "day":
since = now.Add(-24 * time.Hour)
if isSQLite {
dateFmt = "strftime('%Y-%m-%d %H:00', created_at)"
} else {
dateFmt = "to_char(date_trunc('hour', created_at), 'YYYY-MM-DD HH24:00')"
}
case "week":
since = now.Add(-7 * 24 * time.Hour)
if isSQLite {
dateFmt = "strftime('%Y-%m-%d', created_at)"
} else {
dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')"
}
case "all":
since = time.Time{} // zero time = no filter
if isSQLite {
dateFmt = "strftime('%Y-%m', created_at)"
} else {
dateFmt = "to_char(date_trunc('month', created_at), 'YYYY-MM')"
}
default: // "month"
since = now.Add(-30 * 24 * time.Hour)
if isSQLite {
dateFmt = "strftime('%Y-%m-%d', created_at)"
} else {
dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')"
}
}
return since, dateFmt
}
func isSQLiteDB(db *gorm.DB) bool {
return strings.Contains(db.Dialector.Name(), "sqlite")
}
// GetUserUsage returns aggregated usage for a single user.
func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) {
sqlite := isSQLiteDB(db)
since, dateFmt := periodToWindow(period, sqlite)
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}).
Select(bucketExpr+", model, "+
"SUM(prompt_tokens) as prompt_tokens, "+
"SUM(completion_tokens) as completion_tokens, "+
"SUM(total_tokens) as total_tokens, "+
"COUNT(*) as request_count").
Where("user_id = ?", userID).
Group("bucket, model").
Order("bucket ASC")
if !since.IsZero() {
query = query.Where("created_at >= ?", since)
}
var buckets []UsageBucket
if err := query.Find(&buckets).Error; err != nil {
return nil, err
}
return buckets, nil
}
// BackfillUsageSource sets the Source column on pre-feature usage rows.
// Idempotent: only touches rows where source is NULL or empty.
// - rows whose user_id == "legacy-api-key" -> UsageSourceLegacy
// - everything else -> UsageSourceWeb
func BackfillUsageSource(db *gorm.DB) error {
// Legacy first (more specific predicate)
if err := db.Exec(
`UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '') AND user_id = ?`,
UsageSourceLegacy, "legacy-api-key",
).Error; err != nil {
return fmt.Errorf("backfill legacy usage source: %w", err)
}
// Everything else -> web
if err := db.Exec(
`UPDATE usage_records SET source = ? WHERE (source IS NULL OR source = '')`,
UsageSourceWeb,
).Error; err != nil {
return fmt.Errorf("backfill web usage source: %w", err)
}
return nil
}
// GetAllUsage returns aggregated usage for all users (admin). Optional userID filter.
func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) {
sqlite := isSQLiteDB(db)
since, dateFmt := periodToWindow(period, sqlite)
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}).
Select(bucketExpr + ", model, user_id, user_name, " +
"SUM(prompt_tokens) as prompt_tokens, " +
"SUM(completion_tokens) as completion_tokens, " +
"SUM(total_tokens) as total_tokens, " +
"COUNT(*) as request_count").
Group("bucket, model, user_id, user_name").
Order("bucket ASC")
if !since.IsZero() {
query = query.Where("created_at >= ?", since)
}
if userID != "" {
query = query.Where("user_id = ?", userID)
}
var buckets []UsageBucket
if err := query.Find(&buckets).Error; err != nil {
return nil, err
}
return buckets, nil
}
// TotalsEntry is a token+request roll-up.
type TotalsEntry struct {
Tokens int64 `json:"tokens"`
Requests int64 `json:"requests"`
}
// KeyTotal is the per-key roll-up returned by sources endpoints. UserID and
// UserName are snapshotted from the UsageRecord so revoked-and-deleted keys
// still carry their owner attribution in admin views.
type KeyTotal struct {
APIKeyID string `json:"api_key_id"`
APIKeyName string `json:"api_key_name"`
UserID string `json:"user_id"`
UserName string `json:"user_name"`
Tokens int64 `json:"tokens"`
Requests int64 `json:"requests"`
LastUsed time.Time `json:"last_used"`
}
// UserSourceTotal is a per-(user, source) roll-up for sources that don't carry
// a named API key identity (web, legacy). It exists so admin views can show
// which user generated each block of Web UI / legacy traffic; the per-apikey
// breakdown for source=apikey already lives in KeyTotal.
type UserSourceTotal struct {
Source string `json:"source"`
UserID string `json:"user_id"`
UserName string `json:"user_name"`
Tokens int64 `json:"tokens"`
Requests int64 `json:"requests"`
}
// SourceTotals summarises a per-source breakdown.
type SourceTotals struct {
BySource map[string]TotalsEntry `json:"by_source"`
ByKey []KeyTotal `json:"by_key"` // server-sorted desc by tokens, capped
ByUserSource []UserSourceTotal `json:"by_user_source,omitempty"` // populated only when includeLegacy=true
GrandTotal TotalsEntry `json:"grand_total"`
}
const maxKeyTotals = 200
// GetUserUsageBySource returns per-source aggregated usage for one user. Legacy
// is excluded by design (visible to admins only via the admin variant).
func GetUserUsageBySource(db *gorm.DB, userID, period string) ([]UsageBucket, SourceTotals, error) {
sqlite := isSQLiteDB(db)
since, dateFmt := periodToWindow(period, sqlite)
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}).
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
"SUM(prompt_tokens) as prompt_tokens, "+
"SUM(completion_tokens) as completion_tokens, "+
"SUM(total_tokens) as total_tokens, "+
"COUNT(*) as request_count").
Where("user_id = ?", userID).
Where("source <> ?", UsageSourceLegacy).
Group("bucket, source, api_key_id, api_key_name").
Order("bucket ASC")
if !since.IsZero() {
query = query.Where("created_at >= ?", since)
}
var buckets []UsageBucket
if err := query.Find(&buckets).Error; err != nil {
return nil, SourceTotals{}, err
}
totals := computeSourceTotals(db, userID, "", since, false)
return buckets, totals, nil
}
// computeSourceTotals rolls up by_source / by_key / grand_total.
// userID/apiKeyID are optional filters. includeLegacy controls whether the
// legacy bucket is exposed (admin-only).
func computeSourceTotals(db *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) SourceTotals {
totals := SourceTotals{BySource: map[string]TotalsEntry{}}
bySourceQ := db.Model(&UsageRecord{}).
Select("source, SUM(total_tokens) as tokens, COUNT(*) as requests").
Group("source")
bySourceQ = applyFilters(bySourceQ, userID, apiKeyID, since, includeLegacy)
var bySourceRows []struct {
Source string
Tokens int64
Requests int64
}
if err := bySourceQ.Scan(&bySourceRows).Error; err != nil {
xlog.Warn("computeSourceTotals: by-source Scan failed", "error", err)
return totals
}
for _, r := range bySourceRows {
totals.BySource[r.Source] = TotalsEntry{Tokens: r.Tokens, Requests: r.Requests}
totals.GrandTotal.Tokens += r.Tokens
totals.GrandTotal.Requests += r.Requests
}
byKeyQ := db.Model(&UsageRecord{}).
Select("COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
"user_id, user_name, "+
"SUM(total_tokens) as tokens, COUNT(*) as requests, MAX(created_at) as last_used").
Where("api_key_id IS NOT NULL AND api_key_id <> ''").
Group("api_key_id, api_key_name, user_id, user_name").
Order("tokens DESC").
Limit(maxKeyTotals)
byKeyQ = applyFilters(byKeyQ, userID, apiKeyID, since, includeLegacy)
// Iterate Rows() manually because MAX(created_at) is returned as a string by
// the SQLite driver, and Go's database/sql refuses to scan that into
// *time.Time. Postgres returns a proper timestamp. We accept both shapes
// via a Rows.Scan into a string column, then parse uniformly.
rows, err := byKeyQ.Rows()
if err != nil {
xlog.Warn("computeSourceTotals: by-key Rows() failed", "error", err)
} else {
defer func() { _ = rows.Close() }()
out := make([]KeyTotal, 0)
for rows.Next() {
var (
apiKeyID, apiKeyName, userIDCol, userName, lastUsedRaw string
tokens, requests int64
)
if scanErr := rows.Scan(&apiKeyID, &apiKeyName, &userIDCol, &userName, &tokens, &requests, &lastUsedRaw); scanErr != nil {
continue
}
out = append(out, KeyTotal{
APIKeyID: apiKeyID,
APIKeyName: apiKeyName,
UserID: userIDCol,
UserName: userName,
Tokens: tokens,
Requests: requests,
LastUsed: parseLastUsedString(lastUsedRaw),
})
}
if rerr := rows.Err(); rerr != nil {
xlog.Warn("computeSourceTotals: by-key rows iteration failed", "error", rerr)
}
totals.ByKey = out
}
// by_user_source: only populated for admin callers (includeLegacy=true) so
// they can attribute Web UI / legacy traffic to specific users. Per-apikey
// rows already carry user info via KeyTotal above, so this query only
// covers source != apikey.
if includeLegacy {
byUserSourceQ := db.Model(&UsageRecord{}).
Select("source, user_id, user_name, "+
"SUM(total_tokens) as tokens, COUNT(*) as requests").
Where("source <> ?", UsageSourceAPIKey).
Group("source, user_id, user_name").
Order("tokens DESC")
byUserSourceQ = applyFilters(byUserSourceQ, userID, apiKeyID, since, includeLegacy)
var byUserSourceRows []UserSourceTotal
if scanErr := byUserSourceQ.Scan(&byUserSourceRows).Error; scanErr != nil {
xlog.Warn("computeSourceTotals: by-user-source Scan failed", "error", scanErr)
} else {
totals.ByUserSource = byUserSourceRows
}
}
return totals
}
// parseLastUsedString converts the textual MAX(created_at) value returned by
// SQLite (or any driver that surfaces the timestamp as a string) into a
// time.Time. Returns the zero time on parse failure.
func parseLastUsedString(s string) time.Time {
if s == "" {
return time.Time{}
}
// GORM's SQLite driver emits Go's default time formatting. Try the formats
// it commonly produces, falling back to RFC3339Nano.
layouts := []string{
"2006-01-02 15:04:05.999999999 -0700 MST",
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02 15:04:05",
time.RFC3339Nano,
time.RFC3339,
}
for _, layout := range layouts {
if t, err := time.Parse(layout, s); err == nil {
return t
}
}
xlog.Warn("parseLastUsedString: unrecognised format", "value", s)
return time.Time{}
}
// GetAllUsageBySource is the admin variant of GetUserUsageBySource.
// Optional filters: userID and apiKeyID. Legacy is included.
// truncated == true iff the per-key roll-up was capped at maxKeyTotals.
func GetAllUsageBySource(db *gorm.DB, period, userID, apiKeyID string) ([]UsageBucket, SourceTotals, bool, error) {
sqlite := isSQLiteDB(db)
since, dateFmt := periodToWindow(period, sqlite)
bucketExpr := fmt.Sprintf("%s as bucket", dateFmt)
query := db.Model(&UsageRecord{}).
Select(bucketExpr+", source, COALESCE(api_key_id, '') as api_key_id, api_key_name, "+
"user_id, user_name, "+
"SUM(prompt_tokens) as prompt_tokens, "+
"SUM(completion_tokens) as completion_tokens, "+
"SUM(total_tokens) as total_tokens, "+
"COUNT(*) as request_count").
Group("bucket, source, api_key_id, api_key_name, user_id, user_name").
Order("bucket ASC")
query = applyFilters(query, userID, apiKeyID, since, true)
var buckets []UsageBucket
if err := query.Find(&buckets).Error; err != nil {
return nil, SourceTotals{}, false, err
}
totals := computeSourceTotals(db, userID, apiKeyID, since, true)
// Count distinct api_key_ids matching the filters. If > maxKeyTotals,
// the by_key slice was capped and we signal truncation to the caller.
truncated := false
var distinct int64
countQ := applyFilters(
db.Model(&UsageRecord{}).
Distinct("api_key_id").
Where("api_key_id IS NOT NULL AND api_key_id <> ''"),
userID, apiKeyID, since, true,
)
if err := countQ.Count(&distinct).Error; err != nil {
xlog.Warn("GetAllUsageBySource: distinct api_key_id count failed", "error", err)
} else {
truncated = distinct > maxKeyTotals
}
return buckets, totals, truncated, nil
}
func applyFilters(q *gorm.DB, userID, apiKeyID string, since time.Time, includeLegacy bool) *gorm.DB {
if userID != "" {
q = q.Where("user_id = ?", userID)
}
if apiKeyID != "" {
q = q.Where("api_key_id = ?", apiKeyID)
}
if !since.IsZero() {
q = q.Where("created_at >= ?", since)
}
if !includeLegacy {
q = q.Where("source <> ?", UsageSourceLegacy)
}
return q
}