diff --git a/Makefile b/Makefile index 5ea1db5bd..9503bf7b4 100644 --- a/Makefile +++ b/Makefile @@ -309,13 +309,20 @@ run-e2e-aio: protogen-go @echo 'Running e2e AIO tests' $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio +# Distributed architecture e2e (PostgreSQL + NATS via testcontainers). +# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker. +# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that. +test-e2e-distributed: protogen-go + @echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)' + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed + # vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the # cpu-vllm backend from the current working tree, then drives a # head + headless follower via testcontainers-go and asserts a chat # completion. BuildKit caches both images, so re-runs only rebuild # what changed. The test lives under tests/e2e/distributed and is # selected by the VLLMMultinode label so it doesn't run alongside -# the other distributed-suite tests by default. +# test-e2e-distributed. test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go @echo 'Running e2e vLLM multi-node DP test' LOCALAI_IMAGE=local-ai \ diff --git a/core/application/distributed.go b/core/application/distributed.go index 64e8dc12e..3c8c6ec32 100644 --- a/core/application/distributed.go +++ b/core/application/distributed.go @@ -102,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID) // Connect to NATS - natsClient, err := messaging.New(cfg.Distributed.NatsURL) + natsAuth := cfg.Distributed.NatsAuthConfig() + if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") { + return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED") + } + natsOpts := cfg.Distributed.NatsMessagingOptions("", "") + natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...) if err != nil { return nil, fmt.Errorf("connecting to NATS: %w", err) } diff --git a/core/cli/agent_worker.go b/core/cli/agent_worker.go index 2fdf7dd0c..a6ceb3daf 100644 --- a/core/cli/agent_worker.go +++ b/core/cli/agent_worker.go @@ -52,6 +52,15 @@ type AgentWorkerCMD struct { Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"` Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"` + NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"` + NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"` + NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"` + NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"` + NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"` + NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"` + NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"` + NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"` + // Timeouts MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"` } @@ -81,15 +90,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error { registrationBody["token"] = cmd.RegistrationToken } - nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10) + // Context cancelled on shutdown — used by registration waits, heartbeat, and + // other background goroutines. + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + defer shutdownCancel() + + // Acquire credentials via (re)registration. When the bus requires auth and no + // static fallback is configured, wait through admin approval until the + // frontend mints credentials rather than starting unauthenticated. + credMgr := workerregistry.NewNATSCredentialManager( + func(ctx context.Context) (*workerregistry.RegisterResponse, error) { + return regClient.RegisterFull(ctx, registrationBody) + }, + cmd.NatsRequireAuth && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "", + ) + res, err := credMgr.Acquire(shutdownCtx) if err != nil { return fmt.Errorf("registration failed: %w", err) } + nodeID := res.ID xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo) // Use provisioned API token if none was set if cmd.APIToken == "" { - cmd.APIToken = apiToken + cmd.APIToken = res.APIToken } // Start heartbeat @@ -98,14 +122,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error { xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err) } heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second) - // Context cancelled on shutdown — used by heartbeat and other background goroutines - shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) - defer shutdownCancel() go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} }) - // Connect to NATS - natsClient, err := messaging.New(cmd.NatsURL) + // Resolve NATS credentials with precedence: explicit env override, then + // frontend-minted (auto-refreshed before expiry), then service fallback. + // Each static source must supply JWT and seed together. + natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey} + var natsOpts []messaging.Option + switch { + case cmd.NatsJWT != "" || cmd.NatsUserSeed != "": + if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") { + return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together") + } + natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed)) + case credMgr.HasCredentials(): + natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider())) + go func() { + if err := credMgr.RefreshLoop(shutdownCtx); err != nil { + xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err) + shutdownCancel() + } + }() + case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "": + if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") { + return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together") + } + natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed)) + case cmd.NatsRequireAuth: + return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars") + } + if natsTLS.Enabled() { + natsOpts = append(natsOpts, messaging.WithTLS(natsTLS)) + } + natsClient, err := messaging.New(cmd.NatsURL, natsOpts...) if err != nil { return fmt.Errorf("connecting to NATS: %w", err) } @@ -183,17 +233,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error { xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue) - // Wait for shutdown + // Wait for an OS signal or an internal fatal condition (e.g. NATS + // credentials became unrenewable), so the worker restarts and re-acquires + // rather than lingering unable to serve. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh + var runErr error + select { + case <-sigCh: + case <-shutdownCtx.Done(): + runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable") + xlog.Error("Internal shutdown requested", "error", runErr) + } xlog.Info("Shutting down agent worker") shutdownCancel() // stop heartbeat loop immediately dispatcher.Stop() mcpTools.CloseAllMCPSessions() regClient.GracefulDeregister(nodeID) - return nil + return runErr } // handleMCPToolRequest handles a NATS request-reply for MCP tool execution. diff --git a/core/cli/run.go b/core/cli/run.go index 8bbc2b20c..9d23d38d6 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -159,6 +159,14 @@ type RunCMD struct { DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"` BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"` BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"` + NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"` + NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"` + NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"` + NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"` + NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"` + NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"` + NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"` + NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"` ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"` Version bool @@ -283,6 +291,34 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { if r.RegistrationToken != "" { opts = append(opts, config.WithRegistrationToken(r.RegistrationToken)) } + if r.NatsAccountSeed != "" { + opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed)) + } + if r.NatsServiceJWT != "" { + opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT)) + } + if r.NatsServiceSeed != "" { + opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed)) + } + if r.NatsWorkerJWTTTL != "" { + d, err := time.ParseDuration(r.NatsWorkerJWTTTL) + if err != nil { + return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err) + } + opts = append(opts, config.WithNatsWorkerJWTTTL(d)) + } + if r.NatsRequireAuth { + opts = append(opts, config.EnableNatsRequireAuth) + } + if r.NatsTLSCA != "" { + opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA)) + } + if r.NatsTLSCert != "" { + opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert)) + } + if r.NatsTLSKey != "" { + opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey)) + } if r.AutoApproveNodes { opts = append(opts, config.EnableAutoApproveNodes) } diff --git a/core/cli/worker/worker_vllm.go b/core/cli/worker/worker_vllm.go index 1471d780f..8596546b5 100644 --- a/core/cli/worker/worker_vllm.go +++ b/core/cli/worker/worker_vllm.go @@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error { FrontendURL: r.RegisterTo, RegistrationToken: r.RegistrationToken, } - nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10) + nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10) if regErr != nil { return fmt.Errorf("registering with frontend: %w", regErr) } diff --git a/core/cli/workerregistry/client.go b/core/cli/workerregistry/client.go index 0af102787..cf46455c9 100644 --- a/core/cli/workerregistry/client.go +++ b/core/cli/workerregistry/client.go @@ -58,65 +58,77 @@ func (c *RegistrationClient) setAuth(req *http.Request) { // RegisterResponse is the JSON body returned by /api/node/register. type RegisterResponse struct { - ID string `json:"id"` - APIToken string `json:"api_token,omitempty"` + ID string `json:"id"` + Status string `json:"status,omitempty"` // "pending" until an admin approves the node + APIToken string `json:"api_token,omitempty"` + NatsJWT string `json:"nats_jwt,omitempty"` + NatsUserSeed string `json:"nats_user_seed,omitempty"` } -// Register sends a single registration request and returns the node ID and -// (optionally) an auto-provisioned API token. -func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) { +// RegisterFull sends a single registration request and returns the full +// response (node ID, approval status, and optional API token / NATS creds). +// Re-registration is idempotent: the frontend preserves the node row and mints +// a fresh NATS JWT each call, so this doubles as the credential-refresh call. +func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) { jsonBody, _ := json.Marshal(body) url := c.baseURL() + "/api/node/register" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { - return "", "", fmt.Errorf("creating request: %w", err) + return nil, fmt.Errorf("creating request: %w", err) } req.Header.Set("Content-Type", "application/json") c.setAuth(req) resp, err := c.httpClient().Do(req) if err != nil { - return "", "", fmt.Errorf("posting to %s: %w", url, err) + return nil, fmt.Errorf("posting to %s: %w", url, err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode) + return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode) } var result RegisterResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", "", fmt.Errorf("decoding response: %w", err) + return nil, fmt.Errorf("decoding response: %w", err) } - return result.ID, result.APIToken, nil + return &result, nil +} + +// Register sends a single registration request and returns the node ID and +// optional credentials (API token for agent workers, NATS JWT when configured). +func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) { + res, err := c.RegisterFull(ctx, body) + if err != nil { + return "", "", "", "", err + } + return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil } // RegisterWithRetry retries registration with exponential backoff. -func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) { +func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) { backoff := 2 * time.Second maxBackoff := 30 * time.Second - var nodeID, apiToken string - var err error - for attempt := 1; attempt <= maxRetries; attempt++ { - nodeID, apiToken, err = c.Register(ctx, body) + nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body) if err == nil { - return nodeID, apiToken, nil + return nodeID, apiToken, natsJWT, natsSeed, nil } if attempt == maxRetries { - return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err) + return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err) } xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err) select { case <-ctx.Done(): - return "", "", ctx.Err() + return "", "", "", "", ctx.Err() case <-time.After(backoff): } backoff = min(backoff*2, maxBackoff) } - return nodeID, apiToken, err + return nodeID, apiToken, natsJWT, natsSeed, err } // Heartbeat sends a single heartbeat POST with the given body. diff --git a/core/cli/workerregistry/credentials.go b/core/cli/workerregistry/credentials.go new file mode 100644 index 000000000..24dd6f3c8 --- /dev/null +++ b/core/cli/workerregistry/credentials.go @@ -0,0 +1,200 @@ +package workerregistry + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/mudler/LocalAI/pkg/natsauth" + "github.com/mudler/xlog" +) + +// statusPending mirrors nodes.StatusPending. It is duplicated rather than +// imported so the lightweight registration client does not pull in the nodes +// package (and its gorm/DB dependencies). +const statusPending = "pending" + +// defaultMaxAttempts bounds how many times Acquire registers (and how many +// consecutive times RefreshLoop may fail) before giving up. It is high enough +// to ride out a slow admin approval or a transient frontend outage, but finite +// so an unauthorized/unapprovable worker exits and surfaces the problem (via a +// non-zero exit and the resulting restart) rather than waiting forever. +const defaultMaxAttempts = 100 + +// RegisterFunc performs one idempotent registration round-trip. +type RegisterFunc func(ctx context.Context) (*RegisterResponse, error) + +// NATSCredentialManager acquires NATS credentials at startup — waiting through +// admin approval when required — and refreshes them before the minted JWT +// expires, by re-registering (which mints a fresh JWT). The live NATS +// connection adopts a refreshed JWT on its next reconnect via Provider. Safe +// for concurrent use. +// +// It addresses two failure modes: a worker that needs credentials but registers +// while still pending approval (it would otherwise give up and never connect), +// and a long-running worker whose 24h JWT expires with no way to renew it. +type NATSCredentialManager struct { + register RegisterFunc + requireCreds bool // block until credentials are present (frontend minting in use) + + // Tunables; defaults set by NewNATSCredentialManager, overridable in tests. + initialBackoff time.Duration + maxBackoff time.Duration + maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited) + refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed + refreshRetry time.Duration + expiryOf func(jwt string) (time.Time, bool) + + mu sync.RWMutex + jwt string + seed string + nodeID string +} + +// NewNATSCredentialManager builds a manager over register. When requireCreds is +// true, Acquire blocks until the node is approved and credentials are minted. +func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager { + return &NATSCredentialManager{ + register: register, + requireCreds: requireCreds, + initialBackoff: 2 * time.Second, + maxBackoff: 30 * time.Second, + maxAttempts: defaultMaxAttempts, + refreshLead: 0.75, + refreshRetry: 30 * time.Second, + expiryOf: jwtExpiry, + } +} + +// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token +// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT). +func jwtExpiry(token string) (time.Time, bool) { + if token == "" { + return time.Time{}, false + } + uc, err := natsauth.DecodeUserClaims(token) + if err != nil || uc.Expires == 0 { + return time.Time{}, false + } + return time.Unix(uc.Expires, 0), true +} + +func (m *NATSCredentialManager) store(res *RegisterResponse) { + m.mu.Lock() + defer m.mu.Unlock() + m.nodeID = res.ID + if res.NatsJWT != "" && res.NatsUserSeed != "" { + m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed + } +} + +// Current returns the latest NATS credentials (both empty until acquired). +func (m *NATSCredentialManager) Current() (jwt, seed string) { + m.mu.RLock() + defer m.mu.RUnlock() + return m.jwt, m.seed +} + +// NodeID returns the node ID from the most recent registration. +func (m *NATSCredentialManager) NodeID() string { + m.mu.RLock() + defer m.mu.RUnlock() + return m.nodeID +} + +// Provider returns a callback compatible with messaging.WithUserJWTProvider, +// supplying the current credentials on each (re)connect. +func (m *NATSCredentialManager) Provider() func() (string, string) { + return m.Current +} + +// HasCredentials reports whether complete NATS credentials have been obtained. +func (m *NATSCredentialManager) HasCredentials() bool { + jwt, seed := m.Current() + return jwt != "" && seed != "" +} + +// Acquire registers and, when requireCreds is set, keeps re-registering with +// exponential backoff until the node is approved (status != pending) and +// credentials are minted. Without requireCreds it returns the first successful +// response (the historical one-shot behavior, preserved for anonymous NATS). +func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) { + backoff := m.initialBackoff + var lastReason error + for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ { + res, err := m.register(ctx) + switch { + case err != nil: + lastReason = err + xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err) + case !m.requireCreds: + m.store(res) + return res, nil + case res.Status == statusPending: + lastReason = fmt.Errorf("node %s still pending admin approval", res.ID) + xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff) + case res.NatsJWT == "" || res.NatsUserSeed == "": + lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID) + xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff) + default: + m.store(res) + return res, nil + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(backoff): + } + backoff = min(backoff*2, m.maxBackoff) + } + return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason) +} + +// RefreshLoop re-registers to mint a fresh JWT before the current one expires, +// updating the credentials returned by Current/Provider so the NATS connection +// adopts them on its next reconnect. It returns nil when ctx is cancelled or +// when the current credential has no expiry (nothing to refresh), and a non-nil +// error after maxAttempts consecutive refresh failures — letting the caller +// exit the worker so it restarts and re-acquires (or surfaces the outage) +// rather than silently drifting toward an expired, unrenewable JWT. +func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error { + failures := 0 + for { + jwt, _ := m.Current() + exp, ok := m.expiryOf(jwt) + if !ok { + xlog.Debug("NATS credential has no expiry; refresh loop exiting") + return nil + } + wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0) + select { + case <-ctx.Done(): + return nil + case <-time.After(wait): + } + + res, err := m.register(ctx) + if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" { + m.store(res) + failures = 0 + xlog.Info("Refreshed NATS credentials", "node", res.ID) + continue + } + failures++ + if err != nil { + xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err) + } else { + xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures) + } + if m.maxAttempts > 0 && failures >= m.maxAttempts { + return fmt.Errorf("NATS credential refresh failed %d times in a row", failures) + } + // Back off before retrying so a persistent failure near expiry does not spin. + select { + case <-ctx.Done(): + return nil + case <-time.After(m.refreshRetry): + } + } +} diff --git a/core/cli/workerregistry/credentials_test.go b/core/cli/workerregistry/credentials_test.go new file mode 100644 index 000000000..02ed413c4 --- /dev/null +++ b/core/cli/workerregistry/credentials_test.go @@ -0,0 +1,198 @@ +package workerregistry + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/mudler/LocalAI/pkg/natsauth" + "github.com/nats-io/nkeys" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestWorkerRegistry(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "WorkerRegistry") +} + +// fakeRegister returns a sequence of canned responses/errors, one per call, and +// records how many times it was invoked. The last entry repeats once exhausted. +type fakeRegister struct { + mu sync.Mutex + steps []step + calls int +} + +type step struct { + res *RegisterResponse + err error +} + +func (f *fakeRegister) fn() RegisterFunc { + return func(context.Context) (*RegisterResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + i := f.calls + f.calls++ + if i >= len(f.steps) { + i = len(f.steps) - 1 + } + return f.steps[i].res, f.steps[i].err + } +} + +func (f *fakeRegister) count() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls +} + +var _ = Describe("NATSCredentialManager", func() { + approved := func(jwt, seed string) *RegisterResponse { + return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed} + } + pending := &RegisterResponse{ID: "node-1", Status: "pending"} + + Describe("Acquire (#4 — wait through admin approval)", func() { + It("keeps re-registering until the node is approved and credentials are minted", func() { + f := &fakeRegister{steps: []step{ + {res: pending}, // not approved yet + {res: approved("", "")}, // approved but JWT not minted yet + {res: approved("jwt-1", "seed-1")}, // finally minted + }} + m := NewNATSCredentialManager(f.fn(), true /* requireCreds */) + m.initialBackoff = time.Millisecond + m.maxBackoff = time.Millisecond + + res, err := m.Acquire(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(res.ID).To(Equal("node-1")) + Expect(f.count()).To(Equal(3)) + + jwt, seed := m.Current() + Expect(jwt).To(Equal("jwt-1")) + Expect(seed).To(Equal("seed-1")) + Expect(m.HasCredentials()).To(BeTrue()) + Expect(m.NodeID()).To(Equal("node-1")) + }) + + It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() { + f := &fakeRegister{steps: []step{{res: pending}}} + m := NewNATSCredentialManager(f.fn(), false /* requireCreds */) + + res, err := m.Acquire(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(res.Status).To(Equal("pending")) + Expect(f.count()).To(Equal(1)) + Expect(m.HasCredentials()).To(BeFalse()) + }) + + It("aborts when the context is cancelled while waiting for approval", func() { + f := &fakeRegister{steps: []step{{res: pending}}} + m := NewNATSCredentialManager(f.fn(), true) + m.initialBackoff = 10 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := m.Acquire(ctx) + Expect(err).To(MatchError(context.Canceled)) + }) + + It("gives up after a bounded number of attempts so the worker exits and alerts", func() { + f := &fakeRegister{steps: []step{{res: pending}}} // never approved + m := NewNATSCredentialManager(f.fn(), true) + m.initialBackoff = time.Millisecond + m.maxBackoff = time.Millisecond + m.maxAttempts = 5 + + _, err := m.Acquire(context.Background()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("after 5 attempts")) + Expect(err.Error()).To(ContainSubstring("pending admin approval")) + Expect(f.count()).To(Equal(5)) + }) + }) + + Describe("RefreshLoop (#5 — renew before the JWT expires)", func() { + It("re-registers before expiry and updates the credentials served to new connections", func() { + f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}} + m := NewNATSCredentialManager(f.fn(), true) + m.refreshLead = 0.5 + m.refreshRetry = time.Millisecond + // jwt-1 expires soon; jwt-2 is long-lived so the loop then idles. + m.expiryOf = func(jwt string) (time.Time, bool) { + switch jwt { + case "jwt-1": + return time.Now().Add(40 * time.Millisecond), true + case "jwt-2": + return time.Now().Add(time.Hour), true + default: + return time.Time{}, false + } + } + m.store(approved("jwt-1", "seed-1")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { _ = m.RefreshLoop(ctx) }() + + Eventually(func() string { + jwt, _ := m.Current() + return jwt + }, "2s", "10ms").Should(Equal("jwt-2")) + }) + + It("returns an error after the bounded number of consecutive failures so the caller can exit", func() { + f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails + m := NewNATSCredentialManager(f.fn(), true) + m.refreshLead = 0.5 + m.refreshRetry = time.Millisecond + m.maxAttempts = 3 + m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true } + m.store(approved("jwt-1", "seed-1")) + + errCh := make(chan error, 1) + go func() { errCh <- m.RefreshLoop(context.Background()) }() + Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row")))) + }) + + It("exits promptly when the current credential has no expiry (nothing to refresh)", func() { + f := &fakeRegister{steps: []step{{res: approved("x", "y")}}} + m := NewNATSCredentialManager(f.fn(), true) + m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false } + m.store(approved("static", "seed")) + + done := make(chan struct{}) + go func() { _ = m.RefreshLoop(context.Background()); close(done) }() + Eventually(done, "1s").Should(BeClosed()) + Expect(f.count()).To(Equal(0)) // never tried to re-register + }) + }) + + Describe("jwtExpiry default", func() { + It("decodes the expiry of a real minted worker JWT", func() { + akp, err := nkeys.CreateAccount() + Expect(err).ToNot(HaveOccurred()) + seed, err := akp.Seed() + Expect(err).ToNot(HaveOccurred()) + + cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour} + token, _, err := cfg.MintWorkerJWT("node-1", "backend") + Expect(err).ToNot(HaveOccurred()) + + exp, ok := jwtExpiry(token) + Expect(ok).To(BeTrue()) + Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute)) + }) + + It("reports no expiry for an empty or undecodable token", func() { + _, ok := jwtExpiry("") + Expect(ok).To(BeFalse()) + _, ok = jwtExpiry("not-a-jwt") + Expect(ok).To(BeFalse()) + }) + }) +}) diff --git a/core/config/distributed_config.go b/core/config/distributed_config.go index e0e0454d9..7e74f5d61 100644 --- a/core/config/distributed_config.go +++ b/core/config/distributed_config.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/natsauth" "github.com/mudler/xlog" ) @@ -18,6 +20,16 @@ type DistributedConfig struct { RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration) AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers) + // NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md) + NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs + NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers + NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT + NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h) + NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing + NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify) + NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS + NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert + // S3 configuration (used when StorageURL is set) StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION @@ -80,6 +92,13 @@ func (c DistributedConfig) Validate() error { if c.RegistrationToken == "" { xlog.Warn("distributed mode running without registration token — node endpoints are unprotected") } + if err := c.NatsAuthConfig().Validate(); err != nil { + return err + } + if err := c.NatsTLSFiles().Validate(); err != nil { + return err + } + c.NatsAuthConfig().WarnIfInsecure(true) // Check for negative durations for name, d := range map[string]time.Duration{ FlagMCPToolTimeout: c.MCPToolTimeout, @@ -123,6 +142,52 @@ func WithRegistrationToken(token string) AppOption { } } +func WithNatsAccountSeed(seed string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.NatsAccountSeed = seed + } +} + +func WithNatsServiceJWT(jwt string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.NatsServiceJWT = jwt + } +} + +func WithNatsServiceSeed(seed string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.NatsServiceSeed = seed + } +} + +func WithNatsWorkerJWTTTL(d time.Duration) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.NatsWorkerJWTTTL = d + } +} + +var EnableNatsRequireAuth = func(o *ApplicationConfig) { + o.Distributed.NatsRequireAuth = true +} + +func WithNatsTLSCA(path string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.NatsTLSCA = path + } +} + +func WithNatsTLSCert(path string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.NatsTLSCert = path + } +} + +func WithNatsTLSKey(path string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.NatsTLSKey = path + } +} + func WithStorageURL(url string) AppOption { return func(o *ApplicationConfig) { o.Distributed.StorageURL = url @@ -217,6 +282,44 @@ const ( // DefaultMaxUploadSize is the default maximum upload body size (50 GB). const DefaultMaxUploadSize int64 = 50 << 30 +// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client. +func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles { + return messaging.TLSFiles{ + CA: c.NatsTLSCA, + Cert: c.NatsTLSCert, + Key: c.NatsTLSKey, + } +} + +// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components. +// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config. +func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option { + var opts []messaging.Option + jwt, seed := userJWT, userSeed + if jwt == "" && seed == "" { + auth := c.NatsAuthConfig() + jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed + } + if jwt != "" && seed != "" { + opts = append(opts, messaging.WithUserJWT(jwt, seed)) + } + if tls := c.NatsTLSFiles(); tls.Enabled() { + opts = append(opts, messaging.WithTLS(tls)) + } + return opts +} + +// NatsAuthConfig builds pkg/natsauth settings from distributed configuration. +func (c DistributedConfig) NatsAuthConfig() natsauth.Config { + return natsauth.Config{ + AccountSeed: c.NatsAccountSeed, + ServiceUserJWT: c.NatsServiceJWT, + ServiceUserSeed: c.NatsServiceSeed, + WorkerJWTTTL: c.NatsWorkerJWTTTL, + RequireAuth: c.NatsRequireAuth, + } +} + // BackendInstallTimeoutOrDefault returns the configured timeout or the default. func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration { return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout) diff --git a/core/http/app.go b/core/http/app.go index 79a1067b3..9ec0711fb 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -420,8 +420,9 @@ func API(application *application.Application) (*echo.Echo, error) { remoteUnloader = d.Router.Unloader() } } - routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret) - routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken) + natsCfg := distCfg.NatsAuthConfig() + routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg) + routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg) // Distributed SSE routes (job progress + agent events via NATS) if d := application.Distributed(); d != nil { diff --git a/core/http/endpoints/localai/nodes.go b/core/http/endpoints/localai/nodes.go index 90dbe70b0..930070506 100644 --- a/core/http/endpoints/localai/nodes.go +++ b/core/http/endpoints/localai/nodes.go @@ -28,6 +28,7 @@ import ( "github.com/mudler/LocalAI/core/services/nodes" "github.com/mudler/LocalAI/core/services/nodes/prefixcache" "github.com/mudler/LocalAI/pkg/httpclient" + "github.com/mudler/LocalAI/pkg/natsauth" ) // nodeError builds a schema.ErrorResponse for node endpoints. @@ -89,7 +90,7 @@ type RegisterNodeRequest struct { // RegisterNodeEndpoint registers a new backend node. // expectedToken is the registration token configured on the frontend (may be empty to disable auth). // autoApprove controls whether new nodes go directly to "healthy" or require admin approval. -func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc { +func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc { return func(c echo.Context) error { var req RegisterNodeRequest if err := c.Bind(&req); err != nil { @@ -217,13 +218,15 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au } } + attachNatsJWT(response, node, natsCfg) + return c.JSON(http.StatusCreated, response) } } // ApproveNodeEndpoint approves a pending node, setting its status to healthy. // For agent workers, it also provisions an API key so they can call the inference API. -func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc { +func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc { return func(c echo.Context) error { ctx := c.Request().Context() id := c.Param("id") @@ -253,10 +256,26 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr } } + attachNatsJWT(response, node, natsCfg) + return c.JSON(http.StatusOK, response) } } +// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled. +func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) { + if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending { + return + } + jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType) + if err != nil { + xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err) + return + } + response["nats_jwt"] = jwt + response["nats_user_seed"] = seed +} + // provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node. // Returns the plaintext API key on success. func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) { diff --git a/core/http/endpoints/localai/nodes_test.go b/core/http/endpoints/localai/nodes_test.go index fdb29987d..bca6f42bf 100644 --- a/core/http/endpoints/localai/nodes_test.go +++ b/core/http/endpoints/localai/nodes_test.go @@ -12,6 +12,8 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/services/nodes" "github.com/mudler/LocalAI/core/services/testutil" + "github.com/mudler/LocalAI/pkg/natsauth" + "github.com/nats-io/nkeys" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -63,7 +65,7 @@ var _ = Describe("Node HTTP handlers", func() { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := RegisterNodeEndpoint(registry, "", true, nil, "") + handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{}) Expect(handler(c)).To(Succeed()) Expect(rec.Code).To(Equal(http.StatusCreated)) @@ -74,6 +76,29 @@ var _ = Describe("Node HTTP handlers", func() { Expect(resp["status"]).To(Equal(nodes.StatusHealthy)) }) + It("returns nats_jwt when account seed is configured", func() { + akp, err := nkeys.CreateAccount() + Expect(err).ToNot(HaveOccurred()) + seed, err := akp.Seed() + Expect(err).ToNot(HaveOccurred()) + + e := echo.New() + body := `{"name":"worker-nats","address":"10.0.0.2:50051"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + natsCfg := natsauth.Config{AccountSeed: string(seed)} + handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg) + Expect(handler(c)).To(Succeed()) + Expect(rec.Code).To(Equal(http.StatusCreated)) + + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + Expect(resp["nats_jwt"]).ToNot(BeEmpty()) + }) + It("returns 400 when name is missing", func() { e := echo.New() body := `{"address":"10.0.0.1:50051"}` @@ -82,7 +107,7 @@ var _ = Describe("Node HTTP handlers", func() { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := RegisterNodeEndpoint(registry, "", true, nil, "") + handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{}) Expect(handler(c)).To(Succeed()) Expect(rec.Code).To(Equal(http.StatusBadRequest)) @@ -102,7 +127,7 @@ var _ = Describe("Node HTTP handlers", func() { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := RegisterNodeEndpoint(registry, "", true, nil, "") + handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{}) Expect(handler(c)).To(Succeed()) Expect(rec.Code).To(Equal(http.StatusBadRequest)) @@ -121,7 +146,7 @@ var _ = Describe("Node HTTP handlers", func() { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := RegisterNodeEndpoint(registry, "", true, nil, "") + handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{}) Expect(handler(c)).To(Succeed()) Expect(rec.Code).To(Equal(http.StatusBadRequest)) @@ -140,7 +165,7 @@ var _ = Describe("Node HTTP handlers", func() { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := RegisterNodeEndpoint(registry, "", true, nil, "") + handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{}) Expect(handler(c)).To(Succeed()) Expect(rec.Code).To(Equal(http.StatusBadRequest)) @@ -159,7 +184,7 @@ var _ = Describe("Node HTTP handlers", func() { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "") + handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{}) Expect(handler(c)).To(Succeed()) Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) @@ -172,7 +197,7 @@ var _ = Describe("Node HTTP handlers", func() { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := RegisterNodeEndpoint(registry, "", false, nil, "") + handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{}) Expect(handler(c)).To(Succeed()) Expect(rec.Code).To(Equal(http.StatusCreated)) @@ -195,7 +220,7 @@ var _ = Describe("Node HTTP handlers", func() { req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1)) req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec1 := httptest.NewRecorder() - handler := RegisterNodeEndpoint(registry, "", true, nil, "") + handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{}) Expect(handler(e.NewContext(req1, rec1))).To(Succeed()) Expect(rec1.Code).To(Equal(http.StatusCreated)) diff --git a/core/http/routes/nodes.go b/core/http/routes/nodes.go index bbf574c6b..d6f5b8dab 100644 --- a/core/http/routes/nodes.go +++ b/core/http/routes/nodes.go @@ -10,6 +10,7 @@ import ( "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/core/services/nodes" + "github.com/mudler/LocalAI/pkg/natsauth" "gorm.io/gorm" ) @@ -35,7 +36,7 @@ func nodeReadyMiddleware(registry *nodes.NodeRegistry) echo.MiddlewareFunc { // token but do not verify per-node identity. A compromised worker can heartbeat/drain/ // deregister other nodes. Future: issue per-node JWT at registration, validate node // identity on subsequent requests (compare :id param with token subject). -func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) { +func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) { if registry == nil { return } @@ -44,7 +45,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r tokenAuthMw := nodeTokenAuth(registrationToken) node := e.Group("/api/node", readyMw, tokenAuthMw) - node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret)) + node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret, natsCfg)) node.POST("/:id/heartbeat", localai.HeartbeatEndpoint(registry)) node.POST("/:id/drain", localai.DrainNodeEndpoint(registry)) node.POST("/:id/resume", localai.ResumeNodeEndpoint(registry)) @@ -60,7 +61,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r // backend install path (POST /:id/backends/install). That handler enqueues a // ManagementOp on the gallery channel rather than blocking on a NATS reply, so // the browser gets HTTP 202 + jobID immediately instead of waiting up to 3 minutes. -func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string) { +func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string, natsCfg natsauth.Config) { if registry == nil { return } @@ -81,7 +82,7 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry)) admin.POST("/:id/drain", localai.DrainNodeEndpoint(registry)) admin.POST("/:id/resume", localai.ResumeNodeEndpoint(registry)) - admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret)) + admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret, natsCfg)) // Backend management on workers admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader)) diff --git a/core/services/messaging/client.go b/core/services/messaging/client.go index cd5f2363c..31257f1fd 100644 --- a/core/services/messaging/client.go +++ b/core/services/messaging/client.go @@ -2,15 +2,22 @@ package messaging import ( "encoding/json" + "errors" "fmt" + "strings" "sync" "time" "github.com/mudler/LocalAI/pkg/sanitize" "github.com/mudler/xlog" "github.com/nats-io/nats.go" + "github.com/nats-io/nkeys" ) +// subscribeConfirmTimeout bounds the server round-trip used to detect whether a +// subscription was rejected (e.g. by JWT permissions) before returning to the caller. +const subscribeConfirmTimeout = 5 * time.Second + // Client wraps a NATS connection and provides helpers for pub/sub and queue subscriptions. type Client struct { conn *nats.Conn @@ -18,8 +25,13 @@ type Client struct { } // New creates a new NATS client with auto-reconnect. -func New(url string) (*Client, error) { - nc, err := nats.Connect(url, +func New(url string, opts ...Option) (*Client, error) { + var cfg connectConfig + for _, o := range opts { + o(&cfg) + } + + natsOpts := []nats.Option{ nats.RetryOnFailedConnect(true), nats.MaxReconnects(-1), nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { @@ -33,7 +45,60 @@ func New(url string) (*Client, error) { nats.ClosedHandler(func(_ *nats.Conn) { xlog.Info("NATS connection closed") }), - ) + // Surface async errors (notably permission violations) that NATS would + // otherwise deliver silently. A subscription the server rejects for a + // JWT permission means the worker never receives those messages, so make + // it loud rather than letting the feature fail invisibly. + nats.ErrorHandler(func(_ *nats.Conn, sub *nats.Subscription, err error) { + subject := "" + if sub != nil { + subject = sub.Subject + } + if errors.Is(err, nats.ErrPermissionViolation) { + xlog.Error("NATS permission violation — check JWT pub/sub allow lists", "subject", subject, "error", err) + return + } + xlog.Warn("NATS async error", "subject", subject, "error", err) + }), + } + switch { + case cfg.jwtProvider != nil: + // Fetch creds on every (re)connect so a refresh loop can rotate the JWT + // before expiry; the server expiring the old JWT triggers a reconnect + // that transparently picks up the new one. + natsOpts = append(natsOpts, nats.UserJWT( + func() (string, error) { + jwt, _ := cfg.jwtProvider() + if jwt == "" { + return "", fmt.Errorf("no NATS user JWT available") + } + return jwt, nil + }, + func(nonce []byte) ([]byte, error) { + _, seed := cfg.jwtProvider() + kp, err := nkeys.FromSeed([]byte(seed)) + if err != nil { + return nil, fmt.Errorf("loading NATS user seed: %w", err) + } + defer kp.Wipe() + return kp.Sign(nonce) + }, + )) + case cfg.userJWT != "" && cfg.userSeed != "": + natsOpts = append(natsOpts, nats.UserJWTAndSeed(cfg.userJWT, cfg.userSeed)) + } + if cfg.tls.Enabled() { + if err := cfg.tls.Validate(); err != nil { + return nil, err + } + tlsOpts, err := cfg.tls.natsOptions() + if err != nil { + return nil, err + } + natsOpts = append(natsOpts, tlsOpts...) + } + + nc, err := nats.Connect(url, natsOpts...) if err != nil { return nil, fmt.Errorf("connecting to NATS at %s: %w", sanitize.URL(url), err) } @@ -54,23 +119,67 @@ func (c *Client) Publish(subject string, data any) error { // Subscribe creates a subscription on the given subject. All subscribers receive every message. func (c *Client) Subscribe(subject string, handler func([]byte)) (Subscription, error) { - c.mu.RLock() - defer c.mu.RUnlock() - return c.conn.Subscribe(subject, func(msg *nats.Msg) { - handler(msg.Data) + return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) { + return conn.Subscribe(subject, func(msg *nats.Msg) { + handler(msg.Data) + }) }) } // QueueSubscribe creates a queue subscription. Within the same queue group, // only one subscriber receives each message (load-balanced). func (c *Client) QueueSubscribe(subject, queue string, handler func([]byte)) (Subscription, error) { - c.mu.RLock() - defer c.mu.RUnlock() - return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) { - handler(msg.Data) + return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) { + return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) { + handler(msg.Data) + }) }) } +// confirmSubscription creates a subscription via mk and forces a server +// round-trip so that a permissions violation — which NATS otherwise reports +// only asynchronously — is returned to the caller synchronously. The server +// emits the "-ERR Permissions Violation" for a rejected SUB before the PONG +// that satisfies the flush, so by the time FlushTimeout returns the violation +// is recorded as the connection's last error. Without this, a worker whose JWT +// lacks a subject gets a non-nil subscription that never receives a message, +// turning a permission misconfiguration into a silent failure. +func (c *Client) confirmSubscription(subject string, mk func(*nats.Conn) (*nats.Subscription, error)) (Subscription, error) { + c.mu.RLock() + conn := c.conn + c.mu.RUnlock() + if conn == nil { + return nil, fmt.Errorf("subscribe to %s: nil NATS connection", subject) + } + + sub, err := mk(conn) + if err != nil { + return nil, err + } + + // A failed flush here means we could not round-trip to the server (not yet + // connected, reconnecting, slow link). RetryOnFailedConnect intentionally + // buffers subscriptions across that gap, so do NOT fail — keep the + // subscription and let it replay on (re)connect; a later permission + // violation is still logged by the async error handler in New. + if err := conn.FlushTimeout(subscribeConfirmTimeout); err != nil { + xlog.Debug("Could not confirm NATS subscription (will replay on connect)", "subject", subject, "error", err) + return sub, nil + } + // Flush succeeded, so any permission violation for this SUB has already been + // recorded as the connection's last error (the server emits it before the + // PONG). LastError is per-connection; match the exact quoted subject the + // server echoes ("Subscription to \"\"") so a stale violation for + // another subject can't be mis-attributed here. + if lerr := conn.LastError(); lerr != nil && + errors.Is(lerr, nats.ErrPermissionViolation) && + strings.Contains(lerr.Error(), `Subscription to "`+subject+`"`) { + _ = sub.Unsubscribe() + return nil, fmt.Errorf("subscription to %s denied by NATS server (check JWT sub allow list): %w", subject, lerr) + } + return sub, nil +} + // Request sends a request and waits for a reply (request-reply pattern). // Returns the raw reply data. func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) { @@ -86,15 +195,15 @@ func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([] // SubscribeReply creates a subscription that supports replying to requests. // The handler receives the raw request data and the reply subject. func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply func([]byte))) (Subscription, error) { - c.mu.RLock() - defer c.mu.RUnlock() - return c.conn.Subscribe(subject, func(msg *nats.Msg) { - handler(msg.Data, func(replyData []byte) { - if msg.Reply != "" { - if err := msg.Respond(replyData); err != nil { - xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err) + return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) { + return conn.Subscribe(subject, func(msg *nats.Msg) { + handler(msg.Data, func(replyData []byte) { + if msg.Reply != "" { + if err := msg.Respond(replyData); err != nil { + xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err) + } } - } + }) }) }) } @@ -102,15 +211,15 @@ func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply // QueueSubscribeReply creates a queue subscription that supports replying to requests. // Load-balanced across subscribers in the same queue group, with request-reply support. func (c *Client) QueueSubscribeReply(subject, queue string, handler func(data []byte, reply func([]byte))) (Subscription, error) { - c.mu.RLock() - defer c.mu.RUnlock() - return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) { - handler(msg.Data, func(replyData []byte) { - if msg.Reply != "" { - if err := msg.Respond(replyData); err != nil { - xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err) + return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) { + return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) { + handler(msg.Data, func(replyData []byte) { + if msg.Reply != "" { + if err := msg.Respond(replyData); err != nil { + xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err) + } } - } + }) }) }) } diff --git a/core/services/messaging/options.go b/core/services/messaging/options.go new file mode 100644 index 000000000..b4c67b16c --- /dev/null +++ b/core/services/messaging/options.go @@ -0,0 +1,34 @@ +package messaging + +// Option configures NATS client connection behavior. +type Option func(*connectConfig) + +// CredentialProvider returns the NATS user JWT and signing seed to use for the +// next (re)connect. It is consulted on every connection attempt, so a refresh +// loop can rotate credentials before they expire and the connection picks them +// up automatically when the server expires the old JWT and triggers a reconnect. +type CredentialProvider func() (jwt, seed string) + +type connectConfig struct { + userJWT string + userSeed string + jwtProvider CredentialProvider + tls TLSFiles +} + +// WithUserJWT connects using a static NATS user JWT and signing seed (UserJWTAndSeed). +func WithUserJWT(jwt, seed string) Option { + return func(c *connectConfig) { + c.userJWT = jwt + c.userSeed = seed + } +} + +// WithUserJWTProvider connects using credentials fetched from provider on each +// (re)connect, enabling JWT rotation without dropping the client. Takes +// precedence over WithUserJWT when both are set. +func WithUserJWTProvider(provider CredentialProvider) Option { + return func(c *connectConfig) { + c.jwtProvider = provider + } +} diff --git a/core/services/messaging/tls.go b/core/services/messaging/tls.go new file mode 100644 index 000000000..b594845cc --- /dev/null +++ b/core/services/messaging/tls.go @@ -0,0 +1,68 @@ +package messaging + +import ( + "fmt" + "os" + + "github.com/nats-io/nats.go" +) + +// TLSFiles holds PEM paths for NATS TLS / mTLS. Cert and key must be set together. +// Use tls:// in LOCALAI_NATS_URL; CA and client cert paths are optional extras. +type TLSFiles struct { + CA string // LOCALAI_NATS_TLS_CA — private CA for server verification + Cert string // LOCALAI_NATS_TLS_CERT — client certificate (mTLS) + Key string // LOCALAI_NATS_TLS_KEY — client private key +} + +// Enabled reports whether any TLS file path is configured. +func (f TLSFiles) Enabled() bool { + return f.CA != "" || f.Cert != "" || f.Key != "" +} + +// Validate checks path pairing and that files exist. +func (f TLSFiles) Validate() error { + if f.Cert != "" && f.Key == "" { + return fmt.Errorf("LOCALAI_NATS_TLS_KEY is required when LOCALAI_NATS_TLS_CERT is set") + } + if f.Key != "" && f.Cert == "" { + return fmt.Errorf("LOCALAI_NATS_TLS_CERT is required when LOCALAI_NATS_TLS_KEY is set") + } + for _, path := range []struct { + name, path string + }{ + {"LOCALAI_NATS_TLS_CA", f.CA}, + {"LOCALAI_NATS_TLS_CERT", f.Cert}, + {"LOCALAI_NATS_TLS_KEY", f.Key}, + } { + if path.path == "" { + continue + } + if _, err := os.Stat(path.path); err != nil { + return fmt.Errorf("%s: %w", path.name, err) + } + } + return nil +} + +// natsOptions builds nats-go TLS options. Call Validate first. +func (f TLSFiles) natsOptions() ([]nats.Option, error) { + if !f.Enabled() { + return nil, nil + } + opts := []nats.Option{nats.Secure()} + if f.CA != "" { + opts = append(opts, nats.RootCAs(f.CA)) + } + if f.Cert != "" { + opts = append(opts, nats.ClientCert(f.Cert, f.Key)) + } + return opts, nil +} + +// WithTLS configures CA and/or client certificate paths for the NATS connection. +func WithTLS(files TLSFiles) Option { + return func(c *connectConfig) { + c.tls = files + } +} diff --git a/core/services/messaging/tls_test.go b/core/services/messaging/tls_test.go new file mode 100644 index 000000000..9fdb3e816 --- /dev/null +++ b/core/services/messaging/tls_test.go @@ -0,0 +1,25 @@ +package messaging_test + +import ( + "os" + "path/filepath" + + "github.com/mudler/LocalAI/core/services/messaging" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("TLSFiles", func() { + It("requires cert and key together", func() { + Expect((messaging.TLSFiles{Cert: "/tmp/c.pem"}).Validate()).To(HaveOccurred()) + Expect((messaging.TLSFiles{Key: "/tmp/k.pem"}).Validate()).To(HaveOccurred()) + }) + + It("validates files exist", func() { + dir := GinkgoT().TempDir() + ca := filepath.Join(dir, "ca.pem") + Expect(os.WriteFile(ca, []byte("x"), 0600)).To(Succeed()) + Expect((messaging.TLSFiles{CA: ca}).Validate()).To(Succeed()) + }) +}) diff --git a/core/services/worker/config.go b/core/services/worker/config.go index 97d2a6582..890137f79 100644 --- a/core/services/worker/config.go +++ b/core/services/worker/config.go @@ -60,7 +60,13 @@ type Config struct { MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"` // NATS (required) - NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"` + NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"` + NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (normally from registration nats_jwt)" group:"distributed"` + NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user signing seed override (normally from registration nats_user_seed)" group:"distributed"` + NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed from registration or env" group:"distributed"` + NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"` + NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"` + NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"` // S3 storage for distributed file transfer StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"` diff --git a/core/services/worker/nats_connect.go b/core/services/worker/nats_connect.go new file mode 100644 index 000000000..25485701d --- /dev/null +++ b/core/services/worker/nats_connect.go @@ -0,0 +1,33 @@ +package worker + +import ( + "fmt" + + "github.com/mudler/LocalAI/core/services/messaging" +) + +// connectNATS opens a NATS client using JWT+seed from env or registration (env wins). +func connectNATS(url, envJWT, envSeed, registerJWT, registerSeed string, requireAuth bool, tls messaging.TLSFiles) (*messaging.Client, error) { + // Env credentials take precedence, but only fall back to registration when + // the env supplied neither half — otherwise a JWT set without its seed (or + // vice-versa) would be silently completed from a different source. + jwt, seed := envJWT, envSeed + if jwt == "" && seed == "" { + jwt, seed = registerJWT, registerSeed + } + // A JWT without its paired seed (or vice-versa) is a misconfiguration: refuse + // rather than silently connecting anonymously, which would look authenticated. + if (jwt == "") != (seed == "") { + return nil, fmt.Errorf("NATS JWT and seed must be provided together (got JWT set=%t, seed set=%t)", jwt != "", seed != "") + } + var opts []messaging.Option + if jwt != "" && seed != "" { + opts = append(opts, messaging.WithUserJWT(jwt, seed)) + } else if requireAuth { + return nil, fmt.Errorf("NATS JWT+seed required: set LOCALAI_NATS_JWT/LOCALAI_NATS_USER_SEED or enable frontend minting") + } + if tls.Enabled() { + opts = append(opts, messaging.WithTLS(tls)) + } + return messaging.New(url, opts...) +} diff --git a/core/services/worker/nats_connect_test.go b/core/services/worker/nats_connect_test.go new file mode 100644 index 000000000..8f554de4e --- /dev/null +++ b/core/services/worker/nats_connect_test.go @@ -0,0 +1,29 @@ +package worker + +import ( + "github.com/mudler/LocalAI/core/services/messaging" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("connectNATS", func() { + It("requires JWT when requireAuth is set and no credentials are provided", func() { + _, err := connectNATS("nats://127.0.0.1:4222", "", "", "", "", true, messaging.TLSFiles{}) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("NATS JWT+seed required")) + }) + + // A JWT supplied without its paired seed (or vice-versa) is an operator + // misconfiguration. Today connectNATS silently drops the unpaired credential + // and connects anonymously, so the operator believes the link is + // authenticated when it is not. It should refuse instead. + It("rejects a JWT supplied without a seed instead of connecting anonymously", func() { + client, err := connectNATS("nats://127.0.0.1:4222", "jwt-without-seed", "", "", "", false, messaging.TLSFiles{}) + if client != nil { + client.Close() + } + Expect(err).To(HaveOccurred(), + "connectNATS should reject an unpaired JWT rather than silently connecting anonymously") + }) +}) diff --git a/core/services/worker/worker.go b/core/services/worker/worker.go index ff03d7b55..2b869b372 100644 --- a/core/services/worker/worker.go +++ b/core/services/worker/worker.go @@ -15,6 +15,7 @@ import ( "github.com/mudler/LocalAI/core/cli/workerregistry" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/services/messaging" "github.com/mudler/LocalAI/core/services/nodes" grpc "github.com/mudler/LocalAI/pkg/grpc" @@ -67,10 +68,63 @@ func Run(ctx *cliContext.Context, cfg *Config) error { RegistrationToken: cfg.RegistrationToken, } + // Context cancelled on shutdown — used by registration waits, heartbeat, and + // other background goroutines. + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + defer shutdownCancel() + registrationBody := cfg.registrationBody() - nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10) - if err != nil { - return fmt.Errorf("failed to register with frontend: %w", err) + natsTLS := messaging.TLSFiles{CA: cfg.NatsTLSCA, Cert: cfg.NatsTLSCert, Key: cfg.NatsTLSKey} + + // Resolve how to connect to NATS. Static env credentials cannot be re-minted, + // so register once and use them directly. Otherwise the credential manager + // (re)registers to obtain credentials — waiting through admin approval — and + // refreshes them before the minted JWT expires, so the connection survives + // expiry via a transparent reconnect. + var ( + nodeID string + connectNats func() (*messaging.Client, error) + ) + if cfg.NatsJWT != "" || cfg.NatsUserSeed != "" { + nid, _, _, _, regErr := regClient.RegisterWithRetry(shutdownCtx, registrationBody, 10) + if regErr != nil { + return fmt.Errorf("failed to register with frontend: %w", regErr) + } + nodeID = nid + connectNats = func() (*messaging.Client, error) { + return connectNATS(cfg.NatsURL, cfg.NatsJWT, cfg.NatsUserSeed, "", "", cfg.NatsRequireAuth, natsTLS) + } + } else { + credMgr := workerregistry.NewNATSCredentialManager( + func(ctx context.Context) (*workerregistry.RegisterResponse, error) { + return regClient.RegisterFull(ctx, registrationBody) + }, + cfg.NatsRequireAuth, + ) + res, regErr := credMgr.Acquire(shutdownCtx) + if regErr != nil { + return fmt.Errorf("failed to register with frontend: %w", regErr) + } + nodeID = res.ID + connectNats = func() (*messaging.Client, error) { + var opts []messaging.Option + if credMgr.HasCredentials() { + opts = append(opts, messaging.WithUserJWTProvider(credMgr.Provider())) + } + if natsTLS.Enabled() { + opts = append(opts, messaging.WithTLS(natsTLS)) + } + client, cerr := messaging.New(cfg.NatsURL, opts...) + if cerr == nil && credMgr.HasCredentials() { + go func() { + if err := credMgr.RefreshLoop(shutdownCtx); err != nil { + xlog.Error("NATS credential refresh permanently failed; shutting down worker", "error", err) + shutdownCancel() + } + }() + } + return client, cerr + } } xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cfg.RegisterTo) @@ -79,9 +133,6 @@ func Run(ctx *cliContext.Context, cfg *Config) error { xlog.Warn("invalid heartbeat interval, using default 10s", "input", cfg.HeartbeatInterval, "error", err) } heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second) - // Context cancelled on shutdown — used by heartbeat and other background goroutines - shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) - defer shutdownCancel() // Start HTTP file transfer server httpAddr := cfg.resolveHTTPAddr() @@ -94,7 +145,7 @@ func Run(ctx *cliContext.Context, cfg *Config) error { // Connect to NATS xlog.Info("Connecting to NATS", "url", sanitize.URL(cfg.NatsURL)) - natsClient, err := messaging.New(cfg.NatsURL) + natsClient, err := connectNats() if err != nil { nodes.ShutdownFileTransferServer(httpServer) return fmt.Errorf("connecting to NATS: %w", err) @@ -154,12 +205,21 @@ func Run(ctx *cliContext.Context, cfg *Config) error { } xlog.Info("Worker ready, waiting for backend.install events") - <-sigCh + // Exit on an OS signal or on an internal fatal condition (e.g. NATS + // credentials became unrenewable), so the worker restarts and re-acquires + // rather than lingering unable to serve. + var runErr error + select { + case <-sigCh: + case <-shutdownCtx.Done(): + runErr = fmt.Errorf("worker shutting down: NATS credentials unavailable") + xlog.Error("Internal shutdown requested", "error", runErr) + } xlog.Info("Shutting down worker") shutdownCancel() // stop heartbeat loop immediately regClient.GracefulDeregister(nodeID) supervisor.stopAllBackends() nodes.ShutdownFileTransferServer(httpServer) - return nil + return runErr } diff --git a/docs/content/features/distributed-mode.md b/docs/content/features/distributed-mode.md index de50cba3e..43ab5c146 100644 --- a/docs/content/features/distributed-mode.md +++ b/docs/content/features/distributed-mode.md @@ -71,6 +71,50 @@ The frontend is a standard LocalAI instance with distributed mode enabled. These | `--backend-upgrade-timeout` | `LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT` | `15m` | Same as the install timeout, applied to backend upgrades (force-reinstall). | | `--expose-node-header` | `LOCALAI_EXPOSE_NODE_HEADER` | `false` | When enabled, inference responses carry an `X-LocalAI-Node` header with the ID of the worker node that served the request. Coverage spans the OpenAI-compatible endpoints (chat completions, completions, embeddings, audio transcriptions, audio speech / TTS, image generations, image inpainting), the Jina rerank endpoint (`/v1/rerank`), the VAD endpoints (`/v1/vad`, `/vad`), and the Anthropic Messages (`/v1/messages`) and Ollama (`/api/chat`, `/api/generate`, `/api/embed`) shims. Useful for debugging, observability and load-balancer attribution. Off by default: the node ID reveals internal cluster topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency for the same model across multiple replicas, the header may reflect a recent routing decision rather than this exact request's. Acceptable for observability and debugging. | +### NATS JWT authentication (recommended for production) + +By default, NATS connections are anonymous: any client that can reach port `4222` may publish control-plane subjects such as `nodes..backend.install`. Enable JWT auth to scope workers to their own node subjects and give the frontend a dedicated service credential. + +| Flag | Env Var | Description | +|------|---------|-------------| +| `--nats-account-seed` | `LOCALAI_NATS_ACCOUNT_SEED` | Account signing seed (`SU...`). The frontend mints a per-node user JWT at registration (`nats_jwt` in the register response). | +| `--nats-service-jwt` | `LOCALAI_NATS_SERVICE_JWT` | User JWT for the frontend (and optional fallback for agent workers) to publish install/upgrade and related subjects. | +| `--nats-service-seed` | `LOCALAI_NATS_SERVICE_SEED` | User signing seed (`SU...`) paired with the service JWT. | +| `--nats-worker-jwt-ttl` | `LOCALAI_NATS_WORKER_JWT_TTL` | Lifetime of minted worker JWTs (default `24h`). | +| `--nats-require-auth` | `LOCALAI_NATS_REQUIRE_AUTH` | Fail startup if JWT credentials are missing when distributed mode is enabled. | + +### NATS TLS / mTLS (optional) + +Use `tls://` in `--nats-url` / `LOCALAI_NATS_URL` for encrypted transport. When the server uses a private CA or requires client certificates, set: + +| Flag | Env Var | Description | +|------|---------|-------------| +| `--nats-tls-ca` | `LOCALAI_NATS_TLS_CA` | PEM file to verify the NATS server (private CA) | +| `--nats-tls-cert` | `LOCALAI_NATS_TLS_CERT` | Client certificate for NATS mTLS | +| `--nats-tls-key` | `LOCALAI_NATS_TLS_KEY` | Client private key (required with `--nats-tls-cert`) | + +The same env vars apply to backend workers and `local-ai agent-worker`. If the server cert is already trusted by the OS, `tls://` alone is enough. + +**Worker register response** (when minting is enabled and the node is approved): + +```json +{ + "id": "…", + "nats_jwt": "eyJ…", + "nats_user_seed": "SU…" +} +``` + +Workers connect with that JWT and seed automatically (shown once; store securely). Override with `LOCALAI_NATS_JWT` / `LOCALAI_NATS_USER_SEED` if needed. Set `LOCALAI_NATS_REQUIRE_AUTH=true` on workers when the bus requires credentials. + +When `LOCALAI_NATS_REQUIRE_AUTH=true` and no static credentials are provided, a worker that registers while still **pending admin approval** keeps re-registering (with backoff) until an admin approves it and the frontend mints its JWT — it does not start unauthenticated. This retry is **bounded**: if the node is never approved (or no credentials are minted) after a large number of attempts, the worker exits non-zero so the failure is visible (a crash-looping or failed worker) rather than hanging silently. Minted worker JWTs are also **refreshed automatically** before they expire (the worker re-registers at ~75% of the JWT lifetime), so long-running workers survive past `LOCALAI_NATS_WORKER_JWT_TTL`; the NATS connection picks up the new JWT on its next reconnect. If refresh fails persistently, the worker exits (to restart and re-acquire) rather than drifting toward an expired, unrenewable JWT. Statically configured (`LOCALAI_NATS_JWT`) and service (`LOCALAI_NATS_SERVICE_JWT`) credentials are used as-is and not refreshed. + +Generate operator/account material with [`scripts/nats-auth-setup.sh`](https://github.com/mudler/LocalAI/blob/master/scripts/nats-auth-setup.sh) (requires [nsc](https://docs.nats.io/running-a-nats-service/configuration/securing_nats/auth_intro/nsc)). Configure the NATS server with account resolver JWTs before enabling `LOCALAI_NATS_REQUIRE_AUTH`. + +{{% notice note %}} +`LOCALAI_AUTH` (HTTP users/sessions) and NATS JWTs are separate: end-user API keys do not connect to NATS. HTTP registration still uses `LOCALAI_REGISTRATION_TOKEN`. +{{% /notice %}} + ### Optional: S3 Object Storage For multi-host deployments where workers don't share a filesystem, S3-compatible storage enables distributed file transfer (model files, configs): @@ -134,6 +178,12 @@ local-ai worker \ | `--registration-token` | `LOCALAI_REGISTRATION_TOKEN` | *(empty)* | Token to authenticate with the frontend | | `--heartbeat-interval` | `LOCALAI_HEARTBEAT_INTERVAL` | `10s` | Interval between heartbeat pings | | `--nats-url` | `LOCALAI_NATS_URL` | *(required)* | NATS URL for backend installation and file staging | +| `--nats-jwt` | `LOCALAI_NATS_JWT` | *(empty)* | Optional override for the `nats_jwt` returned at registration | +| `--nats-user-seed` | `LOCALAI_NATS_USER_SEED` | *(empty)* | Optional override for `nats_user_seed` from registration | +| `--nats-require-auth` | `LOCALAI_NATS_REQUIRE_AUTH` | `false` | Require NATS JWT+seed (from registration or env) | +| `--nats-tls-ca` | `LOCALAI_NATS_TLS_CA` | *(empty)* | PEM file for NATS server CA | +| `--nats-tls-cert` | `LOCALAI_NATS_TLS_CERT` | *(empty)* | Client certificate for NATS mTLS | +| `--nats-tls-key` | `LOCALAI_NATS_TLS_KEY` | *(empty)* | Client private key for NATS mTLS | | `--backends-path` | `LOCALAI_BACKENDS_PATH` | `./backends` | Path to backend binaries | | `--models-path` | `LOCALAI_MODELS_PATH` | `./models` | Path to model files | diff --git a/go.mod b/go.mod index 1daaf9e4c..a7d395acd 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,9 @@ require ( github.com/mudler/go-processmanager v0.1.1 github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8 github.com/mudler/xlog v0.0.6 + github.com/nats-io/jwt/v2 v2.7.4 github.com/nats-io/nats.go v1.52.0 + github.com/nats-io/nkeys v0.4.15 github.com/ollama/ollama v0.20.4 github.com/onsi/ginkgo/v2 v2.29.0 github.com/onsi/gomega v1.41.0 @@ -134,7 +136,6 @@ require ( github.com/mattn/go-sqlite3 v1.14.28 // indirect github.com/moby/moby/api v1.54.2 // indirect github.com/moby/moby/client v0.4.1 // indirect - github.com/nats-io/nkeys v0.4.15 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/secure-systems-lab/go-securesystemslib v0.9.1 // indirect diff --git a/go.sum b/go.sum index 24e50c70c..2cae3ec88 100644 --- a/go.sum +++ b/go.sum @@ -1016,6 +1016,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= +github.com/nats-io/jwt/v2 v2.7.4 h1:jXFuDDxs/GQjGDZGhNgH4tXzSUK6WQi2rsj4xmsNOtI= +github.com/nats-io/jwt/v2 v2.7.4/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA= github.com/nats-io/nats.go v1.52.0 h1:n3avV4VBsCgsdwh71TppsTwtv+QdPs7ntSKM8qJLGsc= github.com/nats-io/nats.go v1.52.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno= github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= diff --git a/pkg/natsauth/config.go b/pkg/natsauth/config.go new file mode 100644 index 000000000..e9de53470 --- /dev/null +++ b/pkg/natsauth/config.go @@ -0,0 +1,66 @@ +package natsauth + +import ( + "fmt" + "time" + + "github.com/mudler/xlog" +) + +// DefaultWorkerJWTTTL is how long a worker may use a minted NATS user JWT before re-registering. +const DefaultWorkerJWTTTL = 24 * time.Hour + +// Config holds NATS JWT authentication settings for distributed mode. +type Config struct { + // AccountSeed is the NATS account signing seed (SU...). Used to mint per-node worker JWTs. + AccountSeed string + // ServiceUserJWT is a pre-generated user JWT for frontends and agent workers (broad publish). + ServiceUserJWT string + // ServiceUserSeed is the signing seed (SU...) paired with ServiceUserJWT. + ServiceUserSeed string + // WorkerJWTTTL sets expiry on minted worker JWTs. Zero uses DefaultWorkerJWTTTL. + WorkerJWTTTL time.Duration + // RequireAuth rejects anonymous NATS when true (both ServiceUserJWT and AccountSeed expected). + RequireAuth bool +} + +// Enabled reports whether any NATS credential material is configured. +func (c Config) Enabled() bool { + return c.AccountSeed != "" || c.ServiceUserJWT != "" +} + +// CanMintWorkers reports whether per-node JWTs can be issued at registration. +func (c Config) CanMintWorkers() bool { + return c.AccountSeed != "" +} + +// WorkerTTL returns the configured worker JWT lifetime. +func (c Config) WorkerTTL() time.Duration { + if c.WorkerJWTTTL > 0 { + return c.WorkerJWTTTL + } + return DefaultWorkerJWTTTL +} + +// Validate checks consistency when distributed NATS auth is required. +func (c Config) Validate() error { + if !c.RequireAuth { + return nil + } + if c.ServiceUserJWT == "" || c.ServiceUserSeed == "" { + return fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED") + } + if c.AccountSeed == "" { + return fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH is set but LOCALAI_NATS_ACCOUNT_SEED is empty") + } + return nil +} + +// WarnIfInsecure logs when distributed NATS is reachable without credentials. +func (c Config) WarnIfInsecure(distributed bool) { + if !distributed || c.Enabled() { + return + } + xlog.Warn("NATS is used without JWT credentials — any client on the bus can publish backend.install. " + + "Set LOCALAI_NATS_ACCOUNT_SEED + LOCALAI_NATS_SERVICE_JWT (see docs/features/distributed-mode.md).") +} diff --git a/pkg/natsauth/decode.go b/pkg/natsauth/decode.go new file mode 100644 index 000000000..1ae156f44 --- /dev/null +++ b/pkg/natsauth/decode.go @@ -0,0 +1,16 @@ +package natsauth + +import ( + "fmt" + + "github.com/nats-io/jwt/v2" +) + +// DecodeUserClaims decodes a minted worker JWT for tests and diagnostics. +func DecodeUserClaims(token string) (*jwt.UserClaims, error) { + uc, err := jwt.DecodeUserClaims(token) + if err != nil { + return nil, fmt.Errorf("natsauth: decode user JWT: %w", err) + } + return uc, nil +} diff --git a/pkg/natsauth/mint.go b/pkg/natsauth/mint.go new file mode 100644 index 000000000..387e8d701 --- /dev/null +++ b/pkg/natsauth/mint.go @@ -0,0 +1,59 @@ +package natsauth + +import ( + "fmt" + "time" + + "github.com/nats-io/jwt/v2" + "github.com/nats-io/nkeys" +) + +// MintWorkerJWT creates a signed NATS user JWT and user seed scoped to nodeID and nodeType. +// The seed is returned once at registration so the worker can sign NATS connections. +func (c Config) MintWorkerJWT(nodeID, nodeType string) (userJWT, userSeed string, err error) { + if c.AccountSeed == "" { + return "", "", fmt.Errorf("natsauth: account seed not configured") + } + if nodeID == "" { + return "", "", fmt.Errorf("natsauth: node ID is required") + } + + accountKP, err := nkeys.FromSeed([]byte(c.AccountSeed)) + if err != nil { + return "", "", fmt.Errorf("natsauth: invalid account seed: %w", err) + } + + userKP, err := nkeys.CreateUser() + if err != nil { + return "", "", fmt.Errorf("natsauth: create user key: %w", err) + } + seedBytes, err := userKP.Seed() + if err != nil { + return "", "", fmt.Errorf("natsauth: user seed: %w", err) + } + + accountPub, err := accountKP.PublicKey() + if err != nil { + return "", "", fmt.Errorf("natsauth: account public key: %w", err) + } + userPub, err := userKP.PublicKey() + if err != nil { + return "", "", fmt.Errorf("natsauth: user public key: %w", err) + } + + pubAllow, subAllow := WorkerPermissions(nodeID, nodeType) + + uc := jwt.NewUserClaims(userPub) + uc.Name = fmt.Sprintf("localai-%s-%s", nodeType, workerSubjectToken(nodeID)) + uc.IssuerAccount = accountPub + uc.Expires = time.Now().Add(c.WorkerTTL()).Unix() + + uc.Permissions.Pub.Allow = pubAllow + uc.Permissions.Sub.Allow = subAllow + + token, err := uc.Encode(accountKP) + if err != nil { + return "", "", fmt.Errorf("natsauth: encode user JWT: %w", err) + } + return token, string(seedBytes), nil +} diff --git a/pkg/natsauth/mint_test.go b/pkg/natsauth/mint_test.go new file mode 100644 index 000000000..4d5bb77ce --- /dev/null +++ b/pkg/natsauth/mint_test.go @@ -0,0 +1,60 @@ +package natsauth_test + +import ( + "testing" + "time" + + "github.com/mudler/LocalAI/pkg/natsauth" + "github.com/nats-io/jwt/v2" + "github.com/nats-io/nkeys" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestNatsAuth(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "NatsAuth") +} + +var _ = Describe("MintWorkerJWT", func() { + var accountSeed string + + BeforeEach(func() { + akp, err := nkeys.CreateAccount() + Expect(err).NotTo(HaveOccurred()) + seed, err := akp.Seed() + Expect(err).NotTo(HaveOccurred()) + accountSeed = string(seed) + }) + + It("mints a JWT with backend worker permissions", func() { + cfg := natsauth.Config{AccountSeed: accountSeed, WorkerJWTTTL: time.Hour} + token, seed, err := cfg.MintWorkerJWT("550e8400-e29b-41d4-a716-446655440000", "backend") + Expect(err).NotTo(HaveOccurred()) + Expect(token).NotTo(BeEmpty()) + Expect(seed).NotTo(BeEmpty()) + + uc, err := jwt.DecodeUserClaims(token) + Expect(err).NotTo(HaveOccurred()) + Expect(uc.Permissions.Sub.Allow).To(ContainElement("nodes.550e8400-e29b-41d4-a716-446655440000.>")) + Expect(uc.Permissions.Pub.Allow).To(ContainElement("nodes.550e8400-e29b-41d4-a716-446655440000.backend.install.*.progress")) + }) + + It("mints agent permissions without backend install subscribe", func() { + cfg := natsauth.Config{AccountSeed: accountSeed} + token, _, err := cfg.MintWorkerJWT("node-1", "agent") + Expect(err).NotTo(HaveOccurred()) + + uc, err := jwt.DecodeUserClaims(token) + Expect(err).NotTo(HaveOccurred()) + Expect(uc.Permissions.Sub.Allow).To(ContainElement("agent.execute")) + for _, subj := range uc.Permissions.Sub.Allow { + Expect(subj).NotTo(ContainSubstring("backend.install")) + } + }) + + It("rejects mint without account seed", func() { + _, _, err := (natsauth.Config{}).MintWorkerJWT("id", "backend") + Expect(err).To(HaveOccurred()) + }) +}) diff --git a/pkg/natsauth/permissions.go b/pkg/natsauth/permissions.go new file mode 100644 index 000000000..8fdb11ef9 --- /dev/null +++ b/pkg/natsauth/permissions.go @@ -0,0 +1,49 @@ +package natsauth + +import "strings" + +// workerSubjectToken mirrors messaging.sanitizeSubjectToken without importing unexported logic. +func workerSubjectToken(nodeID string) string { + r := strings.NewReplacer(".", "-", "*", "-", ">", "-", " ", "-", "\t", "-", "\n", "-") + return r.Replace(nodeID) +} + +// WorkerPermissions returns NATS pub/sub allow lists for a registered node. +func WorkerPermissions(nodeID, nodeType string) (pubAllow, subAllow []string) { + tok := workerSubjectToken(nodeID) + prefix := "nodes." + tok + + switch nodeType { + case "agent": + // Agent workers consume queue workloads; they must not handle backend.install. + // Keep this list in sync with the subscriptions in core/cli/agent_worker.go. + subAllow = []string{ + "agent.execute", + "jobs.*.cancel", + "jobs.*.progress", + "jobs.*.result", + "jobs.mcp-ci.new", // MCP CI jobs dispatched to agent workers + "mcp.tools.execute", + "mcp.discovery", + prefix + ".backend.stop", // stop events drive MCP session cleanup + "_INBOX.>", + } + pubAllow = []string{ + "agent.>", + "jobs.>", + "_INBOX.>", + } + default: + // Backend worker: lifecycle + file staging on this node only. + subAllow = []string{ + prefix + ".>", + "_INBOX.>", + } + pubAllow = []string{ + prefix + ".backend.install.*.progress", + prefix + ".files.>", + "_INBOX.>", + } + } + return pubAllow, subAllow +} diff --git a/pkg/natsauth/permissions_coverage_test.go b/pkg/natsauth/permissions_coverage_test.go new file mode 100644 index 000000000..05d2fbf0b --- /dev/null +++ b/pkg/natsauth/permissions_coverage_test.go @@ -0,0 +1,134 @@ +package natsauth_test + +import ( + "os" + "regexp" + "strings" + + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/natsauth" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// subjectMatches implements NATS subject-token matching: "*" matches exactly one +// token and ">" matches one or more trailing tokens. It lets these tests assert +// that a permission allow-list (which uses wildcards) actually covers a concrete +// subject a component publishes/subscribes — the same check the NATS server makes. +func subjectMatches(pattern, subject string) bool { + p := strings.Split(pattern, ".") + s := strings.Split(subject, ".") + for i, tok := range p { + if tok == ">" { + return i < len(s) // ">" must match at least one remaining token + } + if i >= len(s) { + return false + } + if tok != "*" && tok != s[i] { + return false + } + } + return len(p) == len(s) +} + +func anyAllows(allow []string, subject string) bool { + for _, p := range allow { + if subjectMatches(p, subject) { + return true + } + } + return false +} + +var _ = Describe("WorkerPermissions subject coverage", func() { + // A node ID containing NATS-reserved characters exercises the (duplicated) + // sanitizer in pkg/natsauth against the canonical one in core/services/messaging. + // If the two ever diverge, the minted prefix stops matching the real subject + // and these assertions fail — guarding the copy noted in the review. + const nodeID = "host.a 1*b" + + Context("backend worker", func() { + pub, sub := natsauth.WorkerPermissions(nodeID, "backend") + + // Every subject core/services/worker/{lifecycle,file_staging}.go subscribes to. + subscribed := []string{ + messaging.SubjectNodeBackendInstall(nodeID), + messaging.SubjectNodeBackendUpgrade(nodeID), + messaging.SubjectNodeBackendStop(nodeID), + messaging.SubjectNodeBackendDelete(nodeID), + messaging.SubjectNodeBackendList(nodeID), + messaging.SubjectNodeModelUnload(nodeID), + messaging.SubjectNodeModelDelete(nodeID), + messaging.SubjectNodeStop(nodeID), + messaging.SubjectNodeFilesEnsure(nodeID), + messaging.SubjectNodeFilesStage(nodeID), + messaging.SubjectNodeFilesTemp(nodeID), + messaging.SubjectNodeFilesListDir(nodeID), + } + for _, subject := range subscribed { + It("allows subscribing to "+subject, func() { + Expect(anyAllows(sub, subject)).To(BeTrue(), + "backend JWT sub allow-list %v does not cover %s", sub, subject) + }) + } + + It("allows publishing backend.install progress", func() { + subject := messaging.SubjectNodeBackendInstallProgress(nodeID, "op-123") + Expect(anyAllows(pub, subject)).To(BeTrue(), + "backend JWT pub allow-list %v does not cover %s", pub, subject) + }) + }) + + Context("agent worker", func() { + // node_type "agent"; subjects from core/cli/agent_worker.go. + pub, sub := natsauth.WorkerPermissions(nodeID, "agent") + _ = pub + + subscribed := []string{ + messaging.SubjectAgentExecute, // dispatcher (default --agent-subject) + messaging.SubjectMCPToolExecute, // QueueSubscribeReply + messaging.SubjectMCPDiscovery, // QueueSubscribeReply + messaging.SubjectMCPCIJobsNew, // QueueSubscribe — jobs.mcp-ci.new + messaging.SubjectNodeBackendStop(nodeID), // Subscribe — MCP session cleanup + } + for _, subject := range subscribed { + It("allows subscribing to "+subject, func() { + Expect(anyAllows(sub, subject)).To(BeTrue(), + "agent JWT sub allow-list %v does not cover %s — the agent worker subscribes to it", sub, subject) + }) + } + }) +}) + +var allowPubRe = regexp.MustCompile(`--allow-pub "([^"]*)"`) + +var _ = Describe("Documented NATS service-user permissions", func() { + // scripts/nats-auth-setup.sh ships the recommended service (frontend) JWT + // permissions. They must cover every subject the frontend actually publishes, + // or prefix-cache sync (and friends) break once LOCALAI_NATS_REQUIRE_AUTH is on. + const scriptPath = "../../scripts/nats-auth-setup.sh" + + // Representative subjects the frontend publishes on the control plane. + // prefixcache.* is emitted by prefixcache.Sync in core/application/distributed.go. + frontendPublishes := []string{ + messaging.SubjectPrefixCacheObserve, + messaging.SubjectPrefixCacheInvalidate, + messaging.SubjectNodeBackendInstall("node-1"), + messaging.SubjectGalleryProgress("op-1"), + } + + It("cover every subject the frontend publishes", func() { + raw, err := os.ReadFile(scriptPath) + Expect(err).ToNot(HaveOccurred(), "cannot read %s", scriptPath) + m := allowPubRe.FindStringSubmatch(string(raw)) + Expect(m).To(HaveLen(2), "no --allow-pub list found in %s", scriptPath) + allow := strings.Split(m[1], ",") + + for _, subject := range frontendPublishes { + Expect(anyAllows(allow, subject)).To(BeTrue(), + "service-user --allow-pub %v does not cover %s (frontend publishes it)", allow, subject) + } + }) +}) diff --git a/scripts/nats-auth-setup.sh b/scripts/nats-auth-setup.sh new file mode 100755 index 000000000..d279176f5 --- /dev/null +++ b/scripts/nats-auth-setup.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +# Generate NATS account + service user JWTs for LocalAI distributed mode. +# +# Requires: nsc (https://docs.nats.io/running-a-nats-service/configuration/securing_nats/auth_intro/nsc) +# +# Usage: +# ./scripts/nats-auth-setup.sh +# +# Outputs operator/account seeds and a service user JWT suitable for: +# LOCALAI_NATS_ACCOUNT_SEED +# LOCALAI_NATS_SERVICE_JWT +# +# Per-node worker JWTs are minted automatically by the frontend at registration +# when LOCALAI_NATS_ACCOUNT_SEED is set. + +set -euo pipefail + +if ! command -v nsc >/dev/null 2>&1; then + echo "nsc is required. Install from https://github.com/nats-io/nsc/releases" >&2 + exit 1 +fi + +OPERATOR="${NATS_OPERATOR_NAME:-localai-operator}" +ACCOUNT="${NATS_ACCOUNT_NAME:-localai}" +SERVICE_USER="${NATS_SERVICE_USER:-localai-frontend}" + +nsc add operator -n "$OPERATOR" --generate-signing-key +nsc add account -n "$ACCOUNT" +nsc add user -n "$SERVICE_USER" --account "$ACCOUNT" + +# Broad publish for frontend control plane (tighten with custom claims in production). +nsc edit user -n "$SERVICE_USER" --account "$ACCOUNT" \ + --allow-pub "nodes.>,gallery.>,agent.>,jobs.>,mcp.>,cache.>,prefixcache.>,finetune.>" \ + --allow-sub "nodes.>,gallery.>,agent.>,jobs.>,mcp.>,cache.>,prefixcache.>,_INBOX.>" + +KEYS_DIR="${NATS_KEYS_DIR:-./nats-keys}" +mkdir -p "$KEYS_DIR" +nsc generate creds -a "$ACCOUNT" -n "$SERVICE_USER" -o "$KEYS_DIR" + +ACCOUNT_SEED=$(nsc describe account "$ACCOUNT" -o json | jq -r '.nats.private_key') +SERVICE_JWT=$(cat "$KEYS_DIR/${ACCOUNT}/${SERVICE_USER}.jwt" 2>/dev/null || cat "$KEYS_DIR/${SERVICE_USER}.jwt") + +echo "" +echo "=== LocalAI NATS auth material ===" +echo "LOCALAI_NATS_ACCOUNT_SEED=${ACCOUNT_SEED}" +echo "LOCALAI_NATS_SERVICE_JWT=${SERVICE_JWT}" +echo "" +echo "Configure the NATS server with the generated operator/account JWTs under $KEYS_DIR" +echo "and set LOCALAI_NATS_REQUIRE_AUTH=true on frontends and workers in production." \ No newline at end of file diff --git a/tests/e2e/distributed/nats_jwt_helpers_test.go b/tests/e2e/distributed/nats_jwt_helpers_test.go new file mode 100644 index 000000000..80060ef6a --- /dev/null +++ b/tests/e2e/distributed/nats_jwt_helpers_test.go @@ -0,0 +1,156 @@ +package distributed_test + +import ( + "bytes" + "context" + "fmt" + "strings" + "time" + + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/natsauth" + "github.com/nats-io/jwt/v2" + "github.com/nats-io/nkeys" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/testcontainers/testcontainers-go" + tcnats "github.com/testcontainers/testcontainers-go/modules/nats" +) + +// JWTTestInfra holds a NATS server configured with JWT auth and minted worker credentials. +type JWTTestInfra struct { + *TestInfra + AccountSeed string + NodeID string + WorkerJWT string + WorkerSeed string +} + +// SetupJWTInfra starts NATS with an in-memory JWT resolver and returns worker credentials +// minted the same way as node registration (pkg/natsauth). +func SetupJWTInfra() *JWTTestInfra { + GinkgoHelper() + + infra := &JWTTestInfra{TestInfra: &TestInfra{Ctx: context.Background()}} + + operatorJWT, accountJWT, accountSeed, err := jwtResolverMaterial() + Expect(err).ToNot(HaveOccurred()) + infra.AccountSeed = accountSeed + + conf := fmt.Sprintf(`listen: 0.0.0.0:4222 + +operator: %s + +resolver: MEMORY +resolver_preload: { + %s: %s +} +`, operatorJWT, accountPublicKeyFromSeed(accountSeed), accountJWT) + + var natsContainer *tcnats.NATSContainer + // Override default testcontainers -js: JetStream fails without a system account in JWT mode. + natsContainer, err = tcnats.Run(infra.Ctx, "nats:2-alpine", + tcnats.WithConfigFile(bytes.NewBufferString(conf)), + testcontainers.WithCmd("-c", "/etc/nats.conf"), + ) + Expect(err).ToNot(HaveOccurred()) + infra.NATSContainer = natsContainer + + infra.NatsURL, err = infra.NATSContainer.ConnectionString(infra.Ctx) + Expect(err).ToNot(HaveOccurred()) + + infra.NodeID = "550e8400-e29b-41d4-a716-446655440000" + cfg := natsauth.Config{AccountSeed: infra.AccountSeed, WorkerJWTTTL: time.Hour} + infra.WorkerJWT, infra.WorkerSeed, err = cfg.MintWorkerJWT(infra.NodeID, "backend") + Expect(err).ToNot(HaveOccurred()) + + infra.NC, err = messaging.New(infra.NatsURL, messaging.WithUserJWT(infra.WorkerJWT, infra.WorkerSeed)) + Expect(err).ToNot(HaveOccurred()) + + DeferCleanup(func() { + if infra.NC != nil { + infra.NC.Close() + } + if infra.NATSContainer != nil { + _ = infra.NATSContainer.Terminate(context.Background()) + } + }) + + return infra +} + +// jwtResolverMaterial builds operator + account JWTs for a MEMORY resolver. +// Follows the NATS JWT tutorial: self-signed account, then operator re-sign, with the +// account identity key listed as a signing key so MintWorkerJWT can use the account seed. +func jwtResolverMaterial() (operatorJWT, accountJWT, accountSeed string, err error) { + okp, err := nkeys.CreateOperator() + if err != nil { + return "", "", "", err + } + opk, err := okp.PublicKey() + if err != nil { + return "", "", "", err + } + oc := jwt.NewOperatorClaims(opk) + oc.Name = "localai-test-operator" + oskp, err := nkeys.CreateOperator() + if err != nil { + return "", "", "", err + } + ospk, err := oskp.PublicKey() + if err != nil { + return "", "", "", err + } + oc.SigningKeys.Add(ospk) + operatorJWT, err = oc.Encode(okp) + if err != nil { + return "", "", "", err + } + + akp, err := nkeys.CreateAccount() + if err != nil { + return "", "", "", err + } + seed, err := akp.Seed() + if err != nil { + return "", "", "", err + } + accountSeed = string(seed) + + apk, err := akp.PublicKey() + if err != nil { + return "", "", "", err + } + ac := jwt.NewAccountClaims(apk) + ac.Name = "localai-test-account" + ac.SigningKeys.Add(apk) + accountJWT, err = ac.Encode(akp) + if err != nil { + return "", "", "", err + } + ac, err = jwt.DecodeAccountClaims(accountJWT) + if err != nil { + return "", "", "", err + } + accountJWT, err = ac.Encode(oskp) + if err != nil { + return "", "", "", err + } + return operatorJWT, accountJWT, accountSeed, nil +} + +func accountPublicKeyFromSeed(accountSeed string) string { + akp, err := nkeys.FromSeed([]byte(accountSeed)) + Expect(err).ToNot(HaveOccurred()) + pk, err := akp.PublicKey() + Expect(err).ToNot(HaveOccurred()) + return pk +} + +// nodeSubjectPrefix returns the sanitized nodes.* prefix for a node ID. +func nodeSubjectPrefix(nodeID string) string { + tok := strings.NewReplacer(".", "-", "*", "-", ">", "-", " ", "-", "\t", "-", "\n", "-").Replace(nodeID) + return "nodes." + tok +} \ No newline at end of file diff --git a/tests/e2e/distributed/nats_jwt_test.go b/tests/e2e/distributed/nats_jwt_test.go new file mode 100644 index 000000000..bf947e472 --- /dev/null +++ b/tests/e2e/distributed/nats_jwt_test.go @@ -0,0 +1,99 @@ +package distributed_test + +import ( + "time" + + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/natsauth" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("NATS JWT Auth", Label("Distributed", "NatsJWT"), func() { + var infra *JWTTestInfra + + BeforeEach(func() { + infra = SetupJWTInfra() + }) + + It("connects with a minted backend worker JWT and publishes on allowed subjects", func() { + // Backend workers may publish under nodes..files.> (see pkg/natsauth permissions). + subject := nodeSubjectPrefix(infra.NodeID) + ".files.in" + Expect(infra.NC.Publish(subject, map[string]string{"path": "/tmp/model"})).To(Succeed()) + Expect(infra.NC.Conn().FlushTimeout(2 * time.Second)).To(Succeed()) + Expect(infra.NC.Conn().IsConnected()).To(BeTrue()) + }) + + It("allows backend subscribe on the node prefix", func() { + wild := nodeSubjectPrefix(infra.NodeID) + ".>" + sub, err := infra.NC.Subscribe(wild, func(_ []byte) {}) + Expect(err).ToNot(HaveOccurred()) + defer func() { _ = sub.Unsubscribe() }() + Expect(infra.NC.Conn().FlushTimeout(2 * time.Second)).To(Succeed()) + Expect(infra.NC.Conn().IsConnected()).To(BeTrue()) + }) + + It("rejects anonymous publish on the JWT-enabled server", func() { + anon, err := messaging.New(infra.NatsURL) + Expect(err).ToNot(HaveOccurred()) + defer anon.Close() + + err = anon.Publish("nodes.any.files.x", map[string]string{"x": "1"}) + Expect(err).ToNot(HaveOccurred()) + Expect(anon.Conn().FlushTimeout(2 * time.Second)).To(HaveOccurred()) + }) + + It("denies backend publish to another node's subjects", func() { + other := nodeSubjectPrefix("other-node-id") + ".files.stage" + Expect(infra.NC.Publish(other, map[string]string{"stage": "nope"})).To(Succeed()) + Eventually(func() error { + _ = infra.NC.Conn().FlushTimeout(500 * time.Millisecond) + return infra.NC.Conn().LastError() + }, "3s", "50ms").Should(HaveOccurred()) + }) + + It("mints agent JWT without backend.install in claims", func() { + cfg := natsauth.Config{AccountSeed: infra.AccountSeed} + token, _, err := cfg.MintWorkerJWT("agent-node-1", "agent") + Expect(err).ToNot(HaveOccurred()) + + claims, err := natsauth.DecodeUserClaims(token) + Expect(err).ToNot(HaveOccurred()) + Expect(claims.Permissions.Sub.Allow).To(ContainElement("agent.execute")) + for _, subj := range claims.Permissions.Sub.Allow { + Expect(subj).NotTo(ContainSubstring("backend.install")) + } + }) + + // Regression guard for the silent permission gaps: decoding the JWT claims + // (above) only proves the agent JWT is *restrictive*, not that it is + // *sufficient*. Stand a real agent connection up against the enforcing + // server and exercise every subscription core/cli/agent_worker.go actually + // makes — a denied SUB now surfaces synchronously via confirmSubscription, + // so a missing allow rule fails this test instead of silently dropping + // backend.stop / MCP-CI deliveries at runtime. + It("lets an agent-minted JWT establish all the subscriptions the agent worker uses", func() { + const nodeID = "agent-node-subs" + cfg := natsauth.Config{AccountSeed: infra.AccountSeed, WorkerJWTTTL: time.Hour} + token, seed, err := cfg.MintWorkerJWT(nodeID, "agent") + Expect(err).ToNot(HaveOccurred()) + + nc, err := messaging.New(infra.NatsURL, messaging.WithUserJWT(token, seed)) + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(nc.Close) + + // Mirror core/cli/agent_worker.go exactly. + _, err = nc.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func([]byte, func([]byte)) {}) + Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s", messaging.SubjectMCPToolExecute) + + _, err = nc.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func([]byte, func([]byte)) {}) + Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s", messaging.SubjectMCPDiscovery) + + _, err = nc.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func([]byte) {}) + Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s (MCP CI jobs)", messaging.SubjectMCPCIJobsNew) + + _, err = nc.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func([]byte) {}) + Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s (MCP session cleanup)", messaging.SubjectNodeBackendStop(nodeID)) + }) +})