feat(distributed): Add NATS JWT authentication and TLS/mTLS options (#10159)

* feat(distributed): NATS JWT auth, TLS/mTLS options, and e2e coverage

Mint per-node NATS user JWTs at registration when LOCALAI_NATS_ACCOUNT_SEED
is set, and connect workers with scoped credentials from the register response.
Add optional LOCALAI_NATS_TLS_CA/CERT/KEY for private CA and mTLS alongside
tls:// URLs, plus test-e2e-distributed and NatsJWT container e2e specs.

Document JWT setup (nats-auth-setup.sh) and TLS env vars in distributed-mode.

Assisted-by: Grok:grok grok-build
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(distributed): correct NATS JWT scoping and harden client auth

The JWT-auth path added in 46467cc7 had several gaps that fail silently
under LOCALAI_NATS_REQUIRE_AUTH:

- Agent-worker minted JWTs did not allow the subjects the agent worker
  actually subscribes to (jobs.mcp-ci.new and nodes.<id>.backend.stop),
  so MCP-CI jobs and backend-stop session cleanup were silently dropped.
  Scope the agent permission set to those subjects.
- NATS subscription permission violations were swallowed (Subscribe
  returned a live-but-dead subscription). Confirm subscriptions with a
  server round-trip so a denial surfaces synchronously, and log async
  permission errors.
- The backend worker connected anonymously when given a JWT without its
  paired seed; reject the unpaired credential instead.
- The documented service-user permissions in nats-auth-setup.sh omitted
  prefixcache.>, which the frontend publishes and subscribes; add it.

Also: add a credential-provider hook to the messaging client (consumed by
the follow-up credential-lifecycle change), drop the always-nil error from
NatsMessagingOptions, run go mod tidy (jwt/v2 and nkeys are now direct),
and gofmt the feature's files.

Tests: an agent-JWT e2e spec that connects to the enforcing NATS server
and exercises every subscription the agent worker makes, plus permission
allow-list coverage unit tests.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(distributed): acquire and auto-refresh worker NATS credentials

Workers fetched NATS credentials once at startup, which broke two cases
under JWT auth: a worker that registered while still pending admin
approval never received a minted JWT (it connected unauthenticated and
gave up), and a long-running worker's 24h JWT expired with no way to renew
it.

Introduce workerregistry.NATSCredentialManager, built on idempotent
re-registration (the frontend preserves the node row and mints a fresh JWT
each call):

- Acquire re-registers through admin approval until the node is approved
  and credentials are minted (or returns the first success when auth is
  not required, preserving anonymous-NATS behavior).
- RefreshLoop re-registers before the JWT expires (~75% of its lifetime),
  updating the credentials served to the connection.
- Both are bounded (default 100 attempts / consecutive failures) and
  return an error on exhaustion, so an unapprovable or unrenewable worker
  exits non-zero and surfaces the problem instead of hanging or drifting
  toward an expired credential.

The messaging client gains WithUserJWTProvider, fetching credentials on
each (re)connect so the connection transparently adopts a refreshed JWT
when the server expires the old one. RegisterFull exposes the approval
status and full response; Register delegates to it.

Both the backend worker and the agent worker are wired to this: explicit
env credentials are used as-is, minted credentials are acquired-with-wait
and refreshed, and a permanent refresh failure shuts the worker down so it
restarts and re-acquires.

Tests cover Acquire (wait-through-pending, bounded give-up, context
cancel), RefreshLoop (refresh-before-expiry, bounded failure, no-expiry
exit) and jwtExpiry decoding. Docs updated in distributed-mode.md.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe
2026-06-03 18:43:56 +01:00
committed by GitHub
parent 9d10418593
commit 3a932a9803
33 changed files with 1856 additions and 86 deletions

View File

@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
@echo 'Running e2e AIO tests'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
test-e2e-distributed: protogen-go
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
# cpu-vllm backend from the current working tree, then drives a
# head + headless follower via testcontainers-go and asserts a chat
# completion. BuildKit caches both images, so re-runs only rebuild
# what changed. The test lives under tests/e2e/distributed and is
# selected by the VLLMMultinode label so it doesn't run alongside
# the other distributed-suite tests by default.
# test-e2e-distributed.
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
@echo 'Running e2e vLLM multi-node DP test'
LOCALAI_IMAGE=local-ai \

View File

@@ -102,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
// Connect to NATS
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
natsAuth := cfg.Distributed.NatsAuthConfig()
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
}
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
if err != nil {
return nil, fmt.Errorf("connecting to NATS: %w", err)
}

View File

@@ -52,6 +52,15 @@ type AgentWorkerCMD struct {
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
// Timeouts
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
}
@@ -81,15 +90,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
registrationBody["token"] = cmd.RegistrationToken
}
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
// Context cancelled on shutdown — used by registration waits, heartbeat, and
// other background goroutines.
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
// Acquire credentials via (re)registration. When the bus requires auth and no
// static fallback is configured, wait through admin approval until the
// frontend mints credentials rather than starting unauthenticated.
credMgr := workerregistry.NewNATSCredentialManager(
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
return regClient.RegisterFull(ctx, registrationBody)
},
cmd.NatsRequireAuth && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
)
res, err := credMgr.Acquire(shutdownCtx)
if err != nil {
return fmt.Errorf("registration failed: %w", err)
}
nodeID := res.ID
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
// Use provisioned API token if none was set
if cmd.APIToken == "" {
cmd.APIToken = apiToken
cmd.APIToken = res.APIToken
}
// Start heartbeat
@@ -98,14 +122,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
}
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
// Context cancelled on shutdown — used by heartbeat and other background goroutines
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
// Connect to NATS
natsClient, err := messaging.New(cmd.NatsURL)
// Resolve NATS credentials with precedence: explicit env override, then
// frontend-minted (auto-refreshed before expiry), then service fallback.
// Each static source must supply JWT and seed together.
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
var natsOpts []messaging.Option
switch {
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
}
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
case credMgr.HasCredentials():
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
go func() {
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
shutdownCancel()
}
}()
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
}
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
case cmd.NatsRequireAuth:
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
}
if natsTLS.Enabled() {
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
}
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
if err != nil {
return fmt.Errorf("connecting to NATS: %w", err)
}
@@ -183,17 +233,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
// Wait for shutdown
// Wait for an OS signal or an internal fatal condition (e.g. NATS
// credentials became unrenewable), so the worker restarts and re-acquires
// rather than lingering unable to serve.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
var runErr error
select {
case <-sigCh:
case <-shutdownCtx.Done():
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
xlog.Error("Internal shutdown requested", "error", runErr)
}
xlog.Info("Shutting down agent worker")
shutdownCancel() // stop heartbeat loop immediately
dispatcher.Stop()
mcpTools.CloseAllMCPSessions()
regClient.GracefulDeregister(nodeID)
return nil
return runErr
}
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.

View File

@@ -159,6 +159,14 @@ type RunCMD struct {
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"`
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"`
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"`
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"`
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"`
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
Version bool
@@ -283,6 +291,34 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
if r.RegistrationToken != "" {
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
}
if r.NatsAccountSeed != "" {
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
}
if r.NatsServiceJWT != "" {
opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT))
}
if r.NatsServiceSeed != "" {
opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed))
}
if r.NatsWorkerJWTTTL != "" {
d, err := time.ParseDuration(r.NatsWorkerJWTTTL)
if err != nil {
return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err)
}
opts = append(opts, config.WithNatsWorkerJWTTTL(d))
}
if r.NatsRequireAuth {
opts = append(opts, config.EnableNatsRequireAuth)
}
if r.NatsTLSCA != "" {
opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA))
}
if r.NatsTLSCert != "" {
opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert))
}
if r.NatsTLSKey != "" {
opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey))
}
if r.AutoApproveNodes {
opts = append(opts, config.EnableAutoApproveNodes)
}

View File

@@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
FrontendURL: r.RegisterTo,
RegistrationToken: r.RegistrationToken,
}
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
if regErr != nil {
return fmt.Errorf("registering with frontend: %w", regErr)
}

View File

@@ -58,65 +58,77 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
// RegisterResponse is the JSON body returned by /api/node/register.
type RegisterResponse struct {
ID string `json:"id"`
APIToken string `json:"api_token,omitempty"`
ID string `json:"id"`
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
APIToken string `json:"api_token,omitempty"`
NatsJWT string `json:"nats_jwt,omitempty"`
NatsUserSeed string `json:"nats_user_seed,omitempty"`
}
// Register sends a single registration request and returns the node ID and
// (optionally) an auto-provisioned API token.
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
// RegisterFull sends a single registration request and returns the full
// response (node ID, approval status, and optional API token / NATS creds).
// Re-registration is idempotent: the frontend preserves the node row and mints
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
jsonBody, _ := json.Marshal(body)
url := c.baseURL() + "/api/node/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return "", "", fmt.Errorf("creating request: %w", err)
return nil, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
c.setAuth(req)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", "", fmt.Errorf("posting to %s: %w", url, err)
return nil, fmt.Errorf("posting to %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
}
var result RegisterResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", "", fmt.Errorf("decoding response: %w", err)
return nil, fmt.Errorf("decoding response: %w", err)
}
return result.ID, result.APIToken, nil
return &result, nil
}
// Register sends a single registration request and returns the node ID and
// optional credentials (API token for agent workers, NATS JWT when configured).
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
res, err := c.RegisterFull(ctx, body)
if err != nil {
return "", "", "", "", err
}
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
}
// RegisterWithRetry retries registration with exponential backoff.
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
backoff := 2 * time.Second
maxBackoff := 30 * time.Second
var nodeID, apiToken string
var err error
for attempt := 1; attempt <= maxRetries; attempt++ {
nodeID, apiToken, err = c.Register(ctx, body)
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
if err == nil {
return nodeID, apiToken, nil
return nodeID, apiToken, natsJWT, natsSeed, nil
}
if attempt == maxRetries {
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
}
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
select {
case <-ctx.Done():
return "", "", ctx.Err()
return "", "", "", "", ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, maxBackoff)
}
return nodeID, apiToken, err
return nodeID, apiToken, natsJWT, natsSeed, err
}
// Heartbeat sends a single heartbeat POST with the given body.

View File

@@ -0,0 +1,200 @@
package workerregistry
import (
"context"
"fmt"
"sync"
"time"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/mudler/xlog"
)
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
// imported so the lightweight registration client does not pull in the nodes
// package (and its gorm/DB dependencies).
const statusPending = "pending"
// defaultMaxAttempts bounds how many times Acquire registers (and how many
// consecutive times RefreshLoop may fail) before giving up. It is high enough
// to ride out a slow admin approval or a transient frontend outage, but finite
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
// non-zero exit and the resulting restart) rather than waiting forever.
const defaultMaxAttempts = 100
// RegisterFunc performs one idempotent registration round-trip.
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
// NATSCredentialManager acquires NATS credentials at startup — waiting through
// admin approval when required — and refreshes them before the minted JWT
// expires, by re-registering (which mints a fresh JWT). The live NATS
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
// for concurrent use.
//
// It addresses two failure modes: a worker that needs credentials but registers
// while still pending approval (it would otherwise give up and never connect),
// and a long-running worker whose 24h JWT expires with no way to renew it.
type NATSCredentialManager struct {
register RegisterFunc
requireCreds bool // block until credentials are present (frontend minting in use)
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
initialBackoff time.Duration
maxBackoff time.Duration
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
refreshRetry time.Duration
expiryOf func(jwt string) (time.Time, bool)
mu sync.RWMutex
jwt string
seed string
nodeID string
}
// NewNATSCredentialManager builds a manager over register. When requireCreds is
// true, Acquire blocks until the node is approved and credentials are minted.
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
return &NATSCredentialManager{
register: register,
requireCreds: requireCreds,
initialBackoff: 2 * time.Second,
maxBackoff: 30 * time.Second,
maxAttempts: defaultMaxAttempts,
refreshLead: 0.75,
refreshRetry: 30 * time.Second,
expiryOf: jwtExpiry,
}
}
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
func jwtExpiry(token string) (time.Time, bool) {
if token == "" {
return time.Time{}, false
}
uc, err := natsauth.DecodeUserClaims(token)
if err != nil || uc.Expires == 0 {
return time.Time{}, false
}
return time.Unix(uc.Expires, 0), true
}
func (m *NATSCredentialManager) store(res *RegisterResponse) {
m.mu.Lock()
defer m.mu.Unlock()
m.nodeID = res.ID
if res.NatsJWT != "" && res.NatsUserSeed != "" {
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
}
}
// Current returns the latest NATS credentials (both empty until acquired).
func (m *NATSCredentialManager) Current() (jwt, seed string) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.jwt, m.seed
}
// NodeID returns the node ID from the most recent registration.
func (m *NATSCredentialManager) NodeID() string {
m.mu.RLock()
defer m.mu.RUnlock()
return m.nodeID
}
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
// supplying the current credentials on each (re)connect.
func (m *NATSCredentialManager) Provider() func() (string, string) {
return m.Current
}
// HasCredentials reports whether complete NATS credentials have been obtained.
func (m *NATSCredentialManager) HasCredentials() bool {
jwt, seed := m.Current()
return jwt != "" && seed != ""
}
// Acquire registers and, when requireCreds is set, keeps re-registering with
// exponential backoff until the node is approved (status != pending) and
// credentials are minted. Without requireCreds it returns the first successful
// response (the historical one-shot behavior, preserved for anonymous NATS).
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
backoff := m.initialBackoff
var lastReason error
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
res, err := m.register(ctx)
switch {
case err != nil:
lastReason = err
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
case !m.requireCreds:
m.store(res)
return res, nil
case res.Status == statusPending:
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
case res.NatsJWT == "" || res.NatsUserSeed == "":
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
default:
m.store(res)
return res, nil
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(backoff):
}
backoff = min(backoff*2, m.maxBackoff)
}
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
}
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
// updating the credentials returned by Current/Provider so the NATS connection
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
// when the current credential has no expiry (nothing to refresh), and a non-nil
// error after maxAttempts consecutive refresh failures — letting the caller
// exit the worker so it restarts and re-acquires (or surfaces the outage)
// rather than silently drifting toward an expired, unrenewable JWT.
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
failures := 0
for {
jwt, _ := m.Current()
exp, ok := m.expiryOf(jwt)
if !ok {
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
return nil
}
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
select {
case <-ctx.Done():
return nil
case <-time.After(wait):
}
res, err := m.register(ctx)
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
m.store(res)
failures = 0
xlog.Info("Refreshed NATS credentials", "node", res.ID)
continue
}
failures++
if err != nil {
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
} else {
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
}
if m.maxAttempts > 0 && failures >= m.maxAttempts {
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
}
// Back off before retrying so a persistent failure near expiry does not spin.
select {
case <-ctx.Done():
return nil
case <-time.After(m.refreshRetry):
}
}
}

View File

@@ -0,0 +1,198 @@
package workerregistry
import (
"context"
"sync"
"testing"
"time"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/nats-io/nkeys"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestWorkerRegistry(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "WorkerRegistry")
}
// fakeRegister returns a sequence of canned responses/errors, one per call, and
// records how many times it was invoked. The last entry repeats once exhausted.
type fakeRegister struct {
mu sync.Mutex
steps []step
calls int
}
type step struct {
res *RegisterResponse
err error
}
func (f *fakeRegister) fn() RegisterFunc {
return func(context.Context) (*RegisterResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
i := f.calls
f.calls++
if i >= len(f.steps) {
i = len(f.steps) - 1
}
return f.steps[i].res, f.steps[i].err
}
}
func (f *fakeRegister) count() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls
}
var _ = Describe("NATSCredentialManager", func() {
approved := func(jwt, seed string) *RegisterResponse {
return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed}
}
pending := &RegisterResponse{ID: "node-1", Status: "pending"}
Describe("Acquire (#4 — wait through admin approval)", func() {
It("keeps re-registering until the node is approved and credentials are minted", func() {
f := &fakeRegister{steps: []step{
{res: pending}, // not approved yet
{res: approved("", "")}, // approved but JWT not minted yet
{res: approved("jwt-1", "seed-1")}, // finally minted
}}
m := NewNATSCredentialManager(f.fn(), true /* requireCreds */)
m.initialBackoff = time.Millisecond
m.maxBackoff = time.Millisecond
res, err := m.Acquire(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(res.ID).To(Equal("node-1"))
Expect(f.count()).To(Equal(3))
jwt, seed := m.Current()
Expect(jwt).To(Equal("jwt-1"))
Expect(seed).To(Equal("seed-1"))
Expect(m.HasCredentials()).To(BeTrue())
Expect(m.NodeID()).To(Equal("node-1"))
})
It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() {
f := &fakeRegister{steps: []step{{res: pending}}}
m := NewNATSCredentialManager(f.fn(), false /* requireCreds */)
res, err := m.Acquire(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(res.Status).To(Equal("pending"))
Expect(f.count()).To(Equal(1))
Expect(m.HasCredentials()).To(BeFalse())
})
It("aborts when the context is cancelled while waiting for approval", func() {
f := &fakeRegister{steps: []step{{res: pending}}}
m := NewNATSCredentialManager(f.fn(), true)
m.initialBackoff = 10 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := m.Acquire(ctx)
Expect(err).To(MatchError(context.Canceled))
})
It("gives up after a bounded number of attempts so the worker exits and alerts", func() {
f := &fakeRegister{steps: []step{{res: pending}}} // never approved
m := NewNATSCredentialManager(f.fn(), true)
m.initialBackoff = time.Millisecond
m.maxBackoff = time.Millisecond
m.maxAttempts = 5
_, err := m.Acquire(context.Background())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("after 5 attempts"))
Expect(err.Error()).To(ContainSubstring("pending admin approval"))
Expect(f.count()).To(Equal(5))
})
})
Describe("RefreshLoop (#5 — renew before the JWT expires)", func() {
It("re-registers before expiry and updates the credentials served to new connections", func() {
f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}}
m := NewNATSCredentialManager(f.fn(), true)
m.refreshLead = 0.5
m.refreshRetry = time.Millisecond
// jwt-1 expires soon; jwt-2 is long-lived so the loop then idles.
m.expiryOf = func(jwt string) (time.Time, bool) {
switch jwt {
case "jwt-1":
return time.Now().Add(40 * time.Millisecond), true
case "jwt-2":
return time.Now().Add(time.Hour), true
default:
return time.Time{}, false
}
}
m.store(approved("jwt-1", "seed-1"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = m.RefreshLoop(ctx) }()
Eventually(func() string {
jwt, _ := m.Current()
return jwt
}, "2s", "10ms").Should(Equal("jwt-2"))
})
It("returns an error after the bounded number of consecutive failures so the caller can exit", func() {
f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails
m := NewNATSCredentialManager(f.fn(), true)
m.refreshLead = 0.5
m.refreshRetry = time.Millisecond
m.maxAttempts = 3
m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true }
m.store(approved("jwt-1", "seed-1"))
errCh := make(chan error, 1)
go func() { errCh <- m.RefreshLoop(context.Background()) }()
Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row"))))
})
It("exits promptly when the current credential has no expiry (nothing to refresh)", func() {
f := &fakeRegister{steps: []step{{res: approved("x", "y")}}}
m := NewNATSCredentialManager(f.fn(), true)
m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false }
m.store(approved("static", "seed"))
done := make(chan struct{})
go func() { _ = m.RefreshLoop(context.Background()); close(done) }()
Eventually(done, "1s").Should(BeClosed())
Expect(f.count()).To(Equal(0)) // never tried to re-register
})
})
Describe("jwtExpiry default", func() {
It("decodes the expiry of a real minted worker JWT", func() {
akp, err := nkeys.CreateAccount()
Expect(err).ToNot(HaveOccurred())
seed, err := akp.Seed()
Expect(err).ToNot(HaveOccurred())
cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour}
token, _, err := cfg.MintWorkerJWT("node-1", "backend")
Expect(err).ToNot(HaveOccurred())
exp, ok := jwtExpiry(token)
Expect(ok).To(BeTrue())
Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute))
})
It("reports no expiry for an empty or undecodable token", func() {
_, ok := jwtExpiry("")
Expect(ok).To(BeFalse())
_, ok = jwtExpiry("not-a-jwt")
Expect(ok).To(BeFalse())
})
})
})

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"time"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/mudler/xlog"
)
@@ -18,6 +20,16 @@ type DistributedConfig struct {
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers
NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT
NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h)
NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing
NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify)
NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS
NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert
// S3 configuration (used when StorageURL is set)
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION
@@ -80,6 +92,13 @@ func (c DistributedConfig) Validate() error {
if c.RegistrationToken == "" {
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
}
if err := c.NatsAuthConfig().Validate(); err != nil {
return err
}
if err := c.NatsTLSFiles().Validate(); err != nil {
return err
}
c.NatsAuthConfig().WarnIfInsecure(true)
// Check for negative durations
for name, d := range map[string]time.Duration{
FlagMCPToolTimeout: c.MCPToolTimeout,
@@ -123,6 +142,52 @@ func WithRegistrationToken(token string) AppOption {
}
}
func WithNatsAccountSeed(seed string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsAccountSeed = seed
}
}
func WithNatsServiceJWT(jwt string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsServiceJWT = jwt
}
}
func WithNatsServiceSeed(seed string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsServiceSeed = seed
}
}
func WithNatsWorkerJWTTTL(d time.Duration) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsWorkerJWTTTL = d
}
}
var EnableNatsRequireAuth = func(o *ApplicationConfig) {
o.Distributed.NatsRequireAuth = true
}
func WithNatsTLSCA(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSCA = path
}
}
func WithNatsTLSCert(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSCert = path
}
}
func WithNatsTLSKey(path string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.NatsTLSKey = path
}
}
func WithStorageURL(url string) AppOption {
return func(o *ApplicationConfig) {
o.Distributed.StorageURL = url
@@ -217,6 +282,44 @@ const (
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
const DefaultMaxUploadSize int64 = 50 << 30
// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client.
func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles {
return messaging.TLSFiles{
CA: c.NatsTLSCA,
Cert: c.NatsTLSCert,
Key: c.NatsTLSKey,
}
}
// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components.
// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config.
func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option {
var opts []messaging.Option
jwt, seed := userJWT, userSeed
if jwt == "" && seed == "" {
auth := c.NatsAuthConfig()
jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed
}
if jwt != "" && seed != "" {
opts = append(opts, messaging.WithUserJWT(jwt, seed))
}
if tls := c.NatsTLSFiles(); tls.Enabled() {
opts = append(opts, messaging.WithTLS(tls))
}
return opts
}
// NatsAuthConfig builds pkg/natsauth settings from distributed configuration.
func (c DistributedConfig) NatsAuthConfig() natsauth.Config {
return natsauth.Config{
AccountSeed: c.NatsAccountSeed,
ServiceUserJWT: c.NatsServiceJWT,
ServiceUserSeed: c.NatsServiceSeed,
WorkerJWTTTL: c.NatsWorkerJWTTTL,
RequireAuth: c.NatsRequireAuth,
}
}
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)

View File

@@ -420,8 +420,9 @@ func API(application *application.Application) (*echo.Echo, error) {
remoteUnloader = d.Router.Unloader()
}
}
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
natsCfg := distCfg.NatsAuthConfig()
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg)
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg)
// Distributed SSE routes (job progress + agent events via NATS)
if d := application.Distributed(); d != nil {

View File

@@ -28,6 +28,7 @@ import (
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/pkg/httpclient"
"github.com/mudler/LocalAI/pkg/natsauth"
)
// nodeError builds a schema.ErrorResponse for node endpoints.
@@ -89,7 +90,7 @@ type RegisterNodeRequest struct {
// RegisterNodeEndpoint registers a new backend node.
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
return func(c echo.Context) error {
var req RegisterNodeRequest
if err := c.Bind(&req); err != nil {
@@ -217,13 +218,15 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
}
}
attachNatsJWT(response, node, natsCfg)
return c.JSON(http.StatusCreated, response)
}
}
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
// For agent workers, it also provisions an API key so they can call the inference API.
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
id := c.Param("id")
@@ -253,10 +256,26 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr
}
}
attachNatsJWT(response, node, natsCfg)
return c.JSON(http.StatusOK, response)
}
}
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
return
}
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
if err != nil {
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
return
}
response["nats_jwt"] = jwt
response["nats_user_seed"] = seed
}
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
// Returns the plaintext API key on success.
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {

View File

@@ -12,6 +12,8 @@ import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/testutil"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/nats-io/nkeys"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -63,7 +65,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
@@ -74,6 +76,29 @@ var _ = Describe("Node HTTP handlers", func() {
Expect(resp["status"]).To(Equal(nodes.StatusHealthy))
})
It("returns nats_jwt when account seed is configured", func() {
akp, err := nkeys.CreateAccount()
Expect(err).ToNot(HaveOccurred())
seed, err := akp.Seed()
Expect(err).ToNot(HaveOccurred())
e := echo.New()
body := `{"name":"worker-nats","address":"10.0.0.2:50051"}`
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
natsCfg := natsauth.Config{AccountSeed: string(seed)}
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg)
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
var resp map[string]any
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
Expect(resp["nats_jwt"]).ToNot(BeEmpty())
})
It("returns 400 when name is missing", func() {
e := echo.New()
body := `{"address":"10.0.0.1:50051"}`
@@ -82,7 +107,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -102,7 +127,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -121,7 +146,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -140,7 +165,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusBadRequest))
@@ -159,7 +184,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "")
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
})
@@ -172,7 +197,7 @@ var _ = Describe("Node HTTP handlers", func() {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := RegisterNodeEndpoint(registry, "", false, nil, "")
handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{})
Expect(handler(c)).To(Succeed())
Expect(rec.Code).To(Equal(http.StatusCreated))
@@ -195,7 +220,7 @@ var _ = Describe("Node HTTP handlers", func() {
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1))
req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec1 := httptest.NewRecorder()
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
Expect(handler(e.NewContext(req1, rec1))).To(Succeed())
Expect(rec1.Code).To(Equal(http.StatusCreated))

View File

@@ -10,6 +10,7 @@ import (
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/services/galleryop"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/pkg/natsauth"
"gorm.io/gorm"
)
@@ -35,7 +36,7 @@ func nodeReadyMiddleware(registry *nodes.NodeRegistry) echo.MiddlewareFunc {
// token but do not verify per-node identity. A compromised worker can heartbeat/drain/
// deregister other nodes. Future: issue per-node JWT at registration, validate node
// identity on subsequent requests (compare :id param with token subject).
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) {
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) {
if registry == nil {
return
}
@@ -44,7 +45,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
tokenAuthMw := nodeTokenAuth(registrationToken)
node := e.Group("/api/node", readyMw, tokenAuthMw)
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret))
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret, natsCfg))
node.POST("/:id/heartbeat", localai.HeartbeatEndpoint(registry))
node.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
node.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
@@ -60,7 +61,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
// backend install path (POST /:id/backends/install). That handler enqueues a
// ManagementOp on the gallery channel rather than blocking on a NATS reply, so
// the browser gets HTTP 202 + jobID immediately instead of waiting up to 3 minutes.
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string) {
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string, natsCfg natsauth.Config) {
if registry == nil {
return
}
@@ -81,7 +82,7 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry))
admin.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
admin.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret))
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret, natsCfg))
// Backend management on workers
admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader))

View File

@@ -2,15 +2,22 @@ package messaging
import (
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/xlog"
"github.com/nats-io/nats.go"
"github.com/nats-io/nkeys"
)
// subscribeConfirmTimeout bounds the server round-trip used to detect whether a
// subscription was rejected (e.g. by JWT permissions) before returning to the caller.
const subscribeConfirmTimeout = 5 * time.Second
// Client wraps a NATS connection and provides helpers for pub/sub and queue subscriptions.
type Client struct {
conn *nats.Conn
@@ -18,8 +25,13 @@ type Client struct {
}
// New creates a new NATS client with auto-reconnect.
func New(url string) (*Client, error) {
nc, err := nats.Connect(url,
func New(url string, opts ...Option) (*Client, error) {
var cfg connectConfig
for _, o := range opts {
o(&cfg)
}
natsOpts := []nats.Option{
nats.RetryOnFailedConnect(true),
nats.MaxReconnects(-1),
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
@@ -33,7 +45,60 @@ func New(url string) (*Client, error) {
nats.ClosedHandler(func(_ *nats.Conn) {
xlog.Info("NATS connection closed")
}),
)
// Surface async errors (notably permission violations) that NATS would
// otherwise deliver silently. A subscription the server rejects for a
// JWT permission means the worker never receives those messages, so make
// it loud rather than letting the feature fail invisibly.
nats.ErrorHandler(func(_ *nats.Conn, sub *nats.Subscription, err error) {
subject := ""
if sub != nil {
subject = sub.Subject
}
if errors.Is(err, nats.ErrPermissionViolation) {
xlog.Error("NATS permission violation — check JWT pub/sub allow lists", "subject", subject, "error", err)
return
}
xlog.Warn("NATS async error", "subject", subject, "error", err)
}),
}
switch {
case cfg.jwtProvider != nil:
// Fetch creds on every (re)connect so a refresh loop can rotate the JWT
// before expiry; the server expiring the old JWT triggers a reconnect
// that transparently picks up the new one.
natsOpts = append(natsOpts, nats.UserJWT(
func() (string, error) {
jwt, _ := cfg.jwtProvider()
if jwt == "" {
return "", fmt.Errorf("no NATS user JWT available")
}
return jwt, nil
},
func(nonce []byte) ([]byte, error) {
_, seed := cfg.jwtProvider()
kp, err := nkeys.FromSeed([]byte(seed))
if err != nil {
return nil, fmt.Errorf("loading NATS user seed: %w", err)
}
defer kp.Wipe()
return kp.Sign(nonce)
},
))
case cfg.userJWT != "" && cfg.userSeed != "":
natsOpts = append(natsOpts, nats.UserJWTAndSeed(cfg.userJWT, cfg.userSeed))
}
if cfg.tls.Enabled() {
if err := cfg.tls.Validate(); err != nil {
return nil, err
}
tlsOpts, err := cfg.tls.natsOptions()
if err != nil {
return nil, err
}
natsOpts = append(natsOpts, tlsOpts...)
}
nc, err := nats.Connect(url, natsOpts...)
if err != nil {
return nil, fmt.Errorf("connecting to NATS at %s: %w", sanitize.URL(url), err)
}
@@ -54,23 +119,67 @@ func (c *Client) Publish(subject string, data any) error {
// Subscribe creates a subscription on the given subject. All subscribers receive every message.
func (c *Client) Subscribe(subject string, handler func([]byte)) (Subscription, error) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data)
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
return conn.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data)
})
})
}
// QueueSubscribe creates a queue subscription. Within the same queue group,
// only one subscriber receives each message (load-balanced).
func (c *Client) QueueSubscribe(subject, queue string, handler func([]byte)) (Subscription, error) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
handler(msg.Data)
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
handler(msg.Data)
})
})
}
// confirmSubscription creates a subscription via mk and forces a server
// round-trip so that a permissions violation — which NATS otherwise reports
// only asynchronously — is returned to the caller synchronously. The server
// emits the "-ERR Permissions Violation" for a rejected SUB before the PONG
// that satisfies the flush, so by the time FlushTimeout returns the violation
// is recorded as the connection's last error. Without this, a worker whose JWT
// lacks a subject gets a non-nil subscription that never receives a message,
// turning a permission misconfiguration into a silent failure.
func (c *Client) confirmSubscription(subject string, mk func(*nats.Conn) (*nats.Subscription, error)) (Subscription, error) {
c.mu.RLock()
conn := c.conn
c.mu.RUnlock()
if conn == nil {
return nil, fmt.Errorf("subscribe to %s: nil NATS connection", subject)
}
sub, err := mk(conn)
if err != nil {
return nil, err
}
// A failed flush here means we could not round-trip to the server (not yet
// connected, reconnecting, slow link). RetryOnFailedConnect intentionally
// buffers subscriptions across that gap, so do NOT fail — keep the
// subscription and let it replay on (re)connect; a later permission
// violation is still logged by the async error handler in New.
if err := conn.FlushTimeout(subscribeConfirmTimeout); err != nil {
xlog.Debug("Could not confirm NATS subscription (will replay on connect)", "subject", subject, "error", err)
return sub, nil
}
// Flush succeeded, so any permission violation for this SUB has already been
// recorded as the connection's last error (the server emits it before the
// PONG). LastError is per-connection; match the exact quoted subject the
// server echoes ("Subscription to \"<subject>\"") so a stale violation for
// another subject can't be mis-attributed here.
if lerr := conn.LastError(); lerr != nil &&
errors.Is(lerr, nats.ErrPermissionViolation) &&
strings.Contains(lerr.Error(), `Subscription to "`+subject+`"`) {
_ = sub.Unsubscribe()
return nil, fmt.Errorf("subscription to %s denied by NATS server (check JWT sub allow list): %w", subject, lerr)
}
return sub, nil
}
// Request sends a request and waits for a reply (request-reply pattern).
// Returns the raw reply data.
func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
@@ -86,15 +195,15 @@ func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]
// SubscribeReply creates a subscription that supports replying to requests.
// The handler receives the raw request data and the reply subject.
func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data, func(replyData []byte) {
if msg.Reply != "" {
if err := msg.Respond(replyData); err != nil {
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
return conn.Subscribe(subject, func(msg *nats.Msg) {
handler(msg.Data, func(replyData []byte) {
if msg.Reply != "" {
if err := msg.Respond(replyData); err != nil {
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
}
}
}
})
})
})
}
@@ -102,15 +211,15 @@ func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply
// QueueSubscribeReply creates a queue subscription that supports replying to requests.
// Load-balanced across subscribers in the same queue group, with request-reply support.
func (c *Client) QueueSubscribeReply(subject, queue string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
handler(msg.Data, func(replyData []byte) {
if msg.Reply != "" {
if err := msg.Respond(replyData); err != nil {
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
handler(msg.Data, func(replyData []byte) {
if msg.Reply != "" {
if err := msg.Respond(replyData); err != nil {
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
}
}
}
})
})
})
}

View File

@@ -0,0 +1,34 @@
package messaging
// Option configures NATS client connection behavior.
type Option func(*connectConfig)
// CredentialProvider returns the NATS user JWT and signing seed to use for the
// next (re)connect. It is consulted on every connection attempt, so a refresh
// loop can rotate credentials before they expire and the connection picks them
// up automatically when the server expires the old JWT and triggers a reconnect.
type CredentialProvider func() (jwt, seed string)
type connectConfig struct {
userJWT string
userSeed string
jwtProvider CredentialProvider
tls TLSFiles
}
// WithUserJWT connects using a static NATS user JWT and signing seed (UserJWTAndSeed).
func WithUserJWT(jwt, seed string) Option {
return func(c *connectConfig) {
c.userJWT = jwt
c.userSeed = seed
}
}
// WithUserJWTProvider connects using credentials fetched from provider on each
// (re)connect, enabling JWT rotation without dropping the client. Takes
// precedence over WithUserJWT when both are set.
func WithUserJWTProvider(provider CredentialProvider) Option {
return func(c *connectConfig) {
c.jwtProvider = provider
}
}

View File

@@ -0,0 +1,68 @@
package messaging
import (
"fmt"
"os"
"github.com/nats-io/nats.go"
)
// TLSFiles holds PEM paths for NATS TLS / mTLS. Cert and key must be set together.
// Use tls:// in LOCALAI_NATS_URL; CA and client cert paths are optional extras.
type TLSFiles struct {
CA string // LOCALAI_NATS_TLS_CA — private CA for server verification
Cert string // LOCALAI_NATS_TLS_CERT — client certificate (mTLS)
Key string // LOCALAI_NATS_TLS_KEY — client private key
}
// Enabled reports whether any TLS file path is configured.
func (f TLSFiles) Enabled() bool {
return f.CA != "" || f.Cert != "" || f.Key != ""
}
// Validate checks path pairing and that files exist.
func (f TLSFiles) Validate() error {
if f.Cert != "" && f.Key == "" {
return fmt.Errorf("LOCALAI_NATS_TLS_KEY is required when LOCALAI_NATS_TLS_CERT is set")
}
if f.Key != "" && f.Cert == "" {
return fmt.Errorf("LOCALAI_NATS_TLS_CERT is required when LOCALAI_NATS_TLS_KEY is set")
}
for _, path := range []struct {
name, path string
}{
{"LOCALAI_NATS_TLS_CA", f.CA},
{"LOCALAI_NATS_TLS_CERT", f.Cert},
{"LOCALAI_NATS_TLS_KEY", f.Key},
} {
if path.path == "" {
continue
}
if _, err := os.Stat(path.path); err != nil {
return fmt.Errorf("%s: %w", path.name, err)
}
}
return nil
}
// natsOptions builds nats-go TLS options. Call Validate first.
func (f TLSFiles) natsOptions() ([]nats.Option, error) {
if !f.Enabled() {
return nil, nil
}
opts := []nats.Option{nats.Secure()}
if f.CA != "" {
opts = append(opts, nats.RootCAs(f.CA))
}
if f.Cert != "" {
opts = append(opts, nats.ClientCert(f.Cert, f.Key))
}
return opts, nil
}
// WithTLS configures CA and/or client certificate paths for the NATS connection.
func WithTLS(files TLSFiles) Option {
return func(c *connectConfig) {
c.tls = files
}
}

View File

@@ -0,0 +1,25 @@
package messaging_test
import (
"os"
"path/filepath"
"github.com/mudler/LocalAI/core/services/messaging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TLSFiles", func() {
It("requires cert and key together", func() {
Expect((messaging.TLSFiles{Cert: "/tmp/c.pem"}).Validate()).To(HaveOccurred())
Expect((messaging.TLSFiles{Key: "/tmp/k.pem"}).Validate()).To(HaveOccurred())
})
It("validates files exist", func() {
dir := GinkgoT().TempDir()
ca := filepath.Join(dir, "ca.pem")
Expect(os.WriteFile(ca, []byte("x"), 0600)).To(Succeed())
Expect((messaging.TLSFiles{CA: ca}).Validate()).To(Succeed())
})
})

View File

@@ -60,7 +60,13 @@ type Config struct {
MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"`
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (normally from registration nats_jwt)" group:"distributed"`
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user signing seed override (normally from registration nats_user_seed)" group:"distributed"`
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed from registration or env" group:"distributed"`
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
// S3 storage for distributed file transfer
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`

View File

@@ -0,0 +1,33 @@
package worker
import (
"fmt"
"github.com/mudler/LocalAI/core/services/messaging"
)
// connectNATS opens a NATS client using JWT+seed from env or registration (env wins).
func connectNATS(url, envJWT, envSeed, registerJWT, registerSeed string, requireAuth bool, tls messaging.TLSFiles) (*messaging.Client, error) {
// Env credentials take precedence, but only fall back to registration when
// the env supplied neither half — otherwise a JWT set without its seed (or
// vice-versa) would be silently completed from a different source.
jwt, seed := envJWT, envSeed
if jwt == "" && seed == "" {
jwt, seed = registerJWT, registerSeed
}
// A JWT without its paired seed (or vice-versa) is a misconfiguration: refuse
// rather than silently connecting anonymously, which would look authenticated.
if (jwt == "") != (seed == "") {
return nil, fmt.Errorf("NATS JWT and seed must be provided together (got JWT set=%t, seed set=%t)", jwt != "", seed != "")
}
var opts []messaging.Option
if jwt != "" && seed != "" {
opts = append(opts, messaging.WithUserJWT(jwt, seed))
} else if requireAuth {
return nil, fmt.Errorf("NATS JWT+seed required: set LOCALAI_NATS_JWT/LOCALAI_NATS_USER_SEED or enable frontend minting")
}
if tls.Enabled() {
opts = append(opts, messaging.WithTLS(tls))
}
return messaging.New(url, opts...)
}

View File

@@ -0,0 +1,29 @@
package worker
import (
"github.com/mudler/LocalAI/core/services/messaging"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("connectNATS", func() {
It("requires JWT when requireAuth is set and no credentials are provided", func() {
_, err := connectNATS("nats://127.0.0.1:4222", "", "", "", "", true, messaging.TLSFiles{})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("NATS JWT+seed required"))
})
// A JWT supplied without its paired seed (or vice-versa) is an operator
// misconfiguration. Today connectNATS silently drops the unpaired credential
// and connects anonymously, so the operator believes the link is
// authenticated when it is not. It should refuse instead.
It("rejects a JWT supplied without a seed instead of connecting anonymously", func() {
client, err := connectNATS("nats://127.0.0.1:4222", "jwt-without-seed", "", "", "", false, messaging.TLSFiles{})
if client != nil {
client.Close()
}
Expect(err).To(HaveOccurred(),
"connectNATS should reject an unpaired JWT rather than silently connecting anonymously")
})
})

View File

@@ -15,6 +15,7 @@ import (
"github.com/mudler/LocalAI/core/cli/workerregistry"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
grpc "github.com/mudler/LocalAI/pkg/grpc"
@@ -67,10 +68,63 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
RegistrationToken: cfg.RegistrationToken,
}
// Context cancelled on shutdown — used by registration waits, heartbeat, and
// other background goroutines.
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
registrationBody := cfg.registrationBody()
nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
if err != nil {
return fmt.Errorf("failed to register with frontend: %w", err)
natsTLS := messaging.TLSFiles{CA: cfg.NatsTLSCA, Cert: cfg.NatsTLSCert, Key: cfg.NatsTLSKey}
// Resolve how to connect to NATS. Static env credentials cannot be re-minted,
// so register once and use them directly. Otherwise the credential manager
// (re)registers to obtain credentials — waiting through admin approval — and
// refreshes them before the minted JWT expires, so the connection survives
// expiry via a transparent reconnect.
var (
nodeID string
connectNats func() (*messaging.Client, error)
)
if cfg.NatsJWT != "" || cfg.NatsUserSeed != "" {
nid, _, _, _, regErr := regClient.RegisterWithRetry(shutdownCtx, registrationBody, 10)
if regErr != nil {
return fmt.Errorf("failed to register with frontend: %w", regErr)
}
nodeID = nid
connectNats = func() (*messaging.Client, error) {
return connectNATS(cfg.NatsURL, cfg.NatsJWT, cfg.NatsUserSeed, "", "", cfg.NatsRequireAuth, natsTLS)
}
} else {
credMgr := workerregistry.NewNATSCredentialManager(
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
return regClient.RegisterFull(ctx, registrationBody)
},
cfg.NatsRequireAuth,
)
res, regErr := credMgr.Acquire(shutdownCtx)
if regErr != nil {
return fmt.Errorf("failed to register with frontend: %w", regErr)
}
nodeID = res.ID
connectNats = func() (*messaging.Client, error) {
var opts []messaging.Option
if credMgr.HasCredentials() {
opts = append(opts, messaging.WithUserJWTProvider(credMgr.Provider()))
}
if natsTLS.Enabled() {
opts = append(opts, messaging.WithTLS(natsTLS))
}
client, cerr := messaging.New(cfg.NatsURL, opts...)
if cerr == nil && credMgr.HasCredentials() {
go func() {
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
xlog.Error("NATS credential refresh permanently failed; shutting down worker", "error", err)
shutdownCancel()
}
}()
}
return client, cerr
}
}
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cfg.RegisterTo)
@@ -79,9 +133,6 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cfg.HeartbeatInterval, "error", err)
}
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
// Context cancelled on shutdown — used by heartbeat and other background goroutines
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
defer shutdownCancel()
// Start HTTP file transfer server
httpAddr := cfg.resolveHTTPAddr()
@@ -94,7 +145,7 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
// Connect to NATS
xlog.Info("Connecting to NATS", "url", sanitize.URL(cfg.NatsURL))
natsClient, err := messaging.New(cfg.NatsURL)
natsClient, err := connectNats()
if err != nil {
nodes.ShutdownFileTransferServer(httpServer)
return fmt.Errorf("connecting to NATS: %w", err)
@@ -154,12 +205,21 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
}
xlog.Info("Worker ready, waiting for backend.install events")
<-sigCh
// Exit on an OS signal or on an internal fatal condition (e.g. NATS
// credentials became unrenewable), so the worker restarts and re-acquires
// rather than lingering unable to serve.
var runErr error
select {
case <-sigCh:
case <-shutdownCtx.Done():
runErr = fmt.Errorf("worker shutting down: NATS credentials unavailable")
xlog.Error("Internal shutdown requested", "error", runErr)
}
xlog.Info("Shutting down worker")
shutdownCancel() // stop heartbeat loop immediately
regClient.GracefulDeregister(nodeID)
supervisor.stopAllBackends()
nodes.ShutdownFileTransferServer(httpServer)
return nil
return runErr
}

View File

@@ -71,6 +71,50 @@ The frontend is a standard LocalAI instance with distributed mode enabled. These
| `--backend-upgrade-timeout` | `LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT` | `15m` | Same as the install timeout, applied to backend upgrades (force-reinstall). |
| `--expose-node-header` | `LOCALAI_EXPOSE_NODE_HEADER` | `false` | When enabled, inference responses carry an `X-LocalAI-Node` header with the ID of the worker node that served the request. Coverage spans the OpenAI-compatible endpoints (chat completions, completions, embeddings, audio transcriptions, audio speech / TTS, image generations, image inpainting), the Jina rerank endpoint (`/v1/rerank`), the VAD endpoints (`/v1/vad`, `/vad`), and the Anthropic Messages (`/v1/messages`) and Ollama (`/api/chat`, `/api/generate`, `/api/embed`) shims. Useful for debugging, observability and load-balancer attribution. Off by default: the node ID reveals internal cluster topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency for the same model across multiple replicas, the header may reflect a recent routing decision rather than this exact request's. Acceptable for observability and debugging. |
### NATS JWT authentication (recommended for production)
By default, NATS connections are anonymous: any client that can reach port `4222` may publish control-plane subjects such as `nodes.<id>.backend.install`. Enable JWT auth to scope workers to their own node subjects and give the frontend a dedicated service credential.
| Flag | Env Var | Description |
|------|---------|-------------|
| `--nats-account-seed` | `LOCALAI_NATS_ACCOUNT_SEED` | Account signing seed (`SU...`). The frontend mints a per-node user JWT at registration (`nats_jwt` in the register response). |
| `--nats-service-jwt` | `LOCALAI_NATS_SERVICE_JWT` | User JWT for the frontend (and optional fallback for agent workers) to publish install/upgrade and related subjects. |
| `--nats-service-seed` | `LOCALAI_NATS_SERVICE_SEED` | User signing seed (`SU...`) paired with the service JWT. |
| `--nats-worker-jwt-ttl` | `LOCALAI_NATS_WORKER_JWT_TTL` | Lifetime of minted worker JWTs (default `24h`). |
| `--nats-require-auth` | `LOCALAI_NATS_REQUIRE_AUTH` | Fail startup if JWT credentials are missing when distributed mode is enabled. |
### NATS TLS / mTLS (optional)
Use `tls://` in `--nats-url` / `LOCALAI_NATS_URL` for encrypted transport. When the server uses a private CA or requires client certificates, set:
| Flag | Env Var | Description |
|------|---------|-------------|
| `--nats-tls-ca` | `LOCALAI_NATS_TLS_CA` | PEM file to verify the NATS server (private CA) |
| `--nats-tls-cert` | `LOCALAI_NATS_TLS_CERT` | Client certificate for NATS mTLS |
| `--nats-tls-key` | `LOCALAI_NATS_TLS_KEY` | Client private key (required with `--nats-tls-cert`) |
The same env vars apply to backend workers and `local-ai agent-worker`. If the server cert is already trusted by the OS, `tls://` alone is enough.
**Worker register response** (when minting is enabled and the node is approved):
```json
{
"id": "…",
"nats_jwt": "eyJ…",
"nats_user_seed": "SU…"
}
```
Workers connect with that JWT and seed automatically (shown once; store securely). Override with `LOCALAI_NATS_JWT` / `LOCALAI_NATS_USER_SEED` if needed. Set `LOCALAI_NATS_REQUIRE_AUTH=true` on workers when the bus requires credentials.
When `LOCALAI_NATS_REQUIRE_AUTH=true` and no static credentials are provided, a worker that registers while still **pending admin approval** keeps re-registering (with backoff) until an admin approves it and the frontend mints its JWT — it does not start unauthenticated. This retry is **bounded**: if the node is never approved (or no credentials are minted) after a large number of attempts, the worker exits non-zero so the failure is visible (a crash-looping or failed worker) rather than hanging silently. Minted worker JWTs are also **refreshed automatically** before they expire (the worker re-registers at ~75% of the JWT lifetime), so long-running workers survive past `LOCALAI_NATS_WORKER_JWT_TTL`; the NATS connection picks up the new JWT on its next reconnect. If refresh fails persistently, the worker exits (to restart and re-acquire) rather than drifting toward an expired, unrenewable JWT. Statically configured (`LOCALAI_NATS_JWT`) and service (`LOCALAI_NATS_SERVICE_JWT`) credentials are used as-is and not refreshed.
Generate operator/account material with [`scripts/nats-auth-setup.sh`](https://github.com/mudler/LocalAI/blob/master/scripts/nats-auth-setup.sh) (requires [nsc](https://docs.nats.io/running-a-nats-service/configuration/securing_nats/auth_intro/nsc)). Configure the NATS server with account resolver JWTs before enabling `LOCALAI_NATS_REQUIRE_AUTH`.
{{% notice note %}}
`LOCALAI_AUTH` (HTTP users/sessions) and NATS JWTs are separate: end-user API keys do not connect to NATS. HTTP registration still uses `LOCALAI_REGISTRATION_TOKEN`.
{{% /notice %}}
### Optional: S3 Object Storage
For multi-host deployments where workers don't share a filesystem, S3-compatible storage enables distributed file transfer (model files, configs):
@@ -134,6 +178,12 @@ local-ai worker \
| `--registration-token` | `LOCALAI_REGISTRATION_TOKEN` | *(empty)* | Token to authenticate with the frontend |
| `--heartbeat-interval` | `LOCALAI_HEARTBEAT_INTERVAL` | `10s` | Interval between heartbeat pings |
| `--nats-url` | `LOCALAI_NATS_URL` | *(required)* | NATS URL for backend installation and file staging |
| `--nats-jwt` | `LOCALAI_NATS_JWT` | *(empty)* | Optional override for the `nats_jwt` returned at registration |
| `--nats-user-seed` | `LOCALAI_NATS_USER_SEED` | *(empty)* | Optional override for `nats_user_seed` from registration |
| `--nats-require-auth` | `LOCALAI_NATS_REQUIRE_AUTH` | `false` | Require NATS JWT+seed (from registration or env) |
| `--nats-tls-ca` | `LOCALAI_NATS_TLS_CA` | *(empty)* | PEM file for NATS server CA |
| `--nats-tls-cert` | `LOCALAI_NATS_TLS_CERT` | *(empty)* | Client certificate for NATS mTLS |
| `--nats-tls-key` | `LOCALAI_NATS_TLS_KEY` | *(empty)* | Client private key for NATS mTLS |
| `--backends-path` | `LOCALAI_BACKENDS_PATH` | `./backends` | Path to backend binaries |
| `--models-path` | `LOCALAI_MODELS_PATH` | `./models` | Path to model files |

3
go.mod
View File

@@ -41,7 +41,9 @@ require (
github.com/mudler/go-processmanager v0.1.1
github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8
github.com/mudler/xlog v0.0.6
github.com/nats-io/jwt/v2 v2.7.4
github.com/nats-io/nats.go v1.52.0
github.com/nats-io/nkeys v0.4.15
github.com/ollama/ollama v0.20.4
github.com/onsi/ginkgo/v2 v2.29.0
github.com/onsi/gomega v1.41.0
@@ -134,7 +136,6 @@ require (
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/moby/moby/api v1.54.2 // indirect
github.com/moby/moby/client v0.4.1 // indirect
github.com/nats-io/nkeys v0.4.15 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/secure-systems-lab/go-securesystemslib v0.9.1 // indirect

2
go.sum
View File

@@ -1016,6 +1016,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A=
github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM=
github.com/nats-io/jwt/v2 v2.7.4 h1:jXFuDDxs/GQjGDZGhNgH4tXzSUK6WQi2rsj4xmsNOtI=
github.com/nats-io/jwt/v2 v2.7.4/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA=
github.com/nats-io/nats.go v1.52.0 h1:n3avV4VBsCgsdwh71TppsTwtv+QdPs7ntSKM8qJLGsc=
github.com/nats-io/nats.go v1.52.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno=
github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4=

66
pkg/natsauth/config.go Normal file
View File

@@ -0,0 +1,66 @@
package natsauth
import (
"fmt"
"time"
"github.com/mudler/xlog"
)
// DefaultWorkerJWTTTL is how long a worker may use a minted NATS user JWT before re-registering.
const DefaultWorkerJWTTTL = 24 * time.Hour
// Config holds NATS JWT authentication settings for distributed mode.
type Config struct {
// AccountSeed is the NATS account signing seed (SU...). Used to mint per-node worker JWTs.
AccountSeed string
// ServiceUserJWT is a pre-generated user JWT for frontends and agent workers (broad publish).
ServiceUserJWT string
// ServiceUserSeed is the signing seed (SU...) paired with ServiceUserJWT.
ServiceUserSeed string
// WorkerJWTTTL sets expiry on minted worker JWTs. Zero uses DefaultWorkerJWTTTL.
WorkerJWTTTL time.Duration
// RequireAuth rejects anonymous NATS when true (both ServiceUserJWT and AccountSeed expected).
RequireAuth bool
}
// Enabled reports whether any NATS credential material is configured.
func (c Config) Enabled() bool {
return c.AccountSeed != "" || c.ServiceUserJWT != ""
}
// CanMintWorkers reports whether per-node JWTs can be issued at registration.
func (c Config) CanMintWorkers() bool {
return c.AccountSeed != ""
}
// WorkerTTL returns the configured worker JWT lifetime.
func (c Config) WorkerTTL() time.Duration {
if c.WorkerJWTTTL > 0 {
return c.WorkerJWTTTL
}
return DefaultWorkerJWTTTL
}
// Validate checks consistency when distributed NATS auth is required.
func (c Config) Validate() error {
if !c.RequireAuth {
return nil
}
if c.ServiceUserJWT == "" || c.ServiceUserSeed == "" {
return fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
}
if c.AccountSeed == "" {
return fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH is set but LOCALAI_NATS_ACCOUNT_SEED is empty")
}
return nil
}
// WarnIfInsecure logs when distributed NATS is reachable without credentials.
func (c Config) WarnIfInsecure(distributed bool) {
if !distributed || c.Enabled() {
return
}
xlog.Warn("NATS is used without JWT credentials — any client on the bus can publish backend.install. " +
"Set LOCALAI_NATS_ACCOUNT_SEED + LOCALAI_NATS_SERVICE_JWT (see docs/features/distributed-mode.md).")
}

16
pkg/natsauth/decode.go Normal file
View File

@@ -0,0 +1,16 @@
package natsauth
import (
"fmt"
"github.com/nats-io/jwt/v2"
)
// DecodeUserClaims decodes a minted worker JWT for tests and diagnostics.
func DecodeUserClaims(token string) (*jwt.UserClaims, error) {
uc, err := jwt.DecodeUserClaims(token)
if err != nil {
return nil, fmt.Errorf("natsauth: decode user JWT: %w", err)
}
return uc, nil
}

59
pkg/natsauth/mint.go Normal file
View File

@@ -0,0 +1,59 @@
package natsauth
import (
"fmt"
"time"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nkeys"
)
// MintWorkerJWT creates a signed NATS user JWT and user seed scoped to nodeID and nodeType.
// The seed is returned once at registration so the worker can sign NATS connections.
func (c Config) MintWorkerJWT(nodeID, nodeType string) (userJWT, userSeed string, err error) {
if c.AccountSeed == "" {
return "", "", fmt.Errorf("natsauth: account seed not configured")
}
if nodeID == "" {
return "", "", fmt.Errorf("natsauth: node ID is required")
}
accountKP, err := nkeys.FromSeed([]byte(c.AccountSeed))
if err != nil {
return "", "", fmt.Errorf("natsauth: invalid account seed: %w", err)
}
userKP, err := nkeys.CreateUser()
if err != nil {
return "", "", fmt.Errorf("natsauth: create user key: %w", err)
}
seedBytes, err := userKP.Seed()
if err != nil {
return "", "", fmt.Errorf("natsauth: user seed: %w", err)
}
accountPub, err := accountKP.PublicKey()
if err != nil {
return "", "", fmt.Errorf("natsauth: account public key: %w", err)
}
userPub, err := userKP.PublicKey()
if err != nil {
return "", "", fmt.Errorf("natsauth: user public key: %w", err)
}
pubAllow, subAllow := WorkerPermissions(nodeID, nodeType)
uc := jwt.NewUserClaims(userPub)
uc.Name = fmt.Sprintf("localai-%s-%s", nodeType, workerSubjectToken(nodeID))
uc.IssuerAccount = accountPub
uc.Expires = time.Now().Add(c.WorkerTTL()).Unix()
uc.Permissions.Pub.Allow = pubAllow
uc.Permissions.Sub.Allow = subAllow
token, err := uc.Encode(accountKP)
if err != nil {
return "", "", fmt.Errorf("natsauth: encode user JWT: %w", err)
}
return token, string(seedBytes), nil
}

60
pkg/natsauth/mint_test.go Normal file
View File

@@ -0,0 +1,60 @@
package natsauth_test
import (
"testing"
"time"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nkeys"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestNatsAuth(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "NatsAuth")
}
var _ = Describe("MintWorkerJWT", func() {
var accountSeed string
BeforeEach(func() {
akp, err := nkeys.CreateAccount()
Expect(err).NotTo(HaveOccurred())
seed, err := akp.Seed()
Expect(err).NotTo(HaveOccurred())
accountSeed = string(seed)
})
It("mints a JWT with backend worker permissions", func() {
cfg := natsauth.Config{AccountSeed: accountSeed, WorkerJWTTTL: time.Hour}
token, seed, err := cfg.MintWorkerJWT("550e8400-e29b-41d4-a716-446655440000", "backend")
Expect(err).NotTo(HaveOccurred())
Expect(token).NotTo(BeEmpty())
Expect(seed).NotTo(BeEmpty())
uc, err := jwt.DecodeUserClaims(token)
Expect(err).NotTo(HaveOccurred())
Expect(uc.Permissions.Sub.Allow).To(ContainElement("nodes.550e8400-e29b-41d4-a716-446655440000.>"))
Expect(uc.Permissions.Pub.Allow).To(ContainElement("nodes.550e8400-e29b-41d4-a716-446655440000.backend.install.*.progress"))
})
It("mints agent permissions without backend install subscribe", func() {
cfg := natsauth.Config{AccountSeed: accountSeed}
token, _, err := cfg.MintWorkerJWT("node-1", "agent")
Expect(err).NotTo(HaveOccurred())
uc, err := jwt.DecodeUserClaims(token)
Expect(err).NotTo(HaveOccurred())
Expect(uc.Permissions.Sub.Allow).To(ContainElement("agent.execute"))
for _, subj := range uc.Permissions.Sub.Allow {
Expect(subj).NotTo(ContainSubstring("backend.install"))
}
})
It("rejects mint without account seed", func() {
_, _, err := (natsauth.Config{}).MintWorkerJWT("id", "backend")
Expect(err).To(HaveOccurred())
})
})

View File

@@ -0,0 +1,49 @@
package natsauth
import "strings"
// workerSubjectToken mirrors messaging.sanitizeSubjectToken without importing unexported logic.
func workerSubjectToken(nodeID string) string {
r := strings.NewReplacer(".", "-", "*", "-", ">", "-", " ", "-", "\t", "-", "\n", "-")
return r.Replace(nodeID)
}
// WorkerPermissions returns NATS pub/sub allow lists for a registered node.
func WorkerPermissions(nodeID, nodeType string) (pubAllow, subAllow []string) {
tok := workerSubjectToken(nodeID)
prefix := "nodes." + tok
switch nodeType {
case "agent":
// Agent workers consume queue workloads; they must not handle backend.install.
// Keep this list in sync with the subscriptions in core/cli/agent_worker.go.
subAllow = []string{
"agent.execute",
"jobs.*.cancel",
"jobs.*.progress",
"jobs.*.result",
"jobs.mcp-ci.new", // MCP CI jobs dispatched to agent workers
"mcp.tools.execute",
"mcp.discovery",
prefix + ".backend.stop", // stop events drive MCP session cleanup
"_INBOX.>",
}
pubAllow = []string{
"agent.>",
"jobs.>",
"_INBOX.>",
}
default:
// Backend worker: lifecycle + file staging on this node only.
subAllow = []string{
prefix + ".>",
"_INBOX.>",
}
pubAllow = []string{
prefix + ".backend.install.*.progress",
prefix + ".files.>",
"_INBOX.>",
}
}
return pubAllow, subAllow
}

View File

@@ -0,0 +1,134 @@
package natsauth_test
import (
"os"
"regexp"
"strings"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/pkg/natsauth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// subjectMatches implements NATS subject-token matching: "*" matches exactly one
// token and ">" matches one or more trailing tokens. It lets these tests assert
// that a permission allow-list (which uses wildcards) actually covers a concrete
// subject a component publishes/subscribes — the same check the NATS server makes.
func subjectMatches(pattern, subject string) bool {
p := strings.Split(pattern, ".")
s := strings.Split(subject, ".")
for i, tok := range p {
if tok == ">" {
return i < len(s) // ">" must match at least one remaining token
}
if i >= len(s) {
return false
}
if tok != "*" && tok != s[i] {
return false
}
}
return len(p) == len(s)
}
func anyAllows(allow []string, subject string) bool {
for _, p := range allow {
if subjectMatches(p, subject) {
return true
}
}
return false
}
var _ = Describe("WorkerPermissions subject coverage", func() {
// A node ID containing NATS-reserved characters exercises the (duplicated)
// sanitizer in pkg/natsauth against the canonical one in core/services/messaging.
// If the two ever diverge, the minted prefix stops matching the real subject
// and these assertions fail — guarding the copy noted in the review.
const nodeID = "host.a 1*b"
Context("backend worker", func() {
pub, sub := natsauth.WorkerPermissions(nodeID, "backend")
// Every subject core/services/worker/{lifecycle,file_staging}.go subscribes to.
subscribed := []string{
messaging.SubjectNodeBackendInstall(nodeID),
messaging.SubjectNodeBackendUpgrade(nodeID),
messaging.SubjectNodeBackendStop(nodeID),
messaging.SubjectNodeBackendDelete(nodeID),
messaging.SubjectNodeBackendList(nodeID),
messaging.SubjectNodeModelUnload(nodeID),
messaging.SubjectNodeModelDelete(nodeID),
messaging.SubjectNodeStop(nodeID),
messaging.SubjectNodeFilesEnsure(nodeID),
messaging.SubjectNodeFilesStage(nodeID),
messaging.SubjectNodeFilesTemp(nodeID),
messaging.SubjectNodeFilesListDir(nodeID),
}
for _, subject := range subscribed {
It("allows subscribing to "+subject, func() {
Expect(anyAllows(sub, subject)).To(BeTrue(),
"backend JWT sub allow-list %v does not cover %s", sub, subject)
})
}
It("allows publishing backend.install progress", func() {
subject := messaging.SubjectNodeBackendInstallProgress(nodeID, "op-123")
Expect(anyAllows(pub, subject)).To(BeTrue(),
"backend JWT pub allow-list %v does not cover %s", pub, subject)
})
})
Context("agent worker", func() {
// node_type "agent"; subjects from core/cli/agent_worker.go.
pub, sub := natsauth.WorkerPermissions(nodeID, "agent")
_ = pub
subscribed := []string{
messaging.SubjectAgentExecute, // dispatcher (default --agent-subject)
messaging.SubjectMCPToolExecute, // QueueSubscribeReply
messaging.SubjectMCPDiscovery, // QueueSubscribeReply
messaging.SubjectMCPCIJobsNew, // QueueSubscribe — jobs.mcp-ci.new
messaging.SubjectNodeBackendStop(nodeID), // Subscribe — MCP session cleanup
}
for _, subject := range subscribed {
It("allows subscribing to "+subject, func() {
Expect(anyAllows(sub, subject)).To(BeTrue(),
"agent JWT sub allow-list %v does not cover %s — the agent worker subscribes to it", sub, subject)
})
}
})
})
var allowPubRe = regexp.MustCompile(`--allow-pub "([^"]*)"`)
var _ = Describe("Documented NATS service-user permissions", func() {
// scripts/nats-auth-setup.sh ships the recommended service (frontend) JWT
// permissions. They must cover every subject the frontend actually publishes,
// or prefix-cache sync (and friends) break once LOCALAI_NATS_REQUIRE_AUTH is on.
const scriptPath = "../../scripts/nats-auth-setup.sh"
// Representative subjects the frontend publishes on the control plane.
// prefixcache.* is emitted by prefixcache.Sync in core/application/distributed.go.
frontendPublishes := []string{
messaging.SubjectPrefixCacheObserve,
messaging.SubjectPrefixCacheInvalidate,
messaging.SubjectNodeBackendInstall("node-1"),
messaging.SubjectGalleryProgress("op-1"),
}
It("cover every subject the frontend publishes", func() {
raw, err := os.ReadFile(scriptPath)
Expect(err).ToNot(HaveOccurred(), "cannot read %s", scriptPath)
m := allowPubRe.FindStringSubmatch(string(raw))
Expect(m).To(HaveLen(2), "no --allow-pub list found in %s", scriptPath)
allow := strings.Split(m[1], ",")
for _, subject := range frontendPublishes {
Expect(anyAllows(allow, subject)).To(BeTrue(),
"service-user --allow-pub %v does not cover %s (frontend publishes it)", allow, subject)
}
})
})

49
scripts/nats-auth-setup.sh Executable file
View File

@@ -0,0 +1,49 @@
#!/usr/bin/env bash
# Generate NATS account + service user JWTs for LocalAI distributed mode.
#
# Requires: nsc (https://docs.nats.io/running-a-nats-service/configuration/securing_nats/auth_intro/nsc)
#
# Usage:
# ./scripts/nats-auth-setup.sh
#
# Outputs operator/account seeds and a service user JWT suitable for:
# LOCALAI_NATS_ACCOUNT_SEED
# LOCALAI_NATS_SERVICE_JWT
#
# Per-node worker JWTs are minted automatically by the frontend at registration
# when LOCALAI_NATS_ACCOUNT_SEED is set.
set -euo pipefail
if ! command -v nsc >/dev/null 2>&1; then
echo "nsc is required. Install from https://github.com/nats-io/nsc/releases" >&2
exit 1
fi
OPERATOR="${NATS_OPERATOR_NAME:-localai-operator}"
ACCOUNT="${NATS_ACCOUNT_NAME:-localai}"
SERVICE_USER="${NATS_SERVICE_USER:-localai-frontend}"
nsc add operator -n "$OPERATOR" --generate-signing-key
nsc add account -n "$ACCOUNT"
nsc add user -n "$SERVICE_USER" --account "$ACCOUNT"
# Broad publish for frontend control plane (tighten with custom claims in production).
nsc edit user -n "$SERVICE_USER" --account "$ACCOUNT" \
--allow-pub "nodes.>,gallery.>,agent.>,jobs.>,mcp.>,cache.>,prefixcache.>,finetune.>" \
--allow-sub "nodes.>,gallery.>,agent.>,jobs.>,mcp.>,cache.>,prefixcache.>,_INBOX.>"
KEYS_DIR="${NATS_KEYS_DIR:-./nats-keys}"
mkdir -p "$KEYS_DIR"
nsc generate creds -a "$ACCOUNT" -n "$SERVICE_USER" -o "$KEYS_DIR"
ACCOUNT_SEED=$(nsc describe account "$ACCOUNT" -o json | jq -r '.nats.private_key')
SERVICE_JWT=$(cat "$KEYS_DIR/${ACCOUNT}/${SERVICE_USER}.jwt" 2>/dev/null || cat "$KEYS_DIR/${SERVICE_USER}.jwt")
echo ""
echo "=== LocalAI NATS auth material ==="
echo "LOCALAI_NATS_ACCOUNT_SEED=${ACCOUNT_SEED}"
echo "LOCALAI_NATS_SERVICE_JWT=${SERVICE_JWT}"
echo ""
echo "Configure the NATS server with the generated operator/account JWTs under $KEYS_DIR"
echo "and set LOCALAI_NATS_REQUIRE_AUTH=true on frontends and workers in production."

View File

@@ -0,0 +1,156 @@
package distributed_test
import (
"bytes"
"context"
"fmt"
"strings"
"time"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/pkg/natsauth"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nkeys"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/testcontainers/testcontainers-go"
tcnats "github.com/testcontainers/testcontainers-go/modules/nats"
)
// JWTTestInfra holds a NATS server configured with JWT auth and minted worker credentials.
type JWTTestInfra struct {
*TestInfra
AccountSeed string
NodeID string
WorkerJWT string
WorkerSeed string
}
// SetupJWTInfra starts NATS with an in-memory JWT resolver and returns worker credentials
// minted the same way as node registration (pkg/natsauth).
func SetupJWTInfra() *JWTTestInfra {
GinkgoHelper()
infra := &JWTTestInfra{TestInfra: &TestInfra{Ctx: context.Background()}}
operatorJWT, accountJWT, accountSeed, err := jwtResolverMaterial()
Expect(err).ToNot(HaveOccurred())
infra.AccountSeed = accountSeed
conf := fmt.Sprintf(`listen: 0.0.0.0:4222
operator: %s
resolver: MEMORY
resolver_preload: {
%s: %s
}
`, operatorJWT, accountPublicKeyFromSeed(accountSeed), accountJWT)
var natsContainer *tcnats.NATSContainer
// Override default testcontainers -js: JetStream fails without a system account in JWT mode.
natsContainer, err = tcnats.Run(infra.Ctx, "nats:2-alpine",
tcnats.WithConfigFile(bytes.NewBufferString(conf)),
testcontainers.WithCmd("-c", "/etc/nats.conf"),
)
Expect(err).ToNot(HaveOccurred())
infra.NATSContainer = natsContainer
infra.NatsURL, err = infra.NATSContainer.ConnectionString(infra.Ctx)
Expect(err).ToNot(HaveOccurred())
infra.NodeID = "550e8400-e29b-41d4-a716-446655440000"
cfg := natsauth.Config{AccountSeed: infra.AccountSeed, WorkerJWTTTL: time.Hour}
infra.WorkerJWT, infra.WorkerSeed, err = cfg.MintWorkerJWT(infra.NodeID, "backend")
Expect(err).ToNot(HaveOccurred())
infra.NC, err = messaging.New(infra.NatsURL, messaging.WithUserJWT(infra.WorkerJWT, infra.WorkerSeed))
Expect(err).ToNot(HaveOccurred())
DeferCleanup(func() {
if infra.NC != nil {
infra.NC.Close()
}
if infra.NATSContainer != nil {
_ = infra.NATSContainer.Terminate(context.Background())
}
})
return infra
}
// jwtResolverMaterial builds operator + account JWTs for a MEMORY resolver.
// Follows the NATS JWT tutorial: self-signed account, then operator re-sign, with the
// account identity key listed as a signing key so MintWorkerJWT can use the account seed.
func jwtResolverMaterial() (operatorJWT, accountJWT, accountSeed string, err error) {
okp, err := nkeys.CreateOperator()
if err != nil {
return "", "", "", err
}
opk, err := okp.PublicKey()
if err != nil {
return "", "", "", err
}
oc := jwt.NewOperatorClaims(opk)
oc.Name = "localai-test-operator"
oskp, err := nkeys.CreateOperator()
if err != nil {
return "", "", "", err
}
ospk, err := oskp.PublicKey()
if err != nil {
return "", "", "", err
}
oc.SigningKeys.Add(ospk)
operatorJWT, err = oc.Encode(okp)
if err != nil {
return "", "", "", err
}
akp, err := nkeys.CreateAccount()
if err != nil {
return "", "", "", err
}
seed, err := akp.Seed()
if err != nil {
return "", "", "", err
}
accountSeed = string(seed)
apk, err := akp.PublicKey()
if err != nil {
return "", "", "", err
}
ac := jwt.NewAccountClaims(apk)
ac.Name = "localai-test-account"
ac.SigningKeys.Add(apk)
accountJWT, err = ac.Encode(akp)
if err != nil {
return "", "", "", err
}
ac, err = jwt.DecodeAccountClaims(accountJWT)
if err != nil {
return "", "", "", err
}
accountJWT, err = ac.Encode(oskp)
if err != nil {
return "", "", "", err
}
return operatorJWT, accountJWT, accountSeed, nil
}
func accountPublicKeyFromSeed(accountSeed string) string {
akp, err := nkeys.FromSeed([]byte(accountSeed))
Expect(err).ToNot(HaveOccurred())
pk, err := akp.PublicKey()
Expect(err).ToNot(HaveOccurred())
return pk
}
// nodeSubjectPrefix returns the sanitized nodes.* prefix for a node ID.
func nodeSubjectPrefix(nodeID string) string {
tok := strings.NewReplacer(".", "-", "*", "-", ">", "-", " ", "-", "\t", "-", "\n", "-").Replace(nodeID)
return "nodes." + tok
}

View File

@@ -0,0 +1,99 @@
package distributed_test
import (
"time"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/pkg/natsauth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("NATS JWT Auth", Label("Distributed", "NatsJWT"), func() {
var infra *JWTTestInfra
BeforeEach(func() {
infra = SetupJWTInfra()
})
It("connects with a minted backend worker JWT and publishes on allowed subjects", func() {
// Backend workers may publish under nodes.<id>.files.> (see pkg/natsauth permissions).
subject := nodeSubjectPrefix(infra.NodeID) + ".files.in"
Expect(infra.NC.Publish(subject, map[string]string{"path": "/tmp/model"})).To(Succeed())
Expect(infra.NC.Conn().FlushTimeout(2 * time.Second)).To(Succeed())
Expect(infra.NC.Conn().IsConnected()).To(BeTrue())
})
It("allows backend subscribe on the node prefix", func() {
wild := nodeSubjectPrefix(infra.NodeID) + ".>"
sub, err := infra.NC.Subscribe(wild, func(_ []byte) {})
Expect(err).ToNot(HaveOccurred())
defer func() { _ = sub.Unsubscribe() }()
Expect(infra.NC.Conn().FlushTimeout(2 * time.Second)).To(Succeed())
Expect(infra.NC.Conn().IsConnected()).To(BeTrue())
})
It("rejects anonymous publish on the JWT-enabled server", func() {
anon, err := messaging.New(infra.NatsURL)
Expect(err).ToNot(HaveOccurred())
defer anon.Close()
err = anon.Publish("nodes.any.files.x", map[string]string{"x": "1"})
Expect(err).ToNot(HaveOccurred())
Expect(anon.Conn().FlushTimeout(2 * time.Second)).To(HaveOccurred())
})
It("denies backend publish to another node's subjects", func() {
other := nodeSubjectPrefix("other-node-id") + ".files.stage"
Expect(infra.NC.Publish(other, map[string]string{"stage": "nope"})).To(Succeed())
Eventually(func() error {
_ = infra.NC.Conn().FlushTimeout(500 * time.Millisecond)
return infra.NC.Conn().LastError()
}, "3s", "50ms").Should(HaveOccurred())
})
It("mints agent JWT without backend.install in claims", func() {
cfg := natsauth.Config{AccountSeed: infra.AccountSeed}
token, _, err := cfg.MintWorkerJWT("agent-node-1", "agent")
Expect(err).ToNot(HaveOccurred())
claims, err := natsauth.DecodeUserClaims(token)
Expect(err).ToNot(HaveOccurred())
Expect(claims.Permissions.Sub.Allow).To(ContainElement("agent.execute"))
for _, subj := range claims.Permissions.Sub.Allow {
Expect(subj).NotTo(ContainSubstring("backend.install"))
}
})
// Regression guard for the silent permission gaps: decoding the JWT claims
// (above) only proves the agent JWT is *restrictive*, not that it is
// *sufficient*. Stand a real agent connection up against the enforcing
// server and exercise every subscription core/cli/agent_worker.go actually
// makes — a denied SUB now surfaces synchronously via confirmSubscription,
// so a missing allow rule fails this test instead of silently dropping
// backend.stop / MCP-CI deliveries at runtime.
It("lets an agent-minted JWT establish all the subscriptions the agent worker uses", func() {
const nodeID = "agent-node-subs"
cfg := natsauth.Config{AccountSeed: infra.AccountSeed, WorkerJWTTTL: time.Hour}
token, seed, err := cfg.MintWorkerJWT(nodeID, "agent")
Expect(err).ToNot(HaveOccurred())
nc, err := messaging.New(infra.NatsURL, messaging.WithUserJWT(token, seed))
Expect(err).ToNot(HaveOccurred())
DeferCleanup(nc.Close)
// Mirror core/cli/agent_worker.go exactly.
_, err = nc.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func([]byte, func([]byte)) {})
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s", messaging.SubjectMCPToolExecute)
_, err = nc.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func([]byte, func([]byte)) {})
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s", messaging.SubjectMCPDiscovery)
_, err = nc.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func([]byte) {})
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s (MCP CI jobs)", messaging.SubjectMCPCIJobsNew)
_, err = nc.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func([]byte) {})
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s (MCP session cleanup)", messaging.SubjectNodeBackendStop(nodeID))
})
})