Files
LocalAI/core/http/middleware/usage.go
Richard Palethorpe 6a80e23733 feat(middleware): Model routing, PII filtering, Cloud model proxies (#9802)
Add a routing middleware stack and a cloud-proxy backend.

* cloud-proxy: a Go gRPC backend that forwards OpenAI- and
  Anthropic-shaped chat requests to upstream providers, with an
  optional translate mode (OpenAI request -> Anthropic /v1/messages
  -> OpenAI response) and full tool-calling support.

* routing: admission control, content-aware model routing
  (embedding cache + classifier + rerank + Arch-Router score),
  PII detection/redaction (regex + NER) with streaming filter and
  OpenAI/Anthropic adapters, and a per-user/per-key billing recorder
  backed by GORM or in-memory storage.

* middleware: UsageMiddleware records usage via the billing recorder,
  plus admission, route-model, usage-stamp and trace middlewares.

* observability: BackendTrace ring buffer stores full request bodies
  (capped), MITM proxy emits structured trace events, and router
  classifier decisions surface at /api/router/decide.

* gallery: Arch-Router-1.5B (Q4_K_M and Q8_0).

* UI: cloud-proxy model-editor fields, classifier system-prompt and
  score-normalization config, and a Traces page rendering request
  bodies.

Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-25 09:28:27 +02:00

292 lines
9.5 KiB
Go

package middleware
import (
"bytes"
"context"
"encoding/json"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services/routing/billing"
"github.com/mudler/xlog"
)
// usageResponseBody is the minimal structure we need from an OpenAI-shaped
// JSON response. Anthropic responses are decoded separately because their
// usage block uses different field names (input_tokens / output_tokens).
type usageResponseBody struct {
Model string `json:"model"`
Usage *struct {
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
TotalTokens int64 `json:"total_tokens"`
} `json:"usage"`
}
// anthropicResponseBody covers /v1/messages JSON responses.
type anthropicResponseBody struct {
Model string `json:"model"`
Usage *struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
} `json:"usage"`
}
// UsageMiddleware records token usage for inference requests via the
// billing.Recorder. Two paths produce a record:
//
// 1. Handler-stamped (preferred): the request handler called
// middleware.StampUsage with the canonical token counts before
// returning. This is the only reliable path for streaming responses
// — clients rarely set OpenAI's stream_options.include_usage, and
// Anthropic's usage lives in a separate message_delta event.
// 2. Body-parsed (fallback): the response is parsed for an OpenAI- or
// Anthropic-shaped usage block. Used by passthrough proxies and
// foreign endpoints.
//
// Recorder being nil (e.g., --disable-stats) makes the middleware a
// transparent pass-through. fallbackUser is used when auth.GetUser(c)
// returns nil; without it, an unauthenticated request would be dropped.
//
// Every request that fails to produce a record ticks
// localai_usage_unrecorded_total so silent billing misses are observable.
func UsageMiddleware(recorder *billing.Recorder, fallbackUser *auth.User) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if recorder == nil {
return next(c)
}
startTime := time.Now()
// Wrap response writer to capture body for the fallback parser.
// When the handler stamps the context we never read this buffer,
// so the cost is the per-chunk Write going through one extra
// indirection — accepted overhead in exchange for one billing
// path that works for both stamping and body-parse callers.
resBody := new(bytes.Buffer)
origWriter := c.Response().Writer
mw := &bodyWriter{
ResponseWriter: origWriter,
body: resBody,
}
c.Response().Writer = mw
handlerErr := next(c)
c.Response().Writer = origWriter
endpoint := c.Request().URL.Path
if c.Response().Status < 200 || c.Response().Status >= 300 {
return handlerErr
}
user := auth.GetUser(c)
if user == nil {
user = fallbackUser
}
if user == nil || user.ID == "" {
billing.CountUnrecorded(context.Background(), endpoint, "no_user")
return handlerErr
}
model, prompt, completion, total, ok := tokensFromContext(c)
if !ok {
model, prompt, completion, total, ok = tokensFromBody(resBody.Bytes(), c.Response().Header().Get("Content-Type"))
}
if !ok {
billing.CountUnrecorded(context.Background(), endpoint, "no_usage")
return handlerErr
}
requested, served := modelsFromContext(c, model)
pre, post := promptTokensFromContext(c, prompt)
source := auth.GetSource(c)
if source == "" {
// Auth disabled or unrecognised path: classify as web so the row is still
// bucketable rather than silently dropped from per-source aggregates.
source = auth.UsageSourceWeb
}
record := &auth.UsageRecord{
UserID: user.ID,
UserName: user.Name,
Source: source,
Model: model,
Endpoint: endpoint,
PromptTokens: prompt,
CompletionTokens: completion,
TotalTokens: total,
Duration: time.Since(startTime).Milliseconds(),
CreatedAt: startTime,
RequestedModel: requested,
ServedModel: served,
PreFilterPromptTokens: pre,
PostFilterPromptTokens: post,
CorrelationID: correlationIDFromContext(c),
}
if key := auth.GetAPIKey(c); key != nil {
id := key.ID
record.APIKeyID = &id
record.APIKeyName = key.Name
}
if err := recorder.Record(context.Background(), record); err != nil {
xlog.Error("usage middleware: recorder.Record failed", "error", err, "user", user.ID, "model", model)
billing.CountUnrecorded(context.Background(), endpoint, "record_failed")
}
return handlerErr
}
}
}
// tokensFromContext returns canonical token counts stamped by a handler
// via middleware.StampUsage. Returns ok=false when no stamp is present
// — the caller then tries the body-parse fallback.
//
// A model name without token counts is not considered "stamped" because a
// record with zero tokens looks the same as a never-recorded request to
// later analytics; the second condition is what gates ok.
func tokensFromContext(c echo.Context) (model string, prompt, completion, total int64, ok bool) {
if v, found := c.Get(ContextKeyResponseModel).(string); found {
model = v
}
pPresent := false
cPresent := false
if v, found := c.Get(ContextKeyPromptTokens).(int64); found {
prompt = v
pPresent = true
}
if v, found := c.Get(ContextKeyCompletionTokens).(int64); found {
completion = v
cPresent = true
}
if v, found := c.Get(ContextKeyTotalTokens).(int64); found {
total = v
} else {
total = prompt + completion
}
ok = pPresent || cPresent
return
}
// tokensFromBody covers the passthrough-proxy / foreign-endpoint case
// where no handler stamps the context. Returns ok=false on any parse
// failure or missing-usage; the caller increments the unrecorded counter.
func tokensFromBody(responseBytes []byte, contentType string) (model string, prompt, completion, total int64, ok bool) {
if len(responseBytes) == 0 {
return
}
isJSON := contentType == "" || contentType == "application/json" || bytes.HasPrefix([]byte(contentType), []byte("application/json"))
isSSE := bytes.HasPrefix([]byte(contentType), []byte("text/event-stream"))
if !isJSON && !isSSE {
return
}
payload := responseBytes
if isSSE {
// For SSE, the canonical usage chunk is the *last* non-[DONE] data
// line. OpenAI clients only emit one if stream_options.include_usage
// is set; Anthropic emits a final message_delta with usage. Both
// fit the "last data: line" rule.
last, lastOk := lastSSEData(responseBytes)
if !lastOk {
return
}
payload = last
}
// Try OpenAI shape first (handles /v1/chat/completions, /v1/completions,
// /v1/embeddings, /v1/edits, and any proxy that translates to OpenAI).
// A usage block whose token fields all decoded to zero is ambiguous —
// it could be an Anthropic body that happens to have a `usage` key —
// so fall through to the Anthropic parser instead of recording zeros.
var openAI usageResponseBody
if err := json.Unmarshal(payload, &openAI); err == nil && openAI.Usage != nil {
if openAI.Usage.PromptTokens != 0 || openAI.Usage.CompletionTokens != 0 || openAI.Usage.TotalTokens != 0 {
model = openAI.Model
prompt = openAI.Usage.PromptTokens
completion = openAI.Usage.CompletionTokens
total = openAI.Usage.TotalTokens
if total == 0 {
total = prompt + completion
}
ok = true
return
}
}
// Fall through to Anthropic shape (proxy passthrough territory).
var ant anthropicResponseBody
if err := json.Unmarshal(payload, &ant); err == nil && ant.Usage != nil {
if ant.Usage.InputTokens != 0 || ant.Usage.OutputTokens != 0 {
model = ant.Model
prompt = ant.Usage.InputTokens
completion = ant.Usage.OutputTokens
total = prompt + completion
ok = true
return
}
}
return
}
// modelsFromContext returns (requested, served) using context-set values
// when present, falling back to the response-reported model for both.
// The router middleware (subsystem 2 of the routing plan) populates
// these; until it lands they are equal.
func modelsFromContext(c echo.Context, fallback string) (string, string) {
requested := fallback
served := fallback
if v, ok := c.Get(ContextKeyRequestedModel).(string); ok && v != "" {
requested = v
}
if v, ok := c.Get(ContextKeyServedModel).(string); ok && v != "" {
served = v
}
return requested, served
}
func promptTokensFromContext(c echo.Context, fallback int64) (int64, int64) {
pre := fallback
post := fallback
if v, ok := c.Get(ContextKeyPreFilterPromptTokens).(int64); ok && v > 0 {
pre = v
}
if v, ok := c.Get(ContextKeyPostFilterPromptTokens).(int64); ok && v > 0 {
post = v
}
return pre, post
}
func correlationIDFromContext(c echo.Context) string {
if v, ok := c.Get(ContextKeyCorrelationID).(string); ok {
return v
}
// X-Correlation-ID header is set by trace.go middleware; read it as a
// fallback if the echo-context binding hasn't been populated yet.
return c.Response().Header().Get("X-Correlation-ID")
}
// lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]".
func lastSSEData(b []byte) ([]byte, bool) {
prefix := []byte("data: ")
var last []byte
for _, line := range bytes.Split(b, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
if bytes.HasPrefix(line, prefix) {
payload := line[len(prefix):]
if !bytes.Equal(payload, []byte("[DONE]")) {
last = payload
}
}
}
return last, last != nil
}