Files
LocalAI/core/cli/workerregistry/client.go
Richard Palethorpe 12d1f3a697 security(http): refuse redirects on outbound clients via hardened pkg/httpclient (#10087)
LocalAI's outbound HTTP clients used Go's default redirect policy, which
follows up to 10 redirects. On a cross-host redirect Go forwards custom
request headers — including credential headers such as Anthropic's
x-api-key — to the redirect target (Go strips Authorization, Cookie and
WWW-Authenticate cross-host, but NOT arbitrary custom headers). An
attacker able to elicit a redirect from an upstream (a hijacked or
spoofed upstream, DNS trickery, or a malicious upstream_url) then
harvests the operator's provider API key.

This was first reported against the cloud-proxy / MITM PII path
(GHSA-3mj3-57v2-4636); the same class affects every other outbound
client. Rather than patch each call site, add pkg/httpclient as the one
sanctioned constructor for outbound HTTP and route everything through it.

pkg/httpclient:
  - New(...)             refuses redirects, TLS 1.2 floor, no body
                         deadline (streaming/SSE safe)
  - NewWithTimeout(d)    simple request/response calls
  - WithFollowRedirects  opt-in following that still strips credential
                         headers on any cross-host hop; different
                         scheme/host/port == different origin, guarding
                         the curl CVE-2022-27774 port-confusion class
  - WithTransport(rt)    keep a custom transport (IP-pin, HTTP/2, a
                         credential-injecting RoundTripper)
  - HardenedTransport()  base transport with the TLS floor + bounded setup
  - Harden(c)            apply the policy to a library-supplied *http.Client
  - NoRedirect           the CheckRedirect policy; wraps ErrRedirectBlocked

Lint: a forbidigo rule flags http.DefaultClient and http.Get/Post/
PostForm/Head, pointing at pkg/httpclient (.golangci.yml,
.agents/coding-style.md). forbidigo cannot match the &http.Client{}
composite literal without also flagging legitimate *http.Client type
references, so that form is enforced by review.

Migrates every non-test outbound call site across core/, pkg/, cmd/, and
the Go backend (backend/go/cloud-proxy). Credential-bearing and
internal-RPC clients refuse redirects; download / CDN / registry clients
use WithFollowRedirects so they keep working while stripping secrets
cross-host. The only credential-bearing client that follows redirects is
the gated-download path (pkg/downloader/uri.go), which strips the token
on the cross-host hop to the CDN. Hardening this closes, in passing:
  - MCP remote-server bearer token leaking via a redirect (the
    RoundTripper re-injected Authorization on every hop)
  - agent multimedia/webhook clients leaking user-supplied auth headers
  - cors_proxy following redirects, bypassing its SSRF IP-pin
  - downloader's authorized read path leaking the token cross-host

Fixes: GHSA-3mj3-57v2-4636 (cloud-proxy leaks operator provider API key
(x-api-key) to attacker host on cross-host redirect)
Reported-by: tonghuaroot
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-30 12:04:10 +02:00

275 lines
7.9 KiB
Go

// Package workerregistry provides a shared HTTP client for worker node
// registration, heartbeating, draining, and deregistration against a
// LocalAI frontend. Both the backend worker (WorkerCMD) and the agent
// worker (AgentWorkerCMD) use this instead of duplicating the logic.
package workerregistry
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/mudler/xlog"
"github.com/mudler/LocalAI/pkg/httpclient"
)
// RegistrationClient talks to the frontend's /api/node/* endpoints.
type RegistrationClient struct {
FrontendURL string
RegistrationToken string
HTTPTimeout time.Duration // used for registration calls; defaults to 10s
client *http.Client
clientOnce sync.Once
}
// httpTimeout returns the configured timeout or a sensible default.
func (c *RegistrationClient) httpTimeout() time.Duration {
if c.HTTPTimeout > 0 {
return c.HTTPTimeout
}
return 10 * time.Second
}
// httpClient returns the shared HTTP client, initializing it on first use.
func (c *RegistrationClient) httpClient() *http.Client {
c.clientOnce.Do(func() {
c.client = httpclient.NewWithTimeout(c.httpTimeout())
})
return c.client
}
// baseURL returns FrontendURL with any trailing slash stripped.
func (c *RegistrationClient) baseURL() string {
return strings.TrimRight(c.FrontendURL, "/")
}
// setAuth adds an Authorization header when a token is configured.
func (c *RegistrationClient) setAuth(req *http.Request) {
if c.RegistrationToken != "" {
req.Header.Set("Authorization", "Bearer "+c.RegistrationToken)
}
}
// RegisterResponse is the JSON body returned by /api/node/register.
type RegisterResponse struct {
ID string `json:"id"`
APIToken string `json:"api_token,omitempty"`
}
// Register sends a single registration request and returns the node ID and
// (optionally) an auto-provisioned API token.
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return "", "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", "", fmt.Errorf("posting to %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
}
var result RegisterResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", "", fmt.Errorf("decoding response: %w", err)
}
return result.ID, result.APIToken, nil
}
// RegisterWithRetry retries registration with exponential backoff.
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
backoff := 2 * time.Second
maxBackoff := 30 * time.Second
var nodeID, apiToken string
var err error
for attempt := 1; attempt <= maxRetries; attempt++ {
nodeID, apiToken, err = c.Register(ctx, body)
if err == nil {
return nodeID, apiToken, nil
}
if attempt == maxRetries {
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
}
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
select {
case <-ctx.Done():
return "", "", ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, maxBackoff)
}
return nodeID, apiToken, err
}
// Heartbeat sends a single heartbeat POST with the given body.
func (c *RegistrationClient) Heartbeat(ctx context.Context, nodeID string, body map[string]any) error {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/" + nodeID + "/heartbeat"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return fmt.Errorf("creating heartbeat request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}
// HeartbeatLoop runs heartbeats at the given interval until ctx is cancelled.
// bodyFn is called each tick to build the heartbeat payload (e.g. VRAM stats).
func (c *RegistrationClient) HeartbeatLoop(ctx context.Context, nodeID string, interval time.Duration, bodyFn func() map[string]any) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
body := bodyFn()
if err := c.Heartbeat(ctx, nodeID, body); err != nil {
xlog.Warn("Heartbeat failed", "error", err)
}
}
}
}
// Drain sets the node to draining status via POST /api/node/:id/drain.
func (c *RegistrationClient) Drain(ctx context.Context, nodeID string) error {
url := c.baseURL() + "/api/node/" + nodeID + "/drain"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return fmt.Errorf("creating drain request: %w", err)
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("drain failed with status %d", resp.StatusCode)
}
return nil
}
// WaitForDrain polls GET /api/node/:id/models until all models report 0
// in-flight requests, or until timeout elapses.
func (c *RegistrationClient) WaitForDrain(ctx context.Context, nodeID string, timeout time.Duration) {
url := c.baseURL() + "/api/node/" + nodeID + "/models"
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
xlog.Warn("Failed to create drain poll request", "error", err)
return
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
xlog.Warn("Drain poll failed, will retry", "error", err)
select {
case <-ctx.Done():
xlog.Warn("Drain wait cancelled")
return
case <-time.After(1 * time.Second):
}
continue
}
var models []struct {
InFlight int `json:"in_flight"`
}
json.NewDecoder(resp.Body).Decode(&models)
resp.Body.Close()
total := 0
for _, m := range models {
total += m.InFlight
}
if total == 0 {
xlog.Info("All in-flight requests drained")
return
}
xlog.Info("Waiting for in-flight requests", "count", total)
select {
case <-ctx.Done():
xlog.Warn("Drain wait cancelled")
return
case <-time.After(1 * time.Second):
}
}
xlog.Warn("Drain timeout reached, proceeding with shutdown")
}
// Deregister marks the node as offline via POST /api/node/:id/deregister.
// The node row is preserved in the database so re-registration restores
// approval status.
func (c *RegistrationClient) Deregister(ctx context.Context, nodeID string) error {
url := c.baseURL() + "/api/node/" + nodeID + "/deregister"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return fmt.Errorf("creating deregister request: %w", err)
}
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("deregistration failed with status %d", resp.StatusCode)
}
return nil
}
// GracefulDeregister performs drain -> wait -> deregister in sequence.
// This is the standard shutdown sequence for backend workers.
func (c *RegistrationClient) GracefulDeregister(nodeID string) {
if c.FrontendURL == "" || nodeID == "" {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
if err := c.Drain(ctx, nodeID); err != nil {
xlog.Warn("Failed to set drain status", "error", err)
} else {
c.WaitForDrain(ctx, nodeID, 30*time.Second)
}
if err := c.Deregister(ctx, nodeID); err != nil {
xlog.Error("Failed to deregister", "error", err)
} else {
xlog.Info("Deregistered from frontend")
}
}