mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
* feat(ui): add users and authentication support Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: allow the admin user to impersonificate users Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: ui improvements, disable 'Users' button in navbar when no auth is configured Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: add OIDC support Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: gate models Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: cache requests to optimize speed Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * small UI enhancements Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(ui): style improvements Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: cover other paths by auth Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: separate local auth, refactor Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * security hardening, approval mode Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: fix tests and expectations Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: update localagi/localrecall Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
186 lines
4.3 KiB
Go
186 lines
4.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/LocalAI/core/http/auth"
|
|
"github.com/mudler/xlog"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
usageFlushInterval = 5 * time.Second
|
|
usageMaxPending = 5000
|
|
)
|
|
|
|
// usageBatcher accumulates usage records and flushes them to the DB periodically.
|
|
type usageBatcher struct {
|
|
mu sync.Mutex
|
|
pending []*auth.UsageRecord
|
|
db *gorm.DB
|
|
}
|
|
|
|
func (b *usageBatcher) add(r *auth.UsageRecord) {
|
|
b.mu.Lock()
|
|
b.pending = append(b.pending, r)
|
|
b.mu.Unlock()
|
|
}
|
|
|
|
func (b *usageBatcher) flush() {
|
|
b.mu.Lock()
|
|
batch := b.pending
|
|
b.pending = nil
|
|
b.mu.Unlock()
|
|
|
|
if len(batch) == 0 {
|
|
return
|
|
}
|
|
|
|
if err := b.db.Create(&batch).Error; err != nil {
|
|
xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err)
|
|
// Re-queue failed records with a cap to avoid unbounded growth
|
|
b.mu.Lock()
|
|
if len(b.pending) < usageMaxPending {
|
|
b.pending = append(batch, b.pending...)
|
|
}
|
|
b.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
var batcher *usageBatcher
|
|
|
|
// InitUsageRecorder starts a background goroutine that periodically flushes
|
|
// accumulated usage records to the database.
|
|
func InitUsageRecorder(db *gorm.DB) {
|
|
if db == nil {
|
|
return
|
|
}
|
|
batcher = &usageBatcher{db: db}
|
|
go func() {
|
|
ticker := time.NewTicker(usageFlushInterval)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
batcher.flush()
|
|
}
|
|
}()
|
|
}
|
|
|
|
// usageResponseBody is the minimal structure we need from the response JSON.
|
|
type usageResponseBody struct {
|
|
Model string `json:"model"`
|
|
Usage *struct {
|
|
PromptTokens int64 `json:"prompt_tokens"`
|
|
CompletionTokens int64 `json:"completion_tokens"`
|
|
TotalTokens int64 `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
// UsageMiddleware extracts token usage from OpenAI-compatible response JSON
|
|
// and records it per-user.
|
|
func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if db == nil || batcher == nil {
|
|
return next(c)
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
// Wrap response writer to capture body
|
|
resBody := new(bytes.Buffer)
|
|
origWriter := c.Response().Writer
|
|
mw := &bodyWriter{
|
|
ResponseWriter: origWriter,
|
|
body: resBody,
|
|
}
|
|
c.Response().Writer = mw
|
|
|
|
handlerErr := next(c)
|
|
|
|
// Restore original writer
|
|
c.Response().Writer = origWriter
|
|
|
|
// Only record on successful responses
|
|
if c.Response().Status < 200 || c.Response().Status >= 300 {
|
|
return handlerErr
|
|
}
|
|
|
|
// Get authenticated user
|
|
user := auth.GetUser(c)
|
|
if user == nil {
|
|
return handlerErr
|
|
}
|
|
|
|
// Try to parse usage from response
|
|
responseBytes := resBody.Bytes()
|
|
if len(responseBytes) == 0 {
|
|
return handlerErr
|
|
}
|
|
|
|
// Check content type
|
|
ct := c.Response().Header().Get("Content-Type")
|
|
isJSON := ct == "" || ct == "application/json" || bytes.HasPrefix([]byte(ct), []byte("application/json"))
|
|
isSSE := bytes.HasPrefix([]byte(ct), []byte("text/event-stream"))
|
|
|
|
if !isJSON && !isSSE {
|
|
return handlerErr
|
|
}
|
|
|
|
var resp usageResponseBody
|
|
if isSSE {
|
|
last, ok := lastSSEData(responseBytes)
|
|
if !ok {
|
|
return handlerErr
|
|
}
|
|
if err := json.Unmarshal(last, &resp); err != nil {
|
|
return handlerErr
|
|
}
|
|
} else {
|
|
if err := json.Unmarshal(responseBytes, &resp); err != nil {
|
|
return handlerErr
|
|
}
|
|
}
|
|
|
|
if resp.Usage == nil {
|
|
return handlerErr
|
|
}
|
|
|
|
record := &auth.UsageRecord{
|
|
UserID: user.ID,
|
|
UserName: user.Name,
|
|
Model: resp.Model,
|
|
Endpoint: c.Request().URL.Path,
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
TotalTokens: resp.Usage.TotalTokens,
|
|
Duration: time.Since(startTime).Milliseconds(),
|
|
CreatedAt: startTime,
|
|
}
|
|
|
|
batcher.add(record)
|
|
|
|
return handlerErr
|
|
}
|
|
}
|
|
}
|
|
|
|
// lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]".
|
|
func lastSSEData(b []byte) ([]byte, bool) {
|
|
prefix := []byte("data: ")
|
|
var last []byte
|
|
for _, line := range bytes.Split(b, []byte("\n")) {
|
|
line = bytes.TrimRight(line, "\r")
|
|
if bytes.HasPrefix(line, prefix) {
|
|
payload := line[len(prefix):]
|
|
if !bytes.Equal(payload, []byte("[DONE]")) {
|
|
last = payload
|
|
}
|
|
}
|
|
}
|
|
return last, last != nil
|
|
}
|