mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-29 11:07:18 -04:00
Add a routing middleware stack and a cloud-proxy backend. * cloud-proxy: a Go gRPC backend that forwards OpenAI- and Anthropic-shaped chat requests to upstream providers, with an optional translate mode (OpenAI request -> Anthropic /v1/messages -> OpenAI response) and full tool-calling support. * routing: admission control, content-aware model routing (embedding cache + classifier + rerank + Arch-Router score), PII detection/redaction (regex + NER) with streaming filter and OpenAI/Anthropic adapters, and a per-user/per-key billing recorder backed by GORM or in-memory storage. * middleware: UsageMiddleware records usage via the billing recorder, plus admission, route-model, usage-stamp and trace middlewares. * observability: BackendTrace ring buffer stores full request bodies (capped), MITM proxy emits structured trace events, and router classifier decisions surface at /api/router/decide. * gallery: Arch-Router-1.5B (Q4_K_M and Q8_0). * UI: cloud-proxy model-editor fields, classifier system-prompt and score-normalization config, and a Traces page rendering request bodies. Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash] Signed-off-by: Richard Palethorpe <io@richiejp.com>
307 lines
7.8 KiB
Go
307 lines
7.8 KiB
Go
package mitm
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/mudler/LocalAI/core/services/routing/pii"
|
|
"github.com/mudler/xlog"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
// Server is an HTTPS forward proxy that MITMs traffic for hosts
|
|
// in its intercept allowlist; non-allowlisted hosts get a plain
|
|
// TCP CONNECT tunnel.
|
|
type Server struct {
|
|
addr string
|
|
ca *CA
|
|
interceptHosts map[string]bool
|
|
handler InterceptHandler
|
|
connectTimeout time.Duration
|
|
dialTimeout time.Duration
|
|
upstreamTLS *tls.Config
|
|
events pii.EventStore
|
|
eventSeq atomic.Uint64
|
|
|
|
listener net.Listener
|
|
srv *http.Server
|
|
|
|
wg sync.WaitGroup
|
|
stopOnce sync.Once
|
|
stopped chan struct{}
|
|
}
|
|
|
|
// InterceptHandler runs after the proxy terminates TLS for an
|
|
// allowlisted host. It is responsible for forwarding the upstream
|
|
// response bytes back to w.
|
|
type InterceptHandler func(w http.ResponseWriter, r *http.Request, upstreamHost string)
|
|
|
|
type Config struct {
|
|
Addr string
|
|
CA *CA
|
|
InterceptHosts []string
|
|
Handler InterceptHandler
|
|
// EventStore optionally receives a proxy_connect event for every
|
|
// CONNECT, recording the destination host and whether the proxy
|
|
// intercepted or tunneled it. nil disables connect-event recording.
|
|
EventStore pii.EventStore
|
|
}
|
|
|
|
func NewServer(cfg Config) (*Server, error) {
|
|
if cfg.CA == nil {
|
|
return nil, errors.New("mitm: NewServer: CA is required")
|
|
}
|
|
if cfg.Handler == nil {
|
|
return nil, errors.New("mitm: NewServer: Handler is required")
|
|
}
|
|
hosts := make(map[string]bool, len(cfg.InterceptHosts))
|
|
for _, h := range cfg.InterceptHosts {
|
|
hosts[strings.ToLower(strings.TrimSpace(h))] = true
|
|
}
|
|
return &Server{
|
|
addr: cfg.Addr,
|
|
ca: cfg.CA,
|
|
interceptHosts: hosts,
|
|
handler: cfg.Handler,
|
|
connectTimeout: 30 * time.Second,
|
|
dialTimeout: 15 * time.Second,
|
|
upstreamTLS: &tls.Config{NextProtos: []string{"http/1.1"}},
|
|
events: cfg.EventStore,
|
|
stopped: make(chan struct{}),
|
|
}, nil
|
|
}
|
|
|
|
func (s *Server) Start() error {
|
|
ln, err := net.Listen("tcp", s.addr)
|
|
if err != nil {
|
|
return fmt.Errorf("mitm: listen %q: %w", s.addr, err)
|
|
}
|
|
s.listener = ln
|
|
s.srv = &http.Server{
|
|
Handler: http.HandlerFunc(s.handle),
|
|
ReadHeaderTimeout: 30 * time.Second,
|
|
}
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
err := s.srv.Serve(ln)
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
xlog.Error("mitm: serve error", "error", err)
|
|
}
|
|
}()
|
|
xlog.Info("mitm: listening", "addr", ln.Addr().String(), "intercept_hosts", len(s.interceptHosts))
|
|
return nil
|
|
}
|
|
|
|
// Addr returns the bound listener address. Useful when Start was
|
|
// called with ":0" — the kernel picks a port and tests need to
|
|
// discover which.
|
|
func (s *Server) Addr() string {
|
|
if s.listener == nil {
|
|
return s.addr
|
|
}
|
|
return s.listener.Addr().String()
|
|
}
|
|
|
|
// Stop is idempotent.
|
|
func (s *Server) Stop() {
|
|
s.stopOnce.Do(func() {
|
|
close(s.stopped)
|
|
if s.srv != nil {
|
|
_ = s.srv.Close()
|
|
}
|
|
s.wg.Wait()
|
|
})
|
|
}
|
|
|
|
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodConnect {
|
|
http.Error(w, "this proxy only supports HTTPS via CONNECT", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
host, _, err := net.SplitHostPort(r.Host)
|
|
if err != nil {
|
|
host = r.Host
|
|
}
|
|
host = strings.ToLower(host)
|
|
|
|
intercept := s.shouldIntercept(host)
|
|
s.recordConnectEvent(host, intercept)
|
|
if !intercept {
|
|
s.handleTunnel(w, r)
|
|
return
|
|
}
|
|
s.handleIntercept(w, r, host)
|
|
}
|
|
|
|
// recordConnectEvent writes a proxy_connect audit row. Best-effort —
|
|
// store errors are logged at debug only so a failing recorder cannot
|
|
// break a CONNECT.
|
|
func (s *Server) recordConnectEvent(host string, intercepted bool) {
|
|
if s.events == nil {
|
|
return
|
|
}
|
|
flag := intercepted
|
|
ev := pii.PIIEvent{
|
|
ID: fmt.Sprintf("proxy_connect_%d", s.eventSeq.Add(1)),
|
|
Kind: pii.KindProxyConnect,
|
|
Host: host,
|
|
Intercepted: &flag,
|
|
CreatedAt: time.Now(),
|
|
}
|
|
if err := s.events.Record(context.Background(), ev); err != nil {
|
|
xlog.Debug("mitm: failed to record proxy_connect event", "error", err, "host", host)
|
|
}
|
|
}
|
|
|
|
// shouldIntercept reports whether host is in the allowlist. An
|
|
// empty allowlist tunnels everything.
|
|
func (s *Server) shouldIntercept(host string) bool {
|
|
if len(s.interceptHosts) == 0 {
|
|
return false
|
|
}
|
|
return s.interceptHosts[host]
|
|
}
|
|
|
|
func (s *Server) handleTunnel(w http.ResponseWriter, r *http.Request) {
|
|
upstream, err := net.DialTimeout("tcp", normalizeHostPort(r.Host), s.dialTimeout)
|
|
if err != nil {
|
|
http.Error(w, "mitm: tunnel dial: "+err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer func() { _ = upstream.Close() }()
|
|
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
http.Error(w, "mitm: hijack unsupported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
clientConn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
http.Error(w, "mitm: hijack failed: "+err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer func() { _ = clientConn.Close() }()
|
|
|
|
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil {
|
|
return
|
|
}
|
|
|
|
pipe(clientConn, upstream)
|
|
}
|
|
|
|
func pipe(a, b net.Conn) {
|
|
done := make(chan struct{}, 2)
|
|
go func() {
|
|
_, _ = io.Copy(a, b)
|
|
_ = a.SetDeadline(time.Now())
|
|
done <- struct{}{}
|
|
}()
|
|
go func() {
|
|
_, _ = io.Copy(b, a)
|
|
_ = b.SetDeadline(time.Now())
|
|
done <- struct{}{}
|
|
}()
|
|
<-done
|
|
}
|
|
|
|
func (s *Server) handleIntercept(w http.ResponseWriter, r *http.Request, host string) {
|
|
leaf, err := s.ca.IssueLeaf(host)
|
|
if err != nil {
|
|
http.Error(w, "mitm: leaf issuance failed: "+err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
http.Error(w, "mitm: hijack unsupported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
clientConn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
http.Error(w, "mitm: hijack failed: "+err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer func() { _ = clientConn.Close() }()
|
|
|
|
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil {
|
|
return
|
|
}
|
|
|
|
tlsConn := tls.Server(clientConn, &tls.Config{
|
|
Certificates: []tls.Certificate{*leaf},
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
})
|
|
defer func() { _ = tlsConn.Close() }()
|
|
|
|
// Deadline applies to the handshake only; cleared before the
|
|
// request loop so long-running streams don't get cut off. Fail
|
|
// closed if SetDeadline errors — better than handshaking without
|
|
// a deadline.
|
|
if err := tlsConn.SetDeadline(time.Now().Add(s.connectTimeout)); err != nil {
|
|
xlog.Debug("mitm: TLS handshake set-deadline failed", "host", host, "error", err)
|
|
return
|
|
}
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
xlog.Debug("mitm: TLS handshake failed", "host", host, "error", err)
|
|
return
|
|
}
|
|
_ = tlsConn.SetDeadline(time.Time{})
|
|
|
|
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
req.URL.Scheme = "https"
|
|
if req.URL.Host == "" {
|
|
req.URL.Host = req.Host
|
|
}
|
|
s.handler(rw, req, host)
|
|
})
|
|
|
|
switch tlsConn.ConnectionState().NegotiatedProtocol {
|
|
case "h2":
|
|
h2srv := &http2.Server{}
|
|
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
|
|
Handler: handler,
|
|
Context: r.Context(),
|
|
})
|
|
default:
|
|
s.serveHTTP1(tlsConn, handler, host)
|
|
}
|
|
}
|
|
|
|
func (s *Server) serveHTTP1(tlsConn *tls.Conn, handler http.Handler, host string) {
|
|
br := bufio.NewReader(tlsConn)
|
|
for {
|
|
req, err := http.ReadRequest(br)
|
|
if err != nil {
|
|
if !errors.Is(err, io.EOF) {
|
|
xlog.Debug("mitm: read request", "host", host, "error", err)
|
|
}
|
|
return
|
|
}
|
|
rw := newConnResponseWriter(tlsConn, req)
|
|
handler.ServeHTTP(rw, req)
|
|
rw.finish()
|
|
if req.Close || rw.closeAfter {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func normalizeHostPort(host string) string {
|
|
if _, _, err := net.SplitHostPort(host); err == nil {
|
|
return host
|
|
}
|
|
return host + ":443"
|
|
}
|