Files
LocalAI/core/services/cloudproxy/mitm/proxy.go
Richard Palethorpe 6a80e23733 feat(middleware): Model routing, PII filtering, Cloud model proxies (#9802)
Add a routing middleware stack and a cloud-proxy backend.

* cloud-proxy: a Go gRPC backend that forwards OpenAI- and
  Anthropic-shaped chat requests to upstream providers, with an
  optional translate mode (OpenAI request -> Anthropic /v1/messages
  -> OpenAI response) and full tool-calling support.

* routing: admission control, content-aware model routing
  (embedding cache + classifier + rerank + Arch-Router score),
  PII detection/redaction (regex + NER) with streaming filter and
  OpenAI/Anthropic adapters, and a per-user/per-key billing recorder
  backed by GORM or in-memory storage.

* middleware: UsageMiddleware records usage via the billing recorder,
  plus admission, route-model, usage-stamp and trace middlewares.

* observability: BackendTrace ring buffer stores full request bodies
  (capped), MITM proxy emits structured trace events, and router
  classifier decisions surface at /api/router/decide.

* gallery: Arch-Router-1.5B (Q4_K_M and Q8_0).

* UI: cloud-proxy model-editor fields, classifier system-prompt and
  score-normalization config, and a Traces page rendering request
  bodies.

Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-25 09:28:27 +02:00

307 lines
7.8 KiB
Go

package mitm
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/xlog"
"golang.org/x/net/http2"
)
// Server is an HTTPS forward proxy that MITMs traffic for hosts
// in its intercept allowlist; non-allowlisted hosts get a plain
// TCP CONNECT tunnel.
type Server struct {
addr string
ca *CA
interceptHosts map[string]bool
handler InterceptHandler
connectTimeout time.Duration
dialTimeout time.Duration
upstreamTLS *tls.Config
events pii.EventStore
eventSeq atomic.Uint64
listener net.Listener
srv *http.Server
wg sync.WaitGroup
stopOnce sync.Once
stopped chan struct{}
}
// InterceptHandler runs after the proxy terminates TLS for an
// allowlisted host. It is responsible for forwarding the upstream
// response bytes back to w.
type InterceptHandler func(w http.ResponseWriter, r *http.Request, upstreamHost string)
type Config struct {
Addr string
CA *CA
InterceptHosts []string
Handler InterceptHandler
// EventStore optionally receives a proxy_connect event for every
// CONNECT, recording the destination host and whether the proxy
// intercepted or tunneled it. nil disables connect-event recording.
EventStore pii.EventStore
}
func NewServer(cfg Config) (*Server, error) {
if cfg.CA == nil {
return nil, errors.New("mitm: NewServer: CA is required")
}
if cfg.Handler == nil {
return nil, errors.New("mitm: NewServer: Handler is required")
}
hosts := make(map[string]bool, len(cfg.InterceptHosts))
for _, h := range cfg.InterceptHosts {
hosts[strings.ToLower(strings.TrimSpace(h))] = true
}
return &Server{
addr: cfg.Addr,
ca: cfg.CA,
interceptHosts: hosts,
handler: cfg.Handler,
connectTimeout: 30 * time.Second,
dialTimeout: 15 * time.Second,
upstreamTLS: &tls.Config{NextProtos: []string{"http/1.1"}},
events: cfg.EventStore,
stopped: make(chan struct{}),
}, nil
}
func (s *Server) Start() error {
ln, err := net.Listen("tcp", s.addr)
if err != nil {
return fmt.Errorf("mitm: listen %q: %w", s.addr, err)
}
s.listener = ln
s.srv = &http.Server{
Handler: http.HandlerFunc(s.handle),
ReadHeaderTimeout: 30 * time.Second,
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := s.srv.Serve(ln)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
xlog.Error("mitm: serve error", "error", err)
}
}()
xlog.Info("mitm: listening", "addr", ln.Addr().String(), "intercept_hosts", len(s.interceptHosts))
return nil
}
// Addr returns the bound listener address. Useful when Start was
// called with ":0" — the kernel picks a port and tests need to
// discover which.
func (s *Server) Addr() string {
if s.listener == nil {
return s.addr
}
return s.listener.Addr().String()
}
// Stop is idempotent.
func (s *Server) Stop() {
s.stopOnce.Do(func() {
close(s.stopped)
if s.srv != nil {
_ = s.srv.Close()
}
s.wg.Wait()
})
}
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "this proxy only supports HTTPS via CONNECT", http.StatusMethodNotAllowed)
return
}
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
host = strings.ToLower(host)
intercept := s.shouldIntercept(host)
s.recordConnectEvent(host, intercept)
if !intercept {
s.handleTunnel(w, r)
return
}
s.handleIntercept(w, r, host)
}
// recordConnectEvent writes a proxy_connect audit row. Best-effort —
// store errors are logged at debug only so a failing recorder cannot
// break a CONNECT.
func (s *Server) recordConnectEvent(host string, intercepted bool) {
if s.events == nil {
return
}
flag := intercepted
ev := pii.PIIEvent{
ID: fmt.Sprintf("proxy_connect_%d", s.eventSeq.Add(1)),
Kind: pii.KindProxyConnect,
Host: host,
Intercepted: &flag,
CreatedAt: time.Now(),
}
if err := s.events.Record(context.Background(), ev); err != nil {
xlog.Debug("mitm: failed to record proxy_connect event", "error", err, "host", host)
}
}
// shouldIntercept reports whether host is in the allowlist. An
// empty allowlist tunnels everything.
func (s *Server) shouldIntercept(host string) bool {
if len(s.interceptHosts) == 0 {
return false
}
return s.interceptHosts[host]
}
func (s *Server) handleTunnel(w http.ResponseWriter, r *http.Request) {
upstream, err := net.DialTimeout("tcp", normalizeHostPort(r.Host), s.dialTimeout)
if err != nil {
http.Error(w, "mitm: tunnel dial: "+err.Error(), http.StatusBadGateway)
return
}
defer func() { _ = upstream.Close() }()
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "mitm: hijack unsupported", http.StatusInternalServerError)
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
http.Error(w, "mitm: hijack failed: "+err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = clientConn.Close() }()
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil {
return
}
pipe(clientConn, upstream)
}
func pipe(a, b net.Conn) {
done := make(chan struct{}, 2)
go func() {
_, _ = io.Copy(a, b)
_ = a.SetDeadline(time.Now())
done <- struct{}{}
}()
go func() {
_, _ = io.Copy(b, a)
_ = b.SetDeadline(time.Now())
done <- struct{}{}
}()
<-done
}
func (s *Server) handleIntercept(w http.ResponseWriter, r *http.Request, host string) {
leaf, err := s.ca.IssueLeaf(host)
if err != nil {
http.Error(w, "mitm: leaf issuance failed: "+err.Error(), http.StatusInternalServerError)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "mitm: hijack unsupported", http.StatusInternalServerError)
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
http.Error(w, "mitm: hijack failed: "+err.Error(), http.StatusInternalServerError)
return
}
defer func() { _ = clientConn.Close() }()
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil {
return
}
tlsConn := tls.Server(clientConn, &tls.Config{
Certificates: []tls.Certificate{*leaf},
NextProtos: []string{"h2", "http/1.1"},
})
defer func() { _ = tlsConn.Close() }()
// Deadline applies to the handshake only; cleared before the
// request loop so long-running streams don't get cut off. Fail
// closed if SetDeadline errors — better than handshaking without
// a deadline.
if err := tlsConn.SetDeadline(time.Now().Add(s.connectTimeout)); err != nil {
xlog.Debug("mitm: TLS handshake set-deadline failed", "host", host, "error", err)
return
}
if err := tlsConn.Handshake(); err != nil {
xlog.Debug("mitm: TLS handshake failed", "host", host, "error", err)
return
}
_ = tlsConn.SetDeadline(time.Time{})
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
req.URL.Scheme = "https"
if req.URL.Host == "" {
req.URL.Host = req.Host
}
s.handler(rw, req, host)
})
switch tlsConn.ConnectionState().NegotiatedProtocol {
case "h2":
h2srv := &http2.Server{}
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
Handler: handler,
Context: r.Context(),
})
default:
s.serveHTTP1(tlsConn, handler, host)
}
}
func (s *Server) serveHTTP1(tlsConn *tls.Conn, handler http.Handler, host string) {
br := bufio.NewReader(tlsConn)
for {
req, err := http.ReadRequest(br)
if err != nil {
if !errors.Is(err, io.EOF) {
xlog.Debug("mitm: read request", "host", host, "error", err)
}
return
}
rw := newConnResponseWriter(tlsConn, req)
handler.ServeHTTP(rw, req)
rw.finish()
if req.Close || rw.closeAfter {
return
}
}
}
func normalizeHostPort(host string) string {
if _, _, err := net.SplitHostPort(host); err == nil {
return host
}
return host + ":443"
}