package mitm import ( "bytes" "context" "crypto/tls" "encoding/json" "fmt" "io" "net/http" "strings" "sync/atomic" "time" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services/cloudproxy/ssewire" "github.com/mudler/LocalAI/core/services/routing/pii" "github.com/mudler/LocalAI/core/services/routing/piiadapter" "github.com/mudler/xlog" "golang.org/x/net/http2" ) // PIIHandlerOptions configures NewPIIHandler. type PIIHandlerOptions struct { // Redactor is the regex PII redactor. nil disables redaction. Redactor *pii.Redactor // EventStore receives PIIEvent rows. nil discards events. EventStore pii.EventStore // UpstreamTLS overrides the tls.Config used when dialing the // real upstream. Defaults to a system-trust HTTPS client. UpstreamTLS *tls.Config // CorrelationIDHeader names the request header carrying a // caller-supplied correlation ID. Defaults to "X-Correlation-ID". CorrelationIDHeader string // DialHost optionally remaps the host used for the outbound // upstream URL. Identity by default; tests inject a httptest // listener address. DialHost func(host string) string // HostsWithPIIDisabled lists destination hosts whose request // bodies should NOT run through the redactor. TLS termination, // upstream forwarding, and audit events still happen — only the // regex pass is bypassed. Useful for telemetry/probe endpoints // whose bodies aren't PII-shaped. HostsWithPIIDisabled []string } func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler { tlsCfg := opts.UpstreamTLS if tlsCfg == nil { tlsCfg = &tls.Config{NextProtos: []string{"h2", "http/1.1"}} } else if len(tlsCfg.NextProtos) == 0 { tlsCfg.NextProtos = []string{"h2", "http/1.1"} } transport := &http.Transport{ TLSClientConfig: tlsCfg, ForceAttemptHTTP2: true, } if err := http2.ConfigureTransport(transport); err != nil { xlog.Debug("mitm: http2.ConfigureTransport failed", "error", err) } corrHeader := opts.CorrelationIDHeader if corrHeader == "" { corrHeader = "X-Correlation-ID" } dialHost := opts.DialHost if dialHost == nil { dialHost = func(h string) string { return h } } patternAction := map[string]pii.Action{} if opts.Redactor != nil { for _, p := range opts.Redactor.Patterns() { patternAction[p.ID] = p.Action } } piiDisabled := make(map[string]bool, len(opts.HostsWithPIIDisabled)) for _, h := range opts.HostsWithPIIDisabled { piiDisabled[strings.ToLower(strings.TrimSpace(h))] = true } d := &piiDispatcher{ client: &http.Client{Transport: transport}, redactor: opts.Redactor, store: opts.EventStore, patternAction: patternAction, corrHeader: corrHeader, dialHost: dialHost, piiDisabled: piiDisabled, } return d.serve } type piiDispatcher struct { client *http.Client redactor *pii.Redactor store pii.EventStore patternAction map[string]pii.Action corrHeader string dialHost func(host string) string piiDisabled map[string]bool eventSeq atomic.Uint64 } func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host string) { start := time.Now() cw := &countingResponseWriter{ResponseWriter: w} w = cw var ( correlationID string bytesSent int64 ) defer func() { d.recordTrafficEvent(host, correlationID, bytesSent, cw.bytes, cw.status, start) }() body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "mitm: read body: "+err.Error(), http.StatusBadGateway) return } _ = r.Body.Close() correlationID = r.Header.Get(d.corrHeader) if correlationID == "" { correlationID = r.Header.Get("x-request-id") } shape := classifyRequestShape(host, r.URL.Path) if d.redactor != nil && shape != shapeUnknown && !d.piiDisabled[strings.ToLower(host)] { redacted, blocked, err := d.redactRequest(body, shape, correlationID) switch { case err != nil: xlog.Debug("mitm: redact request failed; forwarding unchanged", "host", host, "path", r.URL.Path, "error", err) case blocked: writePIIBlocked(w, correlationID) return default: body = redacted } } upstreamURL := "https://" + d.dialHost(host) + r.URL.RequestURI() upstreamReq, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, bytes.NewReader(body)) if err != nil { http.Error(w, "mitm: build upstream request: "+err.Error(), http.StatusBadGateway) return } upstreamReq.Header = cloneHopByHopFiltered(r.Header) upstreamReq.ContentLength = int64(len(body)) upstreamReq.Header.Set("Content-Length", fmt.Sprintf("%d", len(body))) bytesSent = int64(len(body)) resp, err := d.client.Do(upstreamReq) if err != nil { http.Error(w, "mitm: upstream: "+err.Error(), http.StatusBadGateway) return } defer func() { _ = resp.Body.Close() }() for k, vs := range resp.Header { if isHopByHop(k) || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Content-Length") { continue } for _, v := range vs { w.Header().Add(k, v) } } w.WriteHeader(resp.StatusCode) contentType := resp.Header.Get("Content-Type") if shape != shapeUnknown && d.redactor != nil && isSSE(contentType) { d.streamWithPII(w, resp.Body, shape, correlationID) return } if isSSE(contentType) { flusher, _ := w.(http.Flusher) buf := make([]byte, 32*1024) for { n, rErr := resp.Body.Read(buf) if n > 0 { if _, wErr := w.Write(buf[:n]); wErr != nil { return } if flusher != nil { flusher.Flush() } } if rErr != nil { return } } } _, _ = io.Copy(w, resp.Body) } type requestShape int const ( shapeUnknown requestShape = iota shapeOpenAIChat shapeAnthropicMessages ) func classifyRequestShape(host, path string) requestShape { host = strings.ToLower(host) switch { case host == "api.openai.com" && strings.HasSuffix(path, "/v1/chat/completions"): return shapeOpenAIChat case host == "api.anthropic.com" && strings.HasSuffix(path, "/v1/messages"): return shapeAnthropicMessages } return shapeUnknown } func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlationID string) ([]byte, bool, error) { var parsed any var adapter pii.Adapter switch shape { case shapeOpenAIChat: req := &schema.OpenAIRequest{} if err := json.Unmarshal(body, req); err != nil { return nil, false, fmt.Errorf("parse openai: %w", err) } parsed = req adapter = piiadapter.OpenAI() case shapeAnthropicMessages: req := &schema.AnthropicRequest{} if err := json.Unmarshal(body, req); err != nil { return nil, false, fmt.Errorf("parse anthropic: %w", err) } parsed = req adapter = piiadapter.Anthropic() default: return body, false, nil } texts := adapter.Scan(parsed) if len(texts) == 0 { return body, false, nil } updates := make([]pii.ScannedText, 0, len(texts)) blocked := false for _, st := range texts { if st.Text == "" { continue } res := d.redactor.RedactWithOverrides(st.Text, nil) if len(res.Spans) == 0 { continue } d.recordEvents(res.Spans, correlationID) if res.Blocked { blocked = true } updates = append(updates, pii.ScannedText{Index: st.Index, Text: res.Redacted}) } if len(updates) > 0 { adapter.Apply(parsed, updates) } out, err := json.Marshal(parsed) if err != nil { return nil, false, fmt.Errorf("re-marshal: %w", err) } return out, blocked, nil } func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) { if d.store == nil { return } for _, span := range spans { ev := pii.PIIEvent{ ID: fmt.Sprintf("mitm_%s_%d", correlationID, d.eventSeq.Add(1)), Kind: pii.KindPII, CorrelationID: correlationID, Direction: pii.DirectionIn, PatternID: span.Pattern, ByteOffset: span.Start, Length: span.End - span.Start, HashPrefix: span.HashPrefix, Action: d.patternAction[span.Pattern], CreatedAt: time.Now(), } if err := d.store.Record(context.Background(), ev); err != nil { xlog.Debug("mitm: failed to record pii event", "error", err, "pattern", span.Pattern) } } } func (d *piiDispatcher) streamWithPII(w http.ResponseWriter, src io.Reader, shape requestShape, correlationID string) { flusher, _ := w.(http.Flusher) filter := pii.NewStreamFilter(d.redactor, nil, d.store, correlationID, "") provider := ssewire.OpenAI if shape == shapeAnthropicMessages { provider = ssewire.Anthropic } emit := func(s string) { _, _ = w.Write([]byte(s)) if flusher != nil { flusher.Flush() } } scanner := ssewire.NewScanner(src) for scanner.Scan() { ev := scanner.Event() if ssewire.IsTerminalMarker(ev.DataLine, provider) { if residual := filter.Drain(); residual != "" { emit(ssewire.SynthResidualEvent(provider, residual)) } emit(ev.Raw) continue } out := ev.Raw if ev.DataLine != "" { rewritten, drop := ssewire.RewritePayload(ev.DataLine, provider, filter) if drop { continue } if rewritten != ev.DataLine { out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1) } } emit(out) } if residual := filter.Drain(); residual != "" { emit(ssewire.SynthResidualEvent(provider, residual)) } } func writePIIBlocked(w http.ResponseWriter, correlationID string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) resp := map[string]any{ "error": map[string]string{ "message": "request blocked by LocalAI MITM proxy (sensitive data detected)", "type": "pii_blocked", }, "correlation_id": correlationID, } _ = json.NewEncoder(w).Encode(resp) } func isSSE(contentType string) bool { return strings.HasPrefix(strings.TrimSpace(contentType), "text/event-stream") } // hopByHopHeaders are not forwarded by the proxy (RFC 7230 §6.1). var hopByHopHeaders = map[string]struct{}{ "Connection": {}, "Keep-Alive": {}, "Proxy-Authenticate": {}, "Proxy-Authorization": {}, "Te": {}, "Trailers": {}, "Transfer-Encoding": {}, "Upgrade": {}, } func isHopByHop(name string) bool { _, ok := hopByHopHeaders[http.CanonicalHeaderKey(name)] return ok } // countingResponseWriter wraps an http.ResponseWriter to track the // total bytes written downstream and the status code. It implements // http.Flusher because the SSE paths flush per event; without that // the assertion `w.(http.Flusher)` would silently degrade to no-op. type countingResponseWriter struct { http.ResponseWriter bytes int64 status int } func (w *countingResponseWriter) Write(p []byte) (int, error) { if w.status == 0 { w.status = http.StatusOK } n, err := w.ResponseWriter.Write(p) w.bytes += int64(n) return n, err } func (w *countingResponseWriter) WriteHeader(code int) { w.status = code w.ResponseWriter.WriteHeader(code) } func (w *countingResponseWriter) Flush() { if f, ok := w.ResponseWriter.(http.Flusher); ok { f.Flush() } } func (d *piiDispatcher) recordTrafficEvent(host, correlationID string, sent, received int64, status int, start time.Time) { if d.store == nil { return } ev := pii.PIIEvent{ ID: fmt.Sprintf("proxy_traffic_%s_%d", correlationID, d.eventSeq.Add(1)), Kind: pii.KindProxyTraffic, CorrelationID: correlationID, Host: host, BytesSent: sent, BytesReceived: received, StatusCode: status, DurationMS: time.Since(start).Milliseconds(), CreatedAt: time.Now(), } if err := d.store.Record(context.Background(), ev); err != nil { xlog.Debug("mitm: failed to record proxy_traffic event", "error", err, "host", host) } } func cloneHopByHopFiltered(in http.Header) http.Header { out := make(http.Header, len(in)) for k, vs := range in { if isHopByHop(k) { continue } copied := make([]string, len(vs)) copy(copied, vs) out[k] = copied } return out }