mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-29 19:19:19 -04:00
Add a routing middleware stack and a cloud-proxy backend. * cloud-proxy: a Go gRPC backend that forwards OpenAI- and Anthropic-shaped chat requests to upstream providers, with an optional translate mode (OpenAI request -> Anthropic /v1/messages -> OpenAI response) and full tool-calling support. * routing: admission control, content-aware model routing (embedding cache + classifier + rerank + Arch-Router score), PII detection/redaction (regex + NER) with streaming filter and OpenAI/Anthropic adapters, and a per-user/per-key billing recorder backed by GORM or in-memory storage. * middleware: UsageMiddleware records usage via the billing recorder, plus admission, route-model, usage-stamp and trace middlewares. * observability: BackendTrace ring buffer stores full request bodies (capped), MITM proxy emits structured trace events, and router classifier decisions surface at /api/router/decide. * gallery: Arch-Router-1.5B (Q4_K_M and Q8_0). * UI: cloud-proxy model-editor fields, classifier system-prompt and score-normalization config, and a Traces page rendering request bodies. Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash] Signed-off-by: Richard Palethorpe <io@richiejp.com>
443 lines
12 KiB
Go
443 lines
12 KiB
Go
package mitm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
|
|
"github.com/mudler/LocalAI/core/services/routing/pii"
|
|
"github.com/mudler/LocalAI/core/services/routing/piiadapter"
|
|
"github.com/mudler/xlog"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
// PIIHandlerOptions configures NewPIIHandler.
|
|
type PIIHandlerOptions struct {
|
|
// Redactor is the regex PII redactor. nil disables redaction.
|
|
Redactor *pii.Redactor
|
|
|
|
// EventStore receives PIIEvent rows. nil discards events.
|
|
EventStore pii.EventStore
|
|
|
|
// UpstreamTLS overrides the tls.Config used when dialing the
|
|
// real upstream. Defaults to a system-trust HTTPS client.
|
|
UpstreamTLS *tls.Config
|
|
|
|
// CorrelationIDHeader names the request header carrying a
|
|
// caller-supplied correlation ID. Defaults to "X-Correlation-ID".
|
|
CorrelationIDHeader string
|
|
|
|
// DialHost optionally remaps the host used for the outbound
|
|
// upstream URL. Identity by default; tests inject a httptest
|
|
// listener address.
|
|
DialHost func(host string) string
|
|
|
|
// HostsWithPIIDisabled lists destination hosts whose request
|
|
// bodies should NOT run through the redactor. TLS termination,
|
|
// upstream forwarding, and audit events still happen — only the
|
|
// regex pass is bypassed. Useful for telemetry/probe endpoints
|
|
// whose bodies aren't PII-shaped.
|
|
HostsWithPIIDisabled []string
|
|
}
|
|
|
|
func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
|
|
tlsCfg := opts.UpstreamTLS
|
|
if tlsCfg == nil {
|
|
tlsCfg = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
|
|
} else if len(tlsCfg.NextProtos) == 0 {
|
|
tlsCfg.NextProtos = []string{"h2", "http/1.1"}
|
|
}
|
|
transport := &http.Transport{
|
|
TLSClientConfig: tlsCfg,
|
|
ForceAttemptHTTP2: true,
|
|
}
|
|
if err := http2.ConfigureTransport(transport); err != nil {
|
|
xlog.Debug("mitm: http2.ConfigureTransport failed", "error", err)
|
|
}
|
|
|
|
corrHeader := opts.CorrelationIDHeader
|
|
if corrHeader == "" {
|
|
corrHeader = "X-Correlation-ID"
|
|
}
|
|
|
|
dialHost := opts.DialHost
|
|
if dialHost == nil {
|
|
dialHost = func(h string) string { return h }
|
|
}
|
|
|
|
patternAction := map[string]pii.Action{}
|
|
if opts.Redactor != nil {
|
|
for _, p := range opts.Redactor.Patterns() {
|
|
patternAction[p.ID] = p.Action
|
|
}
|
|
}
|
|
|
|
piiDisabled := make(map[string]bool, len(opts.HostsWithPIIDisabled))
|
|
for _, h := range opts.HostsWithPIIDisabled {
|
|
piiDisabled[strings.ToLower(strings.TrimSpace(h))] = true
|
|
}
|
|
|
|
d := &piiDispatcher{
|
|
client: &http.Client{Transport: transport},
|
|
redactor: opts.Redactor,
|
|
store: opts.EventStore,
|
|
patternAction: patternAction,
|
|
corrHeader: corrHeader,
|
|
dialHost: dialHost,
|
|
piiDisabled: piiDisabled,
|
|
}
|
|
return d.serve
|
|
}
|
|
|
|
type piiDispatcher struct {
|
|
client *http.Client
|
|
redactor *pii.Redactor
|
|
store pii.EventStore
|
|
patternAction map[string]pii.Action
|
|
corrHeader string
|
|
dialHost func(host string) string
|
|
piiDisabled map[string]bool
|
|
eventSeq atomic.Uint64
|
|
}
|
|
|
|
func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host string) {
|
|
start := time.Now()
|
|
cw := &countingResponseWriter{ResponseWriter: w}
|
|
w = cw
|
|
|
|
var (
|
|
correlationID string
|
|
bytesSent int64
|
|
)
|
|
defer func() {
|
|
d.recordTrafficEvent(host, correlationID, bytesSent, cw.bytes, cw.status, start)
|
|
}()
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, "mitm: read body: "+err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
_ = r.Body.Close()
|
|
|
|
correlationID = r.Header.Get(d.corrHeader)
|
|
if correlationID == "" {
|
|
correlationID = r.Header.Get("x-request-id")
|
|
}
|
|
|
|
shape := classifyRequestShape(host, r.URL.Path)
|
|
if d.redactor != nil && shape != shapeUnknown && !d.piiDisabled[strings.ToLower(host)] {
|
|
redacted, blocked, err := d.redactRequest(body, shape, correlationID)
|
|
switch {
|
|
case err != nil:
|
|
xlog.Debug("mitm: redact request failed; forwarding unchanged", "host", host, "path", r.URL.Path, "error", err)
|
|
case blocked:
|
|
writePIIBlocked(w, correlationID)
|
|
return
|
|
default:
|
|
body = redacted
|
|
}
|
|
}
|
|
|
|
upstreamURL := "https://" + d.dialHost(host) + r.URL.RequestURI()
|
|
upstreamReq, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, bytes.NewReader(body))
|
|
if err != nil {
|
|
http.Error(w, "mitm: build upstream request: "+err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
upstreamReq.Header = cloneHopByHopFiltered(r.Header)
|
|
upstreamReq.ContentLength = int64(len(body))
|
|
upstreamReq.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
|
bytesSent = int64(len(body))
|
|
|
|
resp, err := d.client.Do(upstreamReq)
|
|
if err != nil {
|
|
http.Error(w, "mitm: upstream: "+err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
for k, vs := range resp.Header {
|
|
if isHopByHop(k) || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Content-Length") {
|
|
continue
|
|
}
|
|
for _, v := range vs {
|
|
w.Header().Add(k, v)
|
|
}
|
|
}
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if shape != shapeUnknown && d.redactor != nil && isSSE(contentType) {
|
|
d.streamWithPII(w, resp.Body, shape, correlationID)
|
|
return
|
|
}
|
|
|
|
if isSSE(contentType) {
|
|
flusher, _ := w.(http.Flusher)
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
n, rErr := resp.Body.Read(buf)
|
|
if n > 0 {
|
|
if _, wErr := w.Write(buf[:n]); wErr != nil {
|
|
return
|
|
}
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
if rErr != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
_, _ = io.Copy(w, resp.Body)
|
|
}
|
|
|
|
type requestShape int
|
|
|
|
const (
|
|
shapeUnknown requestShape = iota
|
|
shapeOpenAIChat
|
|
shapeAnthropicMessages
|
|
)
|
|
|
|
func classifyRequestShape(host, path string) requestShape {
|
|
host = strings.ToLower(host)
|
|
switch {
|
|
case host == "api.openai.com" && strings.HasSuffix(path, "/v1/chat/completions"):
|
|
return shapeOpenAIChat
|
|
case host == "api.anthropic.com" && strings.HasSuffix(path, "/v1/messages"):
|
|
return shapeAnthropicMessages
|
|
}
|
|
return shapeUnknown
|
|
}
|
|
|
|
func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlationID string) ([]byte, bool, error) {
|
|
var parsed any
|
|
var adapter pii.Adapter
|
|
switch shape {
|
|
case shapeOpenAIChat:
|
|
req := &schema.OpenAIRequest{}
|
|
if err := json.Unmarshal(body, req); err != nil {
|
|
return nil, false, fmt.Errorf("parse openai: %w", err)
|
|
}
|
|
parsed = req
|
|
adapter = piiadapter.OpenAI()
|
|
case shapeAnthropicMessages:
|
|
req := &schema.AnthropicRequest{}
|
|
if err := json.Unmarshal(body, req); err != nil {
|
|
return nil, false, fmt.Errorf("parse anthropic: %w", err)
|
|
}
|
|
parsed = req
|
|
adapter = piiadapter.Anthropic()
|
|
default:
|
|
return body, false, nil
|
|
}
|
|
|
|
texts := adapter.Scan(parsed)
|
|
if len(texts) == 0 {
|
|
return body, false, nil
|
|
}
|
|
|
|
updates := make([]pii.ScannedText, 0, len(texts))
|
|
blocked := false
|
|
for _, st := range texts {
|
|
if st.Text == "" {
|
|
continue
|
|
}
|
|
res := d.redactor.RedactWithOverrides(st.Text, nil)
|
|
if len(res.Spans) == 0 {
|
|
continue
|
|
}
|
|
d.recordEvents(res.Spans, correlationID)
|
|
if res.Blocked {
|
|
blocked = true
|
|
}
|
|
updates = append(updates, pii.ScannedText{Index: st.Index, Text: res.Redacted})
|
|
}
|
|
|
|
if len(updates) > 0 {
|
|
adapter.Apply(parsed, updates)
|
|
}
|
|
|
|
out, err := json.Marshal(parsed)
|
|
if err != nil {
|
|
return nil, false, fmt.Errorf("re-marshal: %w", err)
|
|
}
|
|
return out, blocked, nil
|
|
}
|
|
|
|
func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) {
|
|
if d.store == nil {
|
|
return
|
|
}
|
|
for _, span := range spans {
|
|
ev := pii.PIIEvent{
|
|
ID: fmt.Sprintf("mitm_%s_%d", correlationID, d.eventSeq.Add(1)),
|
|
Kind: pii.KindPII,
|
|
CorrelationID: correlationID,
|
|
Direction: pii.DirectionIn,
|
|
PatternID: span.Pattern,
|
|
ByteOffset: span.Start,
|
|
Length: span.End - span.Start,
|
|
HashPrefix: span.HashPrefix,
|
|
Action: d.patternAction[span.Pattern],
|
|
CreatedAt: time.Now(),
|
|
}
|
|
if err := d.store.Record(context.Background(), ev); err != nil {
|
|
xlog.Debug("mitm: failed to record pii event", "error", err, "pattern", span.Pattern)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *piiDispatcher) streamWithPII(w http.ResponseWriter, src io.Reader, shape requestShape, correlationID string) {
|
|
flusher, _ := w.(http.Flusher)
|
|
filter := pii.NewStreamFilter(d.redactor, nil, d.store, correlationID, "")
|
|
|
|
provider := ssewire.OpenAI
|
|
if shape == shapeAnthropicMessages {
|
|
provider = ssewire.Anthropic
|
|
}
|
|
|
|
emit := func(s string) {
|
|
_, _ = w.Write([]byte(s))
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
scanner := ssewire.NewScanner(src)
|
|
for scanner.Scan() {
|
|
ev := scanner.Event()
|
|
if ssewire.IsTerminalMarker(ev.DataLine, provider) {
|
|
if residual := filter.Drain(); residual != "" {
|
|
emit(ssewire.SynthResidualEvent(provider, residual))
|
|
}
|
|
emit(ev.Raw)
|
|
continue
|
|
}
|
|
out := ev.Raw
|
|
if ev.DataLine != "" {
|
|
rewritten, drop := ssewire.RewritePayload(ev.DataLine, provider, filter)
|
|
if drop {
|
|
continue
|
|
}
|
|
if rewritten != ev.DataLine {
|
|
out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1)
|
|
}
|
|
}
|
|
emit(out)
|
|
}
|
|
if residual := filter.Drain(); residual != "" {
|
|
emit(ssewire.SynthResidualEvent(provider, residual))
|
|
}
|
|
}
|
|
|
|
func writePIIBlocked(w http.ResponseWriter, correlationID string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
resp := map[string]any{
|
|
"error": map[string]string{
|
|
"message": "request blocked by LocalAI MITM proxy (sensitive data detected)",
|
|
"type": "pii_blocked",
|
|
},
|
|
"correlation_id": correlationID,
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
func isSSE(contentType string) bool {
|
|
return strings.HasPrefix(strings.TrimSpace(contentType), "text/event-stream")
|
|
}
|
|
|
|
// hopByHopHeaders are not forwarded by the proxy (RFC 7230 §6.1).
|
|
var hopByHopHeaders = map[string]struct{}{
|
|
"Connection": {},
|
|
"Keep-Alive": {},
|
|
"Proxy-Authenticate": {},
|
|
"Proxy-Authorization": {},
|
|
"Te": {},
|
|
"Trailers": {},
|
|
"Transfer-Encoding": {},
|
|
"Upgrade": {},
|
|
}
|
|
|
|
func isHopByHop(name string) bool {
|
|
_, ok := hopByHopHeaders[http.CanonicalHeaderKey(name)]
|
|
return ok
|
|
}
|
|
|
|
// countingResponseWriter wraps an http.ResponseWriter to track the
|
|
// total bytes written downstream and the status code. It implements
|
|
// http.Flusher because the SSE paths flush per event; without that
|
|
// the assertion `w.(http.Flusher)` would silently degrade to no-op.
|
|
type countingResponseWriter struct {
|
|
http.ResponseWriter
|
|
bytes int64
|
|
status int
|
|
}
|
|
|
|
func (w *countingResponseWriter) Write(p []byte) (int, error) {
|
|
if w.status == 0 {
|
|
w.status = http.StatusOK
|
|
}
|
|
n, err := w.ResponseWriter.Write(p)
|
|
w.bytes += int64(n)
|
|
return n, err
|
|
}
|
|
|
|
func (w *countingResponseWriter) WriteHeader(code int) {
|
|
w.status = code
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (w *countingResponseWriter) Flush() {
|
|
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
}
|
|
|
|
func (d *piiDispatcher) recordTrafficEvent(host, correlationID string, sent, received int64, status int, start time.Time) {
|
|
if d.store == nil {
|
|
return
|
|
}
|
|
ev := pii.PIIEvent{
|
|
ID: fmt.Sprintf("proxy_traffic_%s_%d", correlationID, d.eventSeq.Add(1)),
|
|
Kind: pii.KindProxyTraffic,
|
|
CorrelationID: correlationID,
|
|
Host: host,
|
|
BytesSent: sent,
|
|
BytesReceived: received,
|
|
StatusCode: status,
|
|
DurationMS: time.Since(start).Milliseconds(),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
if err := d.store.Record(context.Background(), ev); err != nil {
|
|
xlog.Debug("mitm: failed to record proxy_traffic event", "error", err, "host", host)
|
|
}
|
|
}
|
|
|
|
func cloneHopByHopFiltered(in http.Header) http.Header {
|
|
out := make(http.Header, len(in))
|
|
for k, vs := range in {
|
|
if isHopByHop(k) {
|
|
continue
|
|
}
|
|
copied := make([]string, len(vs))
|
|
copy(copied, vs)
|
|
out[k] = copied
|
|
}
|
|
return out
|
|
}
|