mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-29 11:07:18 -04:00
Add a routing middleware stack and a cloud-proxy backend. * cloud-proxy: a Go gRPC backend that forwards OpenAI- and Anthropic-shaped chat requests to upstream providers, with an optional translate mode (OpenAI request -> Anthropic /v1/messages -> OpenAI response) and full tool-calling support. * routing: admission control, content-aware model routing (embedding cache + classifier + rerank + Arch-Router score), PII detection/redaction (regex + NER) with streaming filter and OpenAI/Anthropic adapters, and a per-user/per-key billing recorder backed by GORM or in-memory storage. * middleware: UsageMiddleware records usage via the billing recorder, plus admission, route-model, usage-stamp and trace middlewares. * observability: BackendTrace ring buffer stores full request bodies (capped), MITM proxy emits structured trace events, and router classifier decisions surface at /api/router/decide. * gallery: Arch-Router-1.5B (Q4_K_M and Q8_0). * UI: cloud-proxy model-editor fields, classifier system-prompt and score-normalization config, and a Traces page rendering request bodies. Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash] Signed-off-by: Richard Palethorpe <io@richiejp.com>
292 lines
9.5 KiB
Go
292 lines
9.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/LocalAI/core/http/auth"
|
|
"github.com/mudler/LocalAI/core/services/routing/billing"
|
|
"github.com/mudler/xlog"
|
|
)
|
|
|
|
// usageResponseBody is the minimal structure we need from an OpenAI-shaped
|
|
// JSON response. Anthropic responses are decoded separately because their
|
|
// usage block uses different field names (input_tokens / output_tokens).
|
|
type usageResponseBody struct {
|
|
Model string `json:"model"`
|
|
Usage *struct {
|
|
PromptTokens int64 `json:"prompt_tokens"`
|
|
CompletionTokens int64 `json:"completion_tokens"`
|
|
TotalTokens int64 `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
// anthropicResponseBody covers /v1/messages JSON responses.
|
|
type anthropicResponseBody struct {
|
|
Model string `json:"model"`
|
|
Usage *struct {
|
|
InputTokens int64 `json:"input_tokens"`
|
|
OutputTokens int64 `json:"output_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
// UsageMiddleware records token usage for inference requests via the
|
|
// billing.Recorder. Two paths produce a record:
|
|
//
|
|
// 1. Handler-stamped (preferred): the request handler called
|
|
// middleware.StampUsage with the canonical token counts before
|
|
// returning. This is the only reliable path for streaming responses
|
|
// — clients rarely set OpenAI's stream_options.include_usage, and
|
|
// Anthropic's usage lives in a separate message_delta event.
|
|
// 2. Body-parsed (fallback): the response is parsed for an OpenAI- or
|
|
// Anthropic-shaped usage block. Used by passthrough proxies and
|
|
// foreign endpoints.
|
|
//
|
|
// Recorder being nil (e.g., --disable-stats) makes the middleware a
|
|
// transparent pass-through. fallbackUser is used when auth.GetUser(c)
|
|
// returns nil; without it, an unauthenticated request would be dropped.
|
|
//
|
|
// Every request that fails to produce a record ticks
|
|
// localai_usage_unrecorded_total so silent billing misses are observable.
|
|
func UsageMiddleware(recorder *billing.Recorder, fallbackUser *auth.User) echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if recorder == nil {
|
|
return next(c)
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
// Wrap response writer to capture body for the fallback parser.
|
|
// When the handler stamps the context we never read this buffer,
|
|
// so the cost is the per-chunk Write going through one extra
|
|
// indirection — accepted overhead in exchange for one billing
|
|
// path that works for both stamping and body-parse callers.
|
|
resBody := new(bytes.Buffer)
|
|
origWriter := c.Response().Writer
|
|
mw := &bodyWriter{
|
|
ResponseWriter: origWriter,
|
|
body: resBody,
|
|
}
|
|
c.Response().Writer = mw
|
|
|
|
handlerErr := next(c)
|
|
|
|
c.Response().Writer = origWriter
|
|
|
|
endpoint := c.Request().URL.Path
|
|
|
|
if c.Response().Status < 200 || c.Response().Status >= 300 {
|
|
return handlerErr
|
|
}
|
|
|
|
user := auth.GetUser(c)
|
|
if user == nil {
|
|
user = fallbackUser
|
|
}
|
|
if user == nil || user.ID == "" {
|
|
billing.CountUnrecorded(context.Background(), endpoint, "no_user")
|
|
return handlerErr
|
|
}
|
|
|
|
model, prompt, completion, total, ok := tokensFromContext(c)
|
|
if !ok {
|
|
model, prompt, completion, total, ok = tokensFromBody(resBody.Bytes(), c.Response().Header().Get("Content-Type"))
|
|
}
|
|
if !ok {
|
|
billing.CountUnrecorded(context.Background(), endpoint, "no_usage")
|
|
return handlerErr
|
|
}
|
|
|
|
requested, served := modelsFromContext(c, model)
|
|
pre, post := promptTokensFromContext(c, prompt)
|
|
|
|
source := auth.GetSource(c)
|
|
if source == "" {
|
|
// Auth disabled or unrecognised path: classify as web so the row is still
|
|
// bucketable rather than silently dropped from per-source aggregates.
|
|
source = auth.UsageSourceWeb
|
|
}
|
|
|
|
record := &auth.UsageRecord{
|
|
UserID: user.ID,
|
|
UserName: user.Name,
|
|
Source: source,
|
|
Model: model,
|
|
Endpoint: endpoint,
|
|
PromptTokens: prompt,
|
|
CompletionTokens: completion,
|
|
TotalTokens: total,
|
|
Duration: time.Since(startTime).Milliseconds(),
|
|
CreatedAt: startTime,
|
|
RequestedModel: requested,
|
|
ServedModel: served,
|
|
PreFilterPromptTokens: pre,
|
|
PostFilterPromptTokens: post,
|
|
CorrelationID: correlationIDFromContext(c),
|
|
}
|
|
|
|
if key := auth.GetAPIKey(c); key != nil {
|
|
id := key.ID
|
|
record.APIKeyID = &id
|
|
record.APIKeyName = key.Name
|
|
}
|
|
|
|
if err := recorder.Record(context.Background(), record); err != nil {
|
|
xlog.Error("usage middleware: recorder.Record failed", "error", err, "user", user.ID, "model", model)
|
|
billing.CountUnrecorded(context.Background(), endpoint, "record_failed")
|
|
}
|
|
|
|
return handlerErr
|
|
}
|
|
}
|
|
}
|
|
|
|
// tokensFromContext returns canonical token counts stamped by a handler
|
|
// via middleware.StampUsage. Returns ok=false when no stamp is present
|
|
// — the caller then tries the body-parse fallback.
|
|
//
|
|
// A model name without token counts is not considered "stamped" because a
|
|
// record with zero tokens looks the same as a never-recorded request to
|
|
// later analytics; the second condition is what gates ok.
|
|
func tokensFromContext(c echo.Context) (model string, prompt, completion, total int64, ok bool) {
|
|
if v, found := c.Get(ContextKeyResponseModel).(string); found {
|
|
model = v
|
|
}
|
|
pPresent := false
|
|
cPresent := false
|
|
if v, found := c.Get(ContextKeyPromptTokens).(int64); found {
|
|
prompt = v
|
|
pPresent = true
|
|
}
|
|
if v, found := c.Get(ContextKeyCompletionTokens).(int64); found {
|
|
completion = v
|
|
cPresent = true
|
|
}
|
|
if v, found := c.Get(ContextKeyTotalTokens).(int64); found {
|
|
total = v
|
|
} else {
|
|
total = prompt + completion
|
|
}
|
|
ok = pPresent || cPresent
|
|
return
|
|
}
|
|
|
|
// tokensFromBody covers the passthrough-proxy / foreign-endpoint case
|
|
// where no handler stamps the context. Returns ok=false on any parse
|
|
// failure or missing-usage; the caller increments the unrecorded counter.
|
|
func tokensFromBody(responseBytes []byte, contentType string) (model string, prompt, completion, total int64, ok bool) {
|
|
if len(responseBytes) == 0 {
|
|
return
|
|
}
|
|
isJSON := contentType == "" || contentType == "application/json" || bytes.HasPrefix([]byte(contentType), []byte("application/json"))
|
|
isSSE := bytes.HasPrefix([]byte(contentType), []byte("text/event-stream"))
|
|
if !isJSON && !isSSE {
|
|
return
|
|
}
|
|
|
|
payload := responseBytes
|
|
if isSSE {
|
|
// For SSE, the canonical usage chunk is the *last* non-[DONE] data
|
|
// line. OpenAI clients only emit one if stream_options.include_usage
|
|
// is set; Anthropic emits a final message_delta with usage. Both
|
|
// fit the "last data: line" rule.
|
|
last, lastOk := lastSSEData(responseBytes)
|
|
if !lastOk {
|
|
return
|
|
}
|
|
payload = last
|
|
}
|
|
|
|
// Try OpenAI shape first (handles /v1/chat/completions, /v1/completions,
|
|
// /v1/embeddings, /v1/edits, and any proxy that translates to OpenAI).
|
|
// A usage block whose token fields all decoded to zero is ambiguous —
|
|
// it could be an Anthropic body that happens to have a `usage` key —
|
|
// so fall through to the Anthropic parser instead of recording zeros.
|
|
var openAI usageResponseBody
|
|
if err := json.Unmarshal(payload, &openAI); err == nil && openAI.Usage != nil {
|
|
if openAI.Usage.PromptTokens != 0 || openAI.Usage.CompletionTokens != 0 || openAI.Usage.TotalTokens != 0 {
|
|
model = openAI.Model
|
|
prompt = openAI.Usage.PromptTokens
|
|
completion = openAI.Usage.CompletionTokens
|
|
total = openAI.Usage.TotalTokens
|
|
if total == 0 {
|
|
total = prompt + completion
|
|
}
|
|
ok = true
|
|
return
|
|
}
|
|
}
|
|
|
|
// Fall through to Anthropic shape (proxy passthrough territory).
|
|
var ant anthropicResponseBody
|
|
if err := json.Unmarshal(payload, &ant); err == nil && ant.Usage != nil {
|
|
if ant.Usage.InputTokens != 0 || ant.Usage.OutputTokens != 0 {
|
|
model = ant.Model
|
|
prompt = ant.Usage.InputTokens
|
|
completion = ant.Usage.OutputTokens
|
|
total = prompt + completion
|
|
ok = true
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// modelsFromContext returns (requested, served) using context-set values
|
|
// when present, falling back to the response-reported model for both.
|
|
// The router middleware (subsystem 2 of the routing plan) populates
|
|
// these; until it lands they are equal.
|
|
func modelsFromContext(c echo.Context, fallback string) (string, string) {
|
|
requested := fallback
|
|
served := fallback
|
|
if v, ok := c.Get(ContextKeyRequestedModel).(string); ok && v != "" {
|
|
requested = v
|
|
}
|
|
if v, ok := c.Get(ContextKeyServedModel).(string); ok && v != "" {
|
|
served = v
|
|
}
|
|
return requested, served
|
|
}
|
|
|
|
func promptTokensFromContext(c echo.Context, fallback int64) (int64, int64) {
|
|
pre := fallback
|
|
post := fallback
|
|
if v, ok := c.Get(ContextKeyPreFilterPromptTokens).(int64); ok && v > 0 {
|
|
pre = v
|
|
}
|
|
if v, ok := c.Get(ContextKeyPostFilterPromptTokens).(int64); ok && v > 0 {
|
|
post = v
|
|
}
|
|
return pre, post
|
|
}
|
|
|
|
func correlationIDFromContext(c echo.Context) string {
|
|
if v, ok := c.Get(ContextKeyCorrelationID).(string); ok {
|
|
return v
|
|
}
|
|
// X-Correlation-ID header is set by trace.go middleware; read it as a
|
|
// fallback if the echo-context binding hasn't been populated yet.
|
|
return c.Response().Header().Get("X-Correlation-ID")
|
|
}
|
|
|
|
// lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]".
|
|
func lastSSEData(b []byte) ([]byte, bool) {
|
|
prefix := []byte("data: ")
|
|
var last []byte
|
|
for _, line := range bytes.Split(b, []byte("\n")) {
|
|
line = bytes.TrimRight(line, "\r")
|
|
if bytes.HasPrefix(line, prefix) {
|
|
payload := line[len(prefix):]
|
|
if !bytes.Equal(payload, []byte("[DONE]")) {
|
|
last = payload
|
|
}
|
|
}
|
|
}
|
|
return last, last != nil
|
|
}
|