Files
LocalAI/core/http/middleware/trace.go
Richard Palethorpe 6a80e23733 feat(middleware): Model routing, PII filtering, Cloud model proxies (#9802)
Add a routing middleware stack and a cloud-proxy backend.

* cloud-proxy: a Go gRPC backend that forwards OpenAI- and
  Anthropic-shaped chat requests to upstream providers, with an
  optional translate mode (OpenAI request -> Anthropic /v1/messages
  -> OpenAI response) and full tool-calling support.

* routing: admission control, content-aware model routing
  (embedding cache + classifier + rerank + Arch-Router score),
  PII detection/redaction (regex + NER) with streaming filter and
  OpenAI/Anthropic adapters, and a per-user/per-key billing recorder
  backed by GORM or in-memory storage.

* middleware: UsageMiddleware records usage via the billing recorder,
  plus admission, route-model, usage-stamp and trace middlewares.

* observability: BackendTrace ring buffer stores full request bodies
  (capped), MITM proxy emits structured trace events, and router
  classifier decisions surface at /api/router/decide.

* gallery: Arch-Router-1.5B (Q4_K_M and Q8_0).

* UI: cloud-proxy model-editor fields, classifier system-prompt and
  score-normalization config, and a Traces page rendering request
  bodies.

Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-25 09:28:27 +02:00

285 lines
7.8 KiB
Go

package middleware
import (
"bufio"
"bytes"
"io"
"mime"
"net"
"net/http"
"slices"
"sync"
"time"
"github.com/emirpasic/gods/v2/queues/circularbuffer"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/xlog"
)
type APIExchangeRequest struct {
Method string `json:"method"`
Path string `json:"path"`
Headers *http.Header `json:"headers"`
Body *[]byte `json:"body"`
BodyTruncated bool `json:"body_truncated,omitempty"`
BodyBytes int `json:"body_bytes,omitempty"` // original size before truncation
}
type APIExchangeResponse struct {
Status int `json:"status"`
Headers *http.Header `json:"headers"`
Body *[]byte `json:"body"`
BodyTruncated bool `json:"body_truncated,omitempty"`
BodyBytes int `json:"body_bytes,omitempty"` // original size before truncation
}
type APIExchange struct {
Timestamp time.Time `json:"timestamp"`
Duration time.Duration `json:"duration"`
Request APIExchangeRequest `json:"request"`
Response APIExchangeResponse `json:"response"`
Error string `json:"error,omitempty"`
UserID string `json:"user_id,omitempty"`
UserName string `json:"user_name,omitempty"`
}
var traceBuffer *circularbuffer.Queue[APIExchange]
var mu sync.Mutex
var logChan = make(chan APIExchange, 100)
var tracingMaxItems int
var doInitializeTracing = sync.OnceFunc(func() {
maxItems := tracingMaxItems
if maxItems <= 0 {
maxItems = 100
}
mu.Lock()
traceBuffer = circularbuffer.New[APIExchange](maxItems)
mu.Unlock()
go func() {
for exchange := range logChan {
mu.Lock()
if traceBuffer != nil {
traceBuffer.Enqueue(exchange)
}
mu.Unlock()
}
}()
})
type bodyWriter struct {
http.ResponseWriter
body *bytes.Buffer
maxBytes int // 0 = unlimited capture
truncated bool
totalBytes int // bytes the upstream handler wrote, even past the cap
}
func (w *bodyWriter) Write(b []byte) (int, error) {
// Capture into the trace buffer up to maxBytes, then drop the overflow
// so a chatty endpoint can't grow the buffer without bound. The full
// payload still flows through to the real client below.
w.totalBytes += len(b)
if w.maxBytes <= 0 {
w.body.Write(b)
} else if remain := w.maxBytes - w.body.Len(); remain > 0 {
if remain >= len(b) {
w.body.Write(b)
} else {
w.body.Write(b[:remain])
w.truncated = true
}
} else {
w.truncated = true
}
return w.ResponseWriter.Write(b)
}
func (w *bodyWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// truncateForTrace returns a defensive copy of body capped at maxBytes,
// and a flag indicating whether the cap forced truncation. maxBytes <= 0
// disables the cap.
func truncateForTrace(body []byte, maxBytes int) ([]byte, bool) {
if maxBytes <= 0 || len(body) <= maxBytes {
out := make([]byte, len(body))
copy(out, body)
return out, false
}
out := make([]byte, maxBytes)
copy(out, body[:maxBytes])
return out, true
}
// Hijack lets WebSocket upgraders (gorilla/websocket) reach the
// underlying connection. Without this, gorilla's Hijacker type-assertion
// fails on the wrapped writer and the handshake returns 500.
func (w *bodyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, http.ErrNotSupported
}
func initializeTracing(maxItems int) {
tracingMaxItems = maxItems
doInitializeTracing()
}
// sensitiveTraceHeaders is the set of header names whose values must not
// land in the in-memory trace buffer. Keys are canonical — http.Header
// stores them that way, so range yields canonical keys directly.
var sensitiveTraceHeaders = map[string]struct{}{
"Authorization": {},
"Proxy-Authorization": {},
"Cookie": {},
"Set-Cookie": {},
"X-Api-Key": {},
"Xi-Api-Key": {},
"X-Auth-Token": {},
}
func redactSensitiveHeaders(h http.Header) http.Header {
out := h.Clone()
for k := range out {
if _, ok := sensitiveTraceHeaders[k]; ok {
out[k] = []string{"[redacted]"}
}
}
return out
}
// TraceMiddleware intercepts and logs JSON API requests and responses
func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !app.ApplicationConfig().EnableTracing {
return next(c)
}
initializeTracing(app.ApplicationConfig().TracingMaxItems)
ct, _, _ := mime.ParseMediaType(c.Request().Header.Get("Content-Type"))
if ct != "application/json" {
return next(c)
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
xlog.Error("Failed to read request body")
return err
}
// Restore the body for downstream handlers
c.Request().Body = io.NopCloser(bytes.NewBuffer(body))
startTime := time.Now()
// Cap captured payload size. Without this, /embeddings and
// streaming /chat/completions blow the in-memory buffer into the
// tens of MB, which then locks the admin Traces UI fetching the
// JSON dump faster than the 5s auto-refresh.
maxBodyBytes := app.ApplicationConfig().TracingMaxBodyBytes
// Wrap response writer to capture body
resBody := new(bytes.Buffer)
mw := &bodyWriter{
ResponseWriter: c.Response().Writer,
body: resBody,
maxBytes: maxBodyBytes,
}
c.Response().Writer = mw
handlerErr := next(c)
// Restore original writer unconditionally
c.Response().Writer = mw.ResponseWriter
// Determine response status (use 500 if handler errored and no status was set)
status := c.Response().Status
if status == 0 && handlerErr != nil {
status = http.StatusInternalServerError
}
// Create exchange log (always, even on error). Sensitive headers
// (Authorization, API keys, cookies) are redacted before storage —
// the trace endpoint is admin-only but the buffer is also reachable
// via any heap-dump-style introspection, and tokens shouldn't
// outlive the request that carried them.
requestHeaders := redactSensitiveHeaders(c.Request().Header)
requestBody, requestTruncated := truncateForTrace(body, maxBodyBytes)
responseHeaders := redactSensitiveHeaders(c.Response().Header())
responseBody := make([]byte, resBody.Len())
copy(responseBody, resBody.Bytes())
exchange := APIExchange{
Timestamp: startTime,
Duration: time.Since(startTime),
Request: APIExchangeRequest{
Method: c.Request().Method,
Path: c.Path(),
Headers: &requestHeaders,
Body: &requestBody,
BodyTruncated: requestTruncated,
BodyBytes: len(body),
},
Response: APIExchangeResponse{
Status: status,
Headers: &responseHeaders,
Body: &responseBody,
BodyTruncated: mw.truncated,
BodyBytes: mw.totalBytes,
},
}
if handlerErr != nil {
exchange.Error = handlerErr.Error()
}
if user := auth.GetUser(c); user != nil {
exchange.UserID = user.ID
exchange.UserName = user.Name
}
select {
case logChan <- exchange:
default:
xlog.Warn("Trace channel full, dropping trace")
}
return handlerErr
}
}
}
// GetTraces returns a copy of the logged API exchanges for display
func GetTraces() []APIExchange {
mu.Lock()
if traceBuffer == nil {
mu.Unlock()
return []APIExchange{}
}
traces := traceBuffer.Values()
mu.Unlock()
slices.SortFunc(traces, func(a, b APIExchange) int {
return b.Timestamp.Compare(a.Timestamp)
})
return traces
}
// ClearTraces clears the in-memory logs
func ClearTraces() {
mu.Lock()
if traceBuffer != nil {
traceBuffer.Clear()
}
mu.Unlock()
}