Files
LocalAI/core/services/cloudproxy/mitm/handler.go
Richard Palethorpe 12d1f3a697 security(http): refuse redirects on outbound clients via hardened pkg/httpclient (#10087)
LocalAI's outbound HTTP clients used Go's default redirect policy, which
follows up to 10 redirects. On a cross-host redirect Go forwards custom
request headers — including credential headers such as Anthropic's
x-api-key — to the redirect target (Go strips Authorization, Cookie and
WWW-Authenticate cross-host, but NOT arbitrary custom headers). An
attacker able to elicit a redirect from an upstream (a hijacked or
spoofed upstream, DNS trickery, or a malicious upstream_url) then
harvests the operator's provider API key.

This was first reported against the cloud-proxy / MITM PII path
(GHSA-3mj3-57v2-4636); the same class affects every other outbound
client. Rather than patch each call site, add pkg/httpclient as the one
sanctioned constructor for outbound HTTP and route everything through it.

pkg/httpclient:
  - New(...)             refuses redirects, TLS 1.2 floor, no body
                         deadline (streaming/SSE safe)
  - NewWithTimeout(d)    simple request/response calls
  - WithFollowRedirects  opt-in following that still strips credential
                         headers on any cross-host hop; different
                         scheme/host/port == different origin, guarding
                         the curl CVE-2022-27774 port-confusion class
  - WithTransport(rt)    keep a custom transport (IP-pin, HTTP/2, a
                         credential-injecting RoundTripper)
  - HardenedTransport()  base transport with the TLS floor + bounded setup
  - Harden(c)            apply the policy to a library-supplied *http.Client
  - NoRedirect           the CheckRedirect policy; wraps ErrRedirectBlocked

Lint: a forbidigo rule flags http.DefaultClient and http.Get/Post/
PostForm/Head, pointing at pkg/httpclient (.golangci.yml,
.agents/coding-style.md). forbidigo cannot match the &http.Client{}
composite literal without also flagging legitimate *http.Client type
references, so that form is enforced by review.

Migrates every non-test outbound call site across core/, pkg/, cmd/, and
the Go backend (backend/go/cloud-proxy). Credential-bearing and
internal-RPC clients refuse redirects; download / CDN / registry clients
use WithFollowRedirects so they keep working while stripping secrets
cross-host. The only credential-bearing client that follows redirects is
the gated-download path (pkg/downloader/uri.go), which strips the token
on the cross-host hop to the CDN. Hardening this closes, in passing:
  - MCP remote-server bearer token leaking via a redirect (the
    RoundTripper re-injected Authorization on every hop)
  - agent multimedia/webhook clients leaking user-supplied auth headers
  - cors_proxy following redirects, bypassing its SSRF IP-pin
  - downloader's authorized read path leaking the token cross-host

Fixes: GHSA-3mj3-57v2-4636 (cloud-proxy leaks operator provider API key
(x-api-key) to attacker host on cross-host redirect)
Reported-by: tonghuaroot
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-30 12:04:10 +02:00

452 lines
12 KiB
Go

package mitm
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/mudler/xlog"
"golang.org/x/net/http2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/piiadapter"
"github.com/mudler/LocalAI/pkg/httpclient"
)
// PIIHandlerOptions configures NewPIIHandler.
type PIIHandlerOptions struct {
// Redactor is the regex PII redactor. nil disables redaction.
Redactor *pii.Redactor
// EventStore receives PIIEvent rows. nil discards events.
EventStore pii.EventStore
// UpstreamTLS overrides the tls.Config used when dialing the
// real upstream. Defaults to a system-trust HTTPS client.
UpstreamTLS *tls.Config
// CorrelationIDHeader names the request header carrying a
// caller-supplied correlation ID. Defaults to "X-Correlation-ID".
CorrelationIDHeader string
// DialHost optionally remaps the host used for the outbound
// upstream URL. Identity by default; tests inject a httptest
// listener address.
DialHost func(host string) string
// HostsWithPIIDisabled lists destination hosts whose request
// bodies should NOT run through the redactor. TLS termination,
// upstream forwarding, and audit events still happen — only the
// regex pass is bypassed. Useful for telemetry/probe endpoints
// whose bodies aren't PII-shaped.
HostsWithPIIDisabled []string
}
func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
tlsCfg := opts.UpstreamTLS
if tlsCfg == nil {
tlsCfg = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
} else if len(tlsCfg.NextProtos) == 0 {
tlsCfg.NextProtos = []string{"h2", "http/1.1"}
}
transport := &http.Transport{
TLSClientConfig: tlsCfg,
ForceAttemptHTTP2: true,
}
if err := http2.ConfigureTransport(transport); err != nil {
xlog.Debug("mitm: http2.ConfigureTransport failed", "error", err)
}
corrHeader := opts.CorrelationIDHeader
if corrHeader == "" {
corrHeader = "X-Correlation-ID"
}
dialHost := opts.DialHost
if dialHost == nil {
dialHost = func(h string) string { return h }
}
patternAction := map[string]pii.Action{}
if opts.Redactor != nil {
for _, p := range opts.Redactor.Patterns() {
patternAction[p.ID] = p.Action
}
}
piiDisabled := make(map[string]bool, len(opts.HostsWithPIIDisabled))
for _, h := range opts.HostsWithPIIDisabled {
piiDisabled[strings.ToLower(strings.TrimSpace(h))] = true
}
d := &piiDispatcher{
// Refuse redirects: the MITM client forwards to the real
// upstream over TLS, and a 3xx means the upstream (or something
// impersonating it) is trying to bounce the request elsewhere.
// Following it would replay caller headers — including provider
// API keys such as Anthropic's x-api-key, which Go does NOT
// strip on cross-host redirects — to an unvetted host. Surface
// it as an error (handled as a 502) instead.
client: httpclient.New(httpclient.WithTransport(transport)),
redactor: opts.Redactor,
store: opts.EventStore,
patternAction: patternAction,
corrHeader: corrHeader,
dialHost: dialHost,
piiDisabled: piiDisabled,
}
return d.serve
}
type piiDispatcher struct {
client *http.Client
redactor *pii.Redactor
store pii.EventStore
patternAction map[string]pii.Action
corrHeader string
dialHost func(host string) string
piiDisabled map[string]bool
eventSeq atomic.Uint64
}
func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host string) {
start := time.Now()
cw := &countingResponseWriter{ResponseWriter: w}
w = cw
var (
correlationID string
bytesSent int64
)
defer func() {
d.recordTrafficEvent(host, correlationID, bytesSent, cw.bytes, cw.status, start)
}()
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "mitm: read body: "+err.Error(), http.StatusBadGateway)
return
}
_ = r.Body.Close()
correlationID = r.Header.Get(d.corrHeader)
if correlationID == "" {
correlationID = r.Header.Get("x-request-id")
}
shape := classifyRequestShape(host, r.URL.Path)
if d.redactor != nil && shape != shapeUnknown && !d.piiDisabled[strings.ToLower(host)] {
redacted, blocked, err := d.redactRequest(body, shape, correlationID)
switch {
case err != nil:
xlog.Debug("mitm: redact request failed; forwarding unchanged", "host", host, "path", r.URL.Path, "error", err)
case blocked:
writePIIBlocked(w, correlationID)
return
default:
body = redacted
}
}
upstreamURL := "https://" + d.dialHost(host) + r.URL.RequestURI()
upstreamReq, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, bytes.NewReader(body))
if err != nil {
http.Error(w, "mitm: build upstream request: "+err.Error(), http.StatusBadGateway)
return
}
upstreamReq.Header = cloneHopByHopFiltered(r.Header)
upstreamReq.ContentLength = int64(len(body))
upstreamReq.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
bytesSent = int64(len(body))
resp, err := d.client.Do(upstreamReq)
if err != nil {
http.Error(w, "mitm: upstream: "+err.Error(), http.StatusBadGateway)
return
}
defer func() { _ = resp.Body.Close() }()
for k, vs := range resp.Header {
if isHopByHop(k) || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Content-Length") {
continue
}
for _, v := range vs {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
contentType := resp.Header.Get("Content-Type")
if shape != shapeUnknown && d.redactor != nil && isSSE(contentType) {
d.streamWithPII(w, resp.Body, shape, correlationID)
return
}
if isSSE(contentType) {
flusher, _ := w.(http.Flusher)
buf := make([]byte, 32*1024)
for {
n, rErr := resp.Body.Read(buf)
if n > 0 {
if _, wErr := w.Write(buf[:n]); wErr != nil {
return
}
if flusher != nil {
flusher.Flush()
}
}
if rErr != nil {
return
}
}
}
_, _ = io.Copy(w, resp.Body)
}
type requestShape int
const (
shapeUnknown requestShape = iota
shapeOpenAIChat
shapeAnthropicMessages
)
func classifyRequestShape(host, path string) requestShape {
host = strings.ToLower(host)
switch {
case host == "api.openai.com" && strings.HasSuffix(path, "/v1/chat/completions"):
return shapeOpenAIChat
case host == "api.anthropic.com" && strings.HasSuffix(path, "/v1/messages"):
return shapeAnthropicMessages
}
return shapeUnknown
}
func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlationID string) ([]byte, bool, error) {
var parsed any
var adapter pii.Adapter
switch shape {
case shapeOpenAIChat:
req := &schema.OpenAIRequest{}
if err := json.Unmarshal(body, req); err != nil {
return nil, false, fmt.Errorf("parse openai: %w", err)
}
parsed = req
adapter = piiadapter.OpenAI()
case shapeAnthropicMessages:
req := &schema.AnthropicRequest{}
if err := json.Unmarshal(body, req); err != nil {
return nil, false, fmt.Errorf("parse anthropic: %w", err)
}
parsed = req
adapter = piiadapter.Anthropic()
default:
return body, false, nil
}
texts := adapter.Scan(parsed)
if len(texts) == 0 {
return body, false, nil
}
updates := make([]pii.ScannedText, 0, len(texts))
blocked := false
for _, st := range texts {
if st.Text == "" {
continue
}
res := d.redactor.RedactWithOverrides(st.Text, nil)
if len(res.Spans) == 0 {
continue
}
d.recordEvents(res.Spans, correlationID)
if res.Blocked {
blocked = true
}
updates = append(updates, pii.ScannedText{Index: st.Index, Text: res.Redacted})
}
if len(updates) > 0 {
adapter.Apply(parsed, updates)
}
out, err := json.Marshal(parsed)
if err != nil {
return nil, false, fmt.Errorf("re-marshal: %w", err)
}
return out, blocked, nil
}
func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) {
if d.store == nil {
return
}
for _, span := range spans {
ev := pii.PIIEvent{
ID: fmt.Sprintf("mitm_%s_%d", correlationID, d.eventSeq.Add(1)),
Kind: pii.KindPII,
CorrelationID: correlationID,
Direction: pii.DirectionIn,
PatternID: span.Pattern,
ByteOffset: span.Start,
Length: span.End - span.Start,
HashPrefix: span.HashPrefix,
Action: d.patternAction[span.Pattern],
CreatedAt: time.Now(),
}
if err := d.store.Record(context.Background(), ev); err != nil {
xlog.Debug("mitm: failed to record pii event", "error", err, "pattern", span.Pattern)
}
}
}
func (d *piiDispatcher) streamWithPII(w http.ResponseWriter, src io.Reader, shape requestShape, correlationID string) {
flusher, _ := w.(http.Flusher)
filter := pii.NewStreamFilter(d.redactor, nil, d.store, correlationID, "")
provider := ssewire.OpenAI
if shape == shapeAnthropicMessages {
provider = ssewire.Anthropic
}
emit := func(s string) {
_, _ = w.Write([]byte(s))
if flusher != nil {
flusher.Flush()
}
}
scanner := ssewire.NewScanner(src)
for scanner.Scan() {
ev := scanner.Event()
if ssewire.IsTerminalMarker(ev.DataLine, provider) {
if residual := filter.Drain(); residual != "" {
emit(ssewire.SynthResidualEvent(provider, residual))
}
emit(ev.Raw)
continue
}
out := ev.Raw
if ev.DataLine != "" {
rewritten, drop := ssewire.RewritePayload(ev.DataLine, provider, filter)
if drop {
continue
}
if rewritten != ev.DataLine {
out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1)
}
}
emit(out)
}
if residual := filter.Drain(); residual != "" {
emit(ssewire.SynthResidualEvent(provider, residual))
}
}
func writePIIBlocked(w http.ResponseWriter, correlationID string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
resp := map[string]any{
"error": map[string]string{
"message": "request blocked by LocalAI MITM proxy (sensitive data detected)",
"type": "pii_blocked",
},
"correlation_id": correlationID,
}
_ = json.NewEncoder(w).Encode(resp)
}
func isSSE(contentType string) bool {
return strings.HasPrefix(strings.TrimSpace(contentType), "text/event-stream")
}
// hopByHopHeaders are not forwarded by the proxy (RFC 7230 §6.1).
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Te": {},
"Trailers": {},
"Transfer-Encoding": {},
"Upgrade": {},
}
func isHopByHop(name string) bool {
_, ok := hopByHopHeaders[http.CanonicalHeaderKey(name)]
return ok
}
// countingResponseWriter wraps an http.ResponseWriter to track the
// total bytes written downstream and the status code. It implements
// http.Flusher because the SSE paths flush per event; without that
// the assertion `w.(http.Flusher)` would silently degrade to no-op.
type countingResponseWriter struct {
http.ResponseWriter
bytes int64
status int
}
func (w *countingResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
n, err := w.ResponseWriter.Write(p)
w.bytes += int64(n)
return n, err
}
func (w *countingResponseWriter) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}
func (w *countingResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
func (d *piiDispatcher) recordTrafficEvent(host, correlationID string, sent, received int64, status int, start time.Time) {
if d.store == nil {
return
}
ev := pii.PIIEvent{
ID: fmt.Sprintf("proxy_traffic_%s_%d", correlationID, d.eventSeq.Add(1)),
Kind: pii.KindProxyTraffic,
CorrelationID: correlationID,
Host: host,
BytesSent: sent,
BytesReceived: received,
StatusCode: status,
DurationMS: time.Since(start).Milliseconds(),
CreatedAt: time.Now(),
}
if err := d.store.Record(context.Background(), ev); err != nil {
xlog.Debug("mitm: failed to record proxy_traffic event", "error", err, "host", host)
}
}
func cloneHopByHopFiltered(in http.Header) http.Header {
out := make(http.Header, len(in))
for k, vs := range in {
if isHopByHop(k) {
continue
}
copied := make([]string, len(vs))
copy(copied, vs)
out[k] = copied
}
return out
}