mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-05 15:26:14 -04:00
* feat(distributed): NATS JWT auth, TLS/mTLS options, and e2e coverage Mint per-node NATS user JWTs at registration when LOCALAI_NATS_ACCOUNT_SEED is set, and connect workers with scoped credentials from the register response. Add optional LOCALAI_NATS_TLS_CA/CERT/KEY for private CA and mTLS alongside tls:// URLs, plus test-e2e-distributed and NatsJWT container e2e specs. Document JWT setup (nats-auth-setup.sh) and TLS env vars in distributed-mode. Assisted-by: Grok:grok grok-build Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(distributed): correct NATS JWT scoping and harden client auth The JWT-auth path added in 46467cc7 had several gaps that fail silently under LOCALAI_NATS_REQUIRE_AUTH: - Agent-worker minted JWTs did not allow the subjects the agent worker actually subscribes to (jobs.mcp-ci.new and nodes.<id>.backend.stop), so MCP-CI jobs and backend-stop session cleanup were silently dropped. Scope the agent permission set to those subjects. - NATS subscription permission violations were swallowed (Subscribe returned a live-but-dead subscription). Confirm subscriptions with a server round-trip so a denial surfaces synchronously, and log async permission errors. - The backend worker connected anonymously when given a JWT without its paired seed; reject the unpaired credential instead. - The documented service-user permissions in nats-auth-setup.sh omitted prefixcache.>, which the frontend publishes and subscribes; add it. Also: add a credential-provider hook to the messaging client (consumed by the follow-up credential-lifecycle change), drop the always-nil error from NatsMessagingOptions, run go mod tidy (jwt/v2 and nkeys are now direct), and gofmt the feature's files. Tests: an agent-JWT e2e spec that connects to the enforcing NATS server and exercises every subscription the agent worker makes, plus permission allow-list coverage unit tests. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(distributed): acquire and auto-refresh worker NATS credentials Workers fetched NATS credentials once at startup, which broke two cases under JWT auth: a worker that registered while still pending admin approval never received a minted JWT (it connected unauthenticated and gave up), and a long-running worker's 24h JWT expired with no way to renew it. Introduce workerregistry.NATSCredentialManager, built on idempotent re-registration (the frontend preserves the node row and mints a fresh JWT each call): - Acquire re-registers through admin approval until the node is approved and credentials are minted (or returns the first success when auth is not required, preserving anonymous-NATS behavior). - RefreshLoop re-registers before the JWT expires (~75% of its lifetime), updating the credentials served to the connection. - Both are bounded (default 100 attempts / consecutive failures) and return an error on exhaustion, so an unapprovable or unrenewable worker exits non-zero and surfaces the problem instead of hanging or drifting toward an expired credential. The messaging client gains WithUserJWTProvider, fetching credentials on each (re)connect so the connection transparently adopts a refreshed JWT when the server expires the old one. RegisterFull exposes the approval status and full response; Register delegates to it. Both the backend worker and the agent worker are wired to this: explicit env credentials are used as-is, minted credentials are acquired-with-wait and refreshed, and a permanent refresh failure shuts the worker down so it restarts and re-acquires. Tests cover Acquire (wait-through-pending, bounded give-up, context cancel), RefreshLoop (refresh-before-expiry, bounded failure, no-expiry exit) and jwtExpiry decoding. Docs updated in distributed-mode.md. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
287 lines
8.8 KiB
Go
287 lines
8.8 KiB
Go
// Package workerregistry provides a shared HTTP client for worker node
|
|
// registration, heartbeating, draining, and deregistration against a
|
|
// LocalAI frontend. Both the backend worker (WorkerCMD) and the agent
|
|
// worker (AgentWorkerCMD) use this instead of duplicating the logic.
|
|
package workerregistry
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/mudler/xlog"
|
|
|
|
"github.com/mudler/LocalAI/pkg/httpclient"
|
|
)
|
|
|
|
// RegistrationClient talks to the frontend's /api/node/* endpoints.
|
|
type RegistrationClient struct {
|
|
FrontendURL string
|
|
RegistrationToken string
|
|
HTTPTimeout time.Duration // used for registration calls; defaults to 10s
|
|
client *http.Client
|
|
clientOnce sync.Once
|
|
}
|
|
|
|
// httpTimeout returns the configured timeout or a sensible default.
|
|
func (c *RegistrationClient) httpTimeout() time.Duration {
|
|
if c.HTTPTimeout > 0 {
|
|
return c.HTTPTimeout
|
|
}
|
|
return 10 * time.Second
|
|
}
|
|
|
|
// httpClient returns the shared HTTP client, initializing it on first use.
|
|
func (c *RegistrationClient) httpClient() *http.Client {
|
|
c.clientOnce.Do(func() {
|
|
c.client = httpclient.NewWithTimeout(c.httpTimeout())
|
|
})
|
|
return c.client
|
|
}
|
|
|
|
// baseURL returns FrontendURL with any trailing slash stripped.
|
|
func (c *RegistrationClient) baseURL() string {
|
|
return strings.TrimRight(c.FrontendURL, "/")
|
|
}
|
|
|
|
// setAuth adds an Authorization header when a token is configured.
|
|
func (c *RegistrationClient) setAuth(req *http.Request) {
|
|
if c.RegistrationToken != "" {
|
|
req.Header.Set("Authorization", "Bearer "+c.RegistrationToken)
|
|
}
|
|
}
|
|
|
|
// RegisterResponse is the JSON body returned by /api/node/register.
|
|
type RegisterResponse struct {
|
|
ID string `json:"id"`
|
|
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
|
|
APIToken string `json:"api_token,omitempty"`
|
|
NatsJWT string `json:"nats_jwt,omitempty"`
|
|
NatsUserSeed string `json:"nats_user_seed,omitempty"`
|
|
}
|
|
|
|
// RegisterFull sends a single registration request and returns the full
|
|
// response (node ID, approval status, and optional API token / NATS creds).
|
|
// Re-registration is idempotent: the frontend preserves the node row and mints
|
|
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
|
|
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
|
|
jsonBody, _ := json.Marshal(body)
|
|
url := c.baseURL() + "/api/node/register"
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
c.setAuth(req)
|
|
|
|
resp, err := c.httpClient().Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("posting to %s: %w", url, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
|
}
|
|
|
|
var result RegisterResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return nil, fmt.Errorf("decoding response: %w", err)
|
|
}
|
|
return &result, nil
|
|
}
|
|
|
|
// Register sends a single registration request and returns the node ID and
|
|
// optional credentials (API token for agent workers, NATS JWT when configured).
|
|
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
|
res, err := c.RegisterFull(ctx, body)
|
|
if err != nil {
|
|
return "", "", "", "", err
|
|
}
|
|
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
|
|
}
|
|
|
|
// RegisterWithRetry retries registration with exponential backoff.
|
|
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
|
backoff := 2 * time.Second
|
|
maxBackoff := 30 * time.Second
|
|
|
|
for attempt := 1; attempt <= maxRetries; attempt++ {
|
|
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
|
|
if err == nil {
|
|
return nodeID, apiToken, natsJWT, natsSeed, nil
|
|
}
|
|
if attempt == maxRetries {
|
|
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
|
}
|
|
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
|
select {
|
|
case <-ctx.Done():
|
|
return "", "", "", "", ctx.Err()
|
|
case <-time.After(backoff):
|
|
}
|
|
backoff = min(backoff*2, maxBackoff)
|
|
}
|
|
return nodeID, apiToken, natsJWT, natsSeed, err
|
|
}
|
|
|
|
// Heartbeat sends a single heartbeat POST with the given body.
|
|
func (c *RegistrationClient) Heartbeat(ctx context.Context, nodeID string, body map[string]any) error {
|
|
jsonBody, _ := json.Marshal(body)
|
|
url := c.baseURL() + "/api/node/" + nodeID + "/heartbeat"
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return fmt.Errorf("creating heartbeat request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
c.setAuth(req)
|
|
|
|
resp, err := c.httpClient().Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
return nil
|
|
}
|
|
|
|
// HeartbeatLoop runs heartbeats at the given interval until ctx is cancelled.
|
|
// bodyFn is called each tick to build the heartbeat payload (e.g. VRAM stats).
|
|
func (c *RegistrationClient) HeartbeatLoop(ctx context.Context, nodeID string, interval time.Duration, bodyFn func() map[string]any) {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
body := bodyFn()
|
|
if err := c.Heartbeat(ctx, nodeID, body); err != nil {
|
|
xlog.Warn("Heartbeat failed", "error", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Drain sets the node to draining status via POST /api/node/:id/drain.
|
|
func (c *RegistrationClient) Drain(ctx context.Context, nodeID string) error {
|
|
url := c.baseURL() + "/api/node/" + nodeID + "/drain"
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("creating drain request: %w", err)
|
|
}
|
|
c.setAuth(req)
|
|
|
|
resp, err := c.httpClient().Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("drain failed with status %d", resp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// WaitForDrain polls GET /api/node/:id/models until all models report 0
|
|
// in-flight requests, or until timeout elapses.
|
|
func (c *RegistrationClient) WaitForDrain(ctx context.Context, nodeID string, timeout time.Duration) {
|
|
url := c.baseURL() + "/api/node/" + nodeID + "/models"
|
|
|
|
deadline := time.Now().Add(timeout)
|
|
for time.Now().Before(deadline) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
xlog.Warn("Failed to create drain poll request", "error", err)
|
|
return
|
|
}
|
|
c.setAuth(req)
|
|
|
|
resp, err := c.httpClient().Do(req)
|
|
if err != nil {
|
|
xlog.Warn("Drain poll failed, will retry", "error", err)
|
|
select {
|
|
case <-ctx.Done():
|
|
xlog.Warn("Drain wait cancelled")
|
|
return
|
|
case <-time.After(1 * time.Second):
|
|
}
|
|
continue
|
|
}
|
|
var models []struct {
|
|
InFlight int `json:"in_flight"`
|
|
}
|
|
json.NewDecoder(resp.Body).Decode(&models)
|
|
resp.Body.Close()
|
|
|
|
total := 0
|
|
for _, m := range models {
|
|
total += m.InFlight
|
|
}
|
|
if total == 0 {
|
|
xlog.Info("All in-flight requests drained")
|
|
return
|
|
}
|
|
xlog.Info("Waiting for in-flight requests", "count", total)
|
|
select {
|
|
case <-ctx.Done():
|
|
xlog.Warn("Drain wait cancelled")
|
|
return
|
|
case <-time.After(1 * time.Second):
|
|
}
|
|
}
|
|
xlog.Warn("Drain timeout reached, proceeding with shutdown")
|
|
}
|
|
|
|
// Deregister marks the node as offline via POST /api/node/:id/deregister.
|
|
// The node row is preserved in the database so re-registration restores
|
|
// approval status.
|
|
func (c *RegistrationClient) Deregister(ctx context.Context, nodeID string) error {
|
|
url := c.baseURL() + "/api/node/" + nodeID + "/deregister"
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("creating deregister request: %w", err)
|
|
}
|
|
c.setAuth(req)
|
|
|
|
resp, err := c.httpClient().Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("deregistration failed with status %d", resp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GracefulDeregister performs drain -> wait -> deregister in sequence.
|
|
// This is the standard shutdown sequence for backend workers.
|
|
func (c *RegistrationClient) GracefulDeregister(nodeID string) {
|
|
if c.FrontendURL == "" || nodeID == "" {
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
|
|
if err := c.Drain(ctx, nodeID); err != nil {
|
|
xlog.Warn("Failed to set drain status", "error", err)
|
|
} else {
|
|
c.WaitForDrain(ctx, nodeID, 30*time.Second)
|
|
}
|
|
|
|
if err := c.Deregister(ctx, nodeID); err != nil {
|
|
xlog.Error("Failed to deregister", "error", err)
|
|
} else {
|
|
xlog.Info("Deregistered from frontend")
|
|
}
|
|
}
|