mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 05:36:49 -04:00
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
367 lines
9.8 KiB
Go
367 lines
9.8 KiB
Go
package auth
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// QuotaRule defines a rate/token limit for a user, optionally scoped to a model.
|
|
type QuotaRule struct {
|
|
ID string `gorm:"primaryKey;size:36"`
|
|
UserID string `gorm:"size:36;uniqueIndex:idx_quota_user_model"`
|
|
Model string `gorm:"size:255;uniqueIndex:idx_quota_user_model"` // "" = all models
|
|
MaxRequests *int64 // nil = no request limit
|
|
MaxTotalTokens *int64 // nil = no token limit
|
|
WindowSeconds int64 // e.g., 3600 = 1h
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"`
|
|
}
|
|
|
|
// QuotaStatus is returned to clients with current usage included.
|
|
type QuotaStatus struct {
|
|
ID string `json:"id"`
|
|
Model string `json:"model"`
|
|
MaxRequests *int64 `json:"max_requests"`
|
|
MaxTotalTokens *int64 `json:"max_total_tokens"`
|
|
Window string `json:"window"`
|
|
CurrentRequests int64 `json:"current_requests"`
|
|
CurrentTokens int64 `json:"current_total_tokens"`
|
|
ResetsAt string `json:"resets_at,omitempty"`
|
|
}
|
|
|
|
// ── CRUD ──
|
|
|
|
// CreateOrUpdateQuotaRule upserts a quota rule for the given user+model.
|
|
func CreateOrUpdateQuotaRule(db *gorm.DB, userID, model string, maxReqs, maxTokens *int64, windowSecs int64) (*QuotaRule, error) {
|
|
var existing QuotaRule
|
|
err := db.Where("user_id = ? AND model = ?", userID, model).First(&existing).Error
|
|
if err == gorm.ErrRecordNotFound {
|
|
rule := QuotaRule{
|
|
ID: uuid.New().String(),
|
|
UserID: userID,
|
|
Model: model,
|
|
MaxRequests: maxReqs,
|
|
MaxTotalTokens: maxTokens,
|
|
WindowSeconds: windowSecs,
|
|
}
|
|
if err := db.Create(&rule).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
quotaCache.invalidateUser(userID)
|
|
return &rule, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
existing.MaxRequests = maxReqs
|
|
existing.MaxTotalTokens = maxTokens
|
|
existing.WindowSeconds = windowSecs
|
|
if err := db.Save(&existing).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
quotaCache.invalidateUser(userID)
|
|
return &existing, nil
|
|
}
|
|
|
|
// ListQuotaRules returns all quota rules for a user.
|
|
func ListQuotaRules(db *gorm.DB, userID string) ([]QuotaRule, error) {
|
|
var rules []QuotaRule
|
|
if err := db.Where("user_id = ?", userID).Order("model ASC").Find(&rules).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return rules, nil
|
|
}
|
|
|
|
// DeleteQuotaRule removes a quota rule by ID (scoped to user for safety).
|
|
func DeleteQuotaRule(db *gorm.DB, ruleID, userID string) error {
|
|
result := db.Where("id = ? AND user_id = ?", ruleID, userID).Delete(&QuotaRule{})
|
|
if result.RowsAffected == 0 {
|
|
return fmt.Errorf("quota rule not found")
|
|
}
|
|
quotaCache.invalidateUser(userID)
|
|
return nil
|
|
}
|
|
|
|
// ── Usage queries ──
|
|
|
|
type usageCounts struct {
|
|
RequestCount int64
|
|
TotalTokens int64
|
|
}
|
|
|
|
// getUsageSince counts requests and tokens for a user since the given time.
|
|
func getUsageSince(db *gorm.DB, userID string, since time.Time, model string) (usageCounts, error) {
|
|
var result usageCounts
|
|
q := db.Model(&UsageRecord{}).
|
|
Select("COUNT(*) as request_count, COALESCE(SUM(total_tokens), 0) as total_tokens").
|
|
Where("user_id = ? AND created_at >= ?", userID, since)
|
|
if model != "" {
|
|
q = q.Where("model = ?", model)
|
|
}
|
|
if err := q.Row().Scan(&result.RequestCount, &result.TotalTokens); err != nil {
|
|
return result, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// GetQuotaStatuses returns all quota rules for a user with current usage.
|
|
func GetQuotaStatuses(db *gorm.DB, userID string) ([]QuotaStatus, error) {
|
|
rules, err := ListQuotaRules(db, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
statuses := make([]QuotaStatus, 0, len(rules))
|
|
now := time.Now()
|
|
for _, r := range rules {
|
|
windowStart := now.Add(-time.Duration(r.WindowSeconds) * time.Second)
|
|
counts, err := getUsageSince(db, userID, windowStart, r.Model)
|
|
if err != nil {
|
|
counts = usageCounts{}
|
|
}
|
|
statuses = append(statuses, QuotaStatus{
|
|
ID: r.ID,
|
|
Model: r.Model,
|
|
MaxRequests: r.MaxRequests,
|
|
MaxTotalTokens: r.MaxTotalTokens,
|
|
Window: formatWindowDuration(r.WindowSeconds),
|
|
CurrentRequests: counts.RequestCount,
|
|
CurrentTokens: counts.TotalTokens,
|
|
ResetsAt: windowStart.Add(time.Duration(r.WindowSeconds) * time.Second).UTC().Format(time.RFC3339),
|
|
})
|
|
}
|
|
return statuses, nil
|
|
}
|
|
|
|
// ── Quota check (used by middleware) ──
|
|
|
|
// QuotaExceeded checks whether the user has exceeded any applicable quota rule.
|
|
// Returns (exceeded bool, retryAfterSeconds int64, message string).
|
|
func QuotaExceeded(db *gorm.DB, userID, model string) (bool, int64, string) {
|
|
rules := quotaCache.getRules(db, userID)
|
|
if len(rules) == 0 {
|
|
return false, 0, ""
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
for _, r := range rules {
|
|
// Check if rule applies: model-specific rules match that model, global (empty) applies to all.
|
|
if r.Model != "" && r.Model != model {
|
|
continue
|
|
}
|
|
|
|
windowStart := now.Add(-time.Duration(r.WindowSeconds) * time.Second)
|
|
retryAfter := r.WindowSeconds // worst case: full window
|
|
|
|
// Try cache first
|
|
counts, ok := quotaCache.getUsage(userID, r.Model, windowStart)
|
|
if !ok {
|
|
var err error
|
|
counts, err = getUsageSince(db, userID, windowStart, r.Model)
|
|
if err != nil {
|
|
continue // on error, don't block the request
|
|
}
|
|
quotaCache.setUsage(userID, r.Model, windowStart, counts)
|
|
}
|
|
|
|
if r.MaxRequests != nil && counts.RequestCount >= *r.MaxRequests {
|
|
scope := "all models"
|
|
if r.Model != "" {
|
|
scope = "model " + r.Model
|
|
}
|
|
return true, retryAfter, fmt.Sprintf(
|
|
"Request quota exceeded for %s: %d/%d requests in %s window",
|
|
scope, counts.RequestCount, *r.MaxRequests, formatWindowDuration(r.WindowSeconds),
|
|
)
|
|
}
|
|
if r.MaxTotalTokens != nil && counts.TotalTokens >= *r.MaxTotalTokens {
|
|
scope := "all models"
|
|
if r.Model != "" {
|
|
scope = "model " + r.Model
|
|
}
|
|
return true, retryAfter, fmt.Sprintf(
|
|
"Token quota exceeded for %s: %d/%d tokens in %s window",
|
|
scope, counts.TotalTokens, *r.MaxTotalTokens, formatWindowDuration(r.WindowSeconds),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Optimistic increment: bump cached counters so subsequent requests in the
|
|
// same cache window see an updated count without re-querying the DB.
|
|
for _, r := range rules {
|
|
if r.Model != "" && r.Model != model {
|
|
continue
|
|
}
|
|
windowStart := now.Add(-time.Duration(r.WindowSeconds) * time.Second)
|
|
quotaCache.incrementUsage(userID, r.Model, windowStart)
|
|
}
|
|
|
|
return false, 0, ""
|
|
}
|
|
|
|
// ── In-memory cache ──
|
|
|
|
var quotaCache = newQuotaCacheStore()
|
|
|
|
type quotaCacheStore struct {
|
|
mu sync.RWMutex
|
|
rules map[string]cachedRules // userID -> rules
|
|
usage map[string]cachedUsage // "userID|model|windowStart" -> counts
|
|
}
|
|
|
|
type cachedRules struct {
|
|
rules []QuotaRule
|
|
fetchedAt time.Time
|
|
}
|
|
|
|
type cachedUsage struct {
|
|
counts usageCounts
|
|
fetchedAt time.Time
|
|
}
|
|
|
|
func newQuotaCacheStore() *quotaCacheStore {
|
|
c := "aCacheStore{
|
|
rules: make(map[string]cachedRules),
|
|
usage: make(map[string]cachedUsage),
|
|
}
|
|
go c.cleanupLoop()
|
|
return c
|
|
}
|
|
|
|
const (
|
|
rulesCacheTTL = 30 * time.Second
|
|
usageCacheTTL = 10 * time.Second
|
|
)
|
|
|
|
func (c *quotaCacheStore) getRules(db *gorm.DB, userID string) []QuotaRule {
|
|
c.mu.RLock()
|
|
cached, ok := c.rules[userID]
|
|
c.mu.RUnlock()
|
|
if ok && time.Since(cached.fetchedAt) < rulesCacheTTL {
|
|
return cached.rules
|
|
}
|
|
|
|
rules, err := ListQuotaRules(db, userID)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
c.mu.Lock()
|
|
c.rules[userID] = cachedRules{rules: rules, fetchedAt: time.Now()}
|
|
c.mu.Unlock()
|
|
return rules
|
|
}
|
|
|
|
func (c *quotaCacheStore) invalidateUser(userID string) {
|
|
c.mu.Lock()
|
|
delete(c.rules, userID)
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
func usageKey(userID, model string, windowStart time.Time) string {
|
|
return userID + "|" + model + "|" + windowStart.Truncate(time.Second).Format(time.RFC3339)
|
|
}
|
|
|
|
func (c *quotaCacheStore) getUsage(userID, model string, windowStart time.Time) (usageCounts, bool) {
|
|
key := usageKey(userID, model, windowStart)
|
|
c.mu.RLock()
|
|
cached, ok := c.usage[key]
|
|
c.mu.RUnlock()
|
|
if ok && time.Since(cached.fetchedAt) < usageCacheTTL {
|
|
return cached.counts, true
|
|
}
|
|
return usageCounts{}, false
|
|
}
|
|
|
|
func (c *quotaCacheStore) setUsage(userID, model string, windowStart time.Time, counts usageCounts) {
|
|
key := usageKey(userID, model, windowStart)
|
|
c.mu.Lock()
|
|
c.usage[key] = cachedUsage{counts: counts, fetchedAt: time.Now()}
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
func (c *quotaCacheStore) incrementUsage(userID, model string, windowStart time.Time) {
|
|
key := usageKey(userID, model, windowStart)
|
|
c.mu.Lock()
|
|
if cached, ok := c.usage[key]; ok {
|
|
cached.counts.RequestCount++
|
|
c.usage[key] = cached
|
|
}
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
func (c *quotaCacheStore) cleanupLoop() {
|
|
ticker := time.NewTicker(60 * time.Second)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
c.mu.Lock()
|
|
now := time.Now()
|
|
for k, v := range c.rules {
|
|
if now.Sub(v.fetchedAt) > rulesCacheTTL*2 {
|
|
delete(c.rules, k)
|
|
}
|
|
}
|
|
for k, v := range c.usage {
|
|
if now.Sub(v.fetchedAt) > usageCacheTTL*2 {
|
|
delete(c.usage, k)
|
|
}
|
|
}
|
|
c.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// ── Helpers ──
|
|
|
|
// ParseWindowDuration converts a human-friendly window string to seconds.
|
|
func ParseWindowDuration(s string) (int64, error) {
|
|
switch s {
|
|
case "1m":
|
|
return 60, nil
|
|
case "5m":
|
|
return 300, nil
|
|
case "1h":
|
|
return 3600, nil
|
|
case "6h":
|
|
return 21600, nil
|
|
case "1d":
|
|
return 86400, nil
|
|
case "7d":
|
|
return 604800, nil
|
|
case "30d":
|
|
return 2592000, nil
|
|
}
|
|
// Try Go duration parsing as fallback
|
|
d, err := time.ParseDuration(s)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("invalid window duration: %s", s)
|
|
}
|
|
return int64(d.Seconds()), nil
|
|
}
|
|
|
|
// formatWindowDuration converts seconds to a human-friendly string.
|
|
func formatWindowDuration(secs int64) string {
|
|
switch secs {
|
|
case 60:
|
|
return "1m"
|
|
case 300:
|
|
return "5m"
|
|
case 3600:
|
|
return "1h"
|
|
case 21600:
|
|
return "6h"
|
|
case 86400:
|
|
return "1d"
|
|
case 604800:
|
|
return "7d"
|
|
case 2592000:
|
|
return "30d"
|
|
default:
|
|
d := time.Duration(secs) * time.Second
|
|
return d.String()
|
|
}
|
|
}
|