mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-05 15:26:14 -04:00
feat(distributed): Add NATS JWT authentication and TLS/mTLS options (#10159)
* feat(distributed): NATS JWT auth, TLS/mTLS options, and e2e coverage Mint per-node NATS user JWTs at registration when LOCALAI_NATS_ACCOUNT_SEED is set, and connect workers with scoped credentials from the register response. Add optional LOCALAI_NATS_TLS_CA/CERT/KEY for private CA and mTLS alongside tls:// URLs, plus test-e2e-distributed and NatsJWT container e2e specs. Document JWT setup (nats-auth-setup.sh) and TLS env vars in distributed-mode. Assisted-by: Grok:grok grok-build Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(distributed): correct NATS JWT scoping and harden client auth The JWT-auth path added in 46467cc7 had several gaps that fail silently under LOCALAI_NATS_REQUIRE_AUTH: - Agent-worker minted JWTs did not allow the subjects the agent worker actually subscribes to (jobs.mcp-ci.new and nodes.<id>.backend.stop), so MCP-CI jobs and backend-stop session cleanup were silently dropped. Scope the agent permission set to those subjects. - NATS subscription permission violations were swallowed (Subscribe returned a live-but-dead subscription). Confirm subscriptions with a server round-trip so a denial surfaces synchronously, and log async permission errors. - The backend worker connected anonymously when given a JWT without its paired seed; reject the unpaired credential instead. - The documented service-user permissions in nats-auth-setup.sh omitted prefixcache.>, which the frontend publishes and subscribes; add it. Also: add a credential-provider hook to the messaging client (consumed by the follow-up credential-lifecycle change), drop the always-nil error from NatsMessagingOptions, run go mod tidy (jwt/v2 and nkeys are now direct), and gofmt the feature's files. Tests: an agent-JWT e2e spec that connects to the enforcing NATS server and exercises every subscription the agent worker makes, plus permission allow-list coverage unit tests. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(distributed): acquire and auto-refresh worker NATS credentials Workers fetched NATS credentials once at startup, which broke two cases under JWT auth: a worker that registered while still pending admin approval never received a minted JWT (it connected unauthenticated and gave up), and a long-running worker's 24h JWT expired with no way to renew it. Introduce workerregistry.NATSCredentialManager, built on idempotent re-registration (the frontend preserves the node row and mints a fresh JWT each call): - Acquire re-registers through admin approval until the node is approved and credentials are minted (or returns the first success when auth is not required, preserving anonymous-NATS behavior). - RefreshLoop re-registers before the JWT expires (~75% of its lifetime), updating the credentials served to the connection. - Both are bounded (default 100 attempts / consecutive failures) and return an error on exhaustion, so an unapprovable or unrenewable worker exits non-zero and surfaces the problem instead of hanging or drifting toward an expired credential. The messaging client gains WithUserJWTProvider, fetching credentials on each (re)connect so the connection transparently adopts a refreshed JWT when the server expires the old one. RegisterFull exposes the approval status and full response; Register delegates to it. Both the backend worker and the agent worker are wired to this: explicit env credentials are used as-is, minted credentials are acquired-with-wait and refreshed, and a permanent refresh failure shuts the worker down so it restarts and re-acquires. Tests cover Acquire (wait-through-pending, bounded give-up, context cancel), RefreshLoop (refresh-before-expiry, bounded failure, no-expiry exit) and jwtExpiry decoding. Docs updated in distributed-mode.md. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
9d10418593
commit
3a932a9803
9
Makefile
9
Makefile
@@ -309,13 +309,20 @@ run-e2e-aio: protogen-go
|
||||
@echo 'Running e2e AIO tests'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio
|
||||
|
||||
# Distributed architecture e2e (PostgreSQL + NATS via testcontainers).
|
||||
# Includes NatsJWT specs (JWT-enabled NATS). Requires Docker.
|
||||
# VLLMMultinode is excluded here; use test-e2e-vllm-multinode for that.
|
||||
test-e2e-distributed: protogen-go
|
||||
@echo 'Running distributed e2e tests (label Distributed, incl. NatsJWT)'
|
||||
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter='Distributed && !VLLMMultinode' --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e/distributed
|
||||
|
||||
# vLLM multi-node DP smoke (CPU). Builds local-ai:tests and the
|
||||
# cpu-vllm backend from the current working tree, then drives a
|
||||
# head + headless follower via testcontainers-go and asserts a chat
|
||||
# completion. BuildKit caches both images, so re-runs only rebuild
|
||||
# what changed. The test lives under tests/e2e/distributed and is
|
||||
# selected by the VLLMMultinode label so it doesn't run alongside
|
||||
# the other distributed-suite tests by default.
|
||||
# test-e2e-distributed.
|
||||
test-e2e-vllm-multinode: docker-build-e2e extract-backend-vllm protogen-go
|
||||
@echo 'Running e2e vLLM multi-node DP test'
|
||||
LOCALAI_IMAGE=local-ai \
|
||||
|
||||
@@ -102,7 +102,12 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL)
|
||||
natsAuth := cfg.Distributed.NatsAuthConfig()
|
||||
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
|
||||
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
|
||||
}
|
||||
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
|
||||
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
|
||||
@@ -52,6 +52,15 @@ type AgentWorkerCMD struct {
|
||||
Subject string `env:"LOCALAI_AGENT_SUBJECT" default:"agent.execute" help:"NATS subject for agent execution" group:"distributed"`
|
||||
Queue string `env:"LOCALAI_AGENT_QUEUE" default:"agent-workers" help:"NATS queue group name" group:"distributed"`
|
||||
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (defaults to nats_jwt from registration)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user seed override (defaults to nats_user_seed from registration)" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"Fallback NATS service JWT when registration does not mint agent JWT" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"Fallback NATS service seed paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed to connect" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// Timeouts
|
||||
MCPCIJobTimeout string `env:"LOCALAI_MCP_CI_JOB_TIMEOUT" default:"10m" help:"Timeout for MCP CI job execution" group:"distributed"`
|
||||
}
|
||||
@@ -81,15 +90,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
registrationBody["token"] = cmd.RegistrationToken
|
||||
}
|
||||
|
||||
nodeID, apiToken, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
||||
// other background goroutines.
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Acquire credentials via (re)registration. When the bus requires auth and no
|
||||
// static fallback is configured, wait through admin approval until the
|
||||
// frontend mints credentials rather than starting unauthenticated.
|
||||
credMgr := workerregistry.NewNATSCredentialManager(
|
||||
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
||||
return regClient.RegisterFull(ctx, registrationBody)
|
||||
},
|
||||
cmd.NatsRequireAuth && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
|
||||
)
|
||||
res, err := credMgr.Acquire(shutdownCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
nodeID := res.ID
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
|
||||
|
||||
// Use provisioned API token if none was set
|
||||
if cmd.APIToken == "" {
|
||||
cmd.APIToken = apiToken
|
||||
cmd.APIToken = res.APIToken
|
||||
}
|
||||
|
||||
// Start heartbeat
|
||||
@@ -98,14 +122,40 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
|
||||
|
||||
// Connect to NATS
|
||||
natsClient, err := messaging.New(cmd.NatsURL)
|
||||
// Resolve NATS credentials with precedence: explicit env override, then
|
||||
// frontend-minted (auto-refreshed before expiry), then service fallback.
|
||||
// Each static source must supply JWT and seed together.
|
||||
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
|
||||
var natsOpts []messaging.Option
|
||||
switch {
|
||||
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
|
||||
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
|
||||
case credMgr.HasCredentials():
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
||||
go func() {
|
||||
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
||||
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
|
||||
shutdownCancel()
|
||||
}
|
||||
}()
|
||||
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
|
||||
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
|
||||
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
|
||||
}
|
||||
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
|
||||
case cmd.NatsRequireAuth:
|
||||
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
|
||||
}
|
||||
if natsTLS.Enabled() {
|
||||
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
|
||||
}
|
||||
natsClient, err := messaging.New(cmd.NatsURL, natsOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
}
|
||||
@@ -183,17 +233,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
|
||||
|
||||
// Wait for shutdown
|
||||
// Wait for an OS signal or an internal fatal condition (e.g. NATS
|
||||
// credentials became unrenewable), so the worker restarts and re-acquires
|
||||
// rather than lingering unable to serve.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
var runErr error
|
||||
select {
|
||||
case <-sigCh:
|
||||
case <-shutdownCtx.Done():
|
||||
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
|
||||
xlog.Error("Internal shutdown requested", "error", runErr)
|
||||
}
|
||||
|
||||
xlog.Info("Shutting down agent worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
dispatcher.Stop()
|
||||
mcpTools.CloseAllMCPSessions()
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
return nil
|
||||
return runErr
|
||||
}
|
||||
|
||||
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.
|
||||
|
||||
@@ -159,6 +159,14 @@ type RunCMD struct {
|
||||
DistributedPrefixCacheTTL string `env:"LOCALAI_DISTRIBUTED_PREFIX_CACHE_TTL" help:"Idle-timeout for prefix-cache index entries; also drives the background eviction cadence (every TTL/2). Default 5m." group:"distributed"`
|
||||
BackendInstallTimeout string `env:"LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT" help:"NATS round-trip timeout for backend.install requests sent to worker nodes (default 15m). Increase for slow links pulling multi-GB images." group:"distributed"`
|
||||
BackendUpgradeTimeout string `env:"LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT" help:"NATS round-trip timeout for backend.upgrade requests (default 15m)." group:"distributed"`
|
||||
NatsAccountSeed string `env:"LOCALAI_NATS_ACCOUNT_SEED" help:"NATS account signing seed (SU...) used to mint per-node worker JWTs at registration" group:"distributed"`
|
||||
NatsServiceJWT string `env:"LOCALAI_NATS_SERVICE_JWT" help:"NATS user JWT for the frontend (and agent workers) to publish control-plane messages" group:"distributed"`
|
||||
NatsServiceSeed string `env:"LOCALAI_NATS_SERVICE_SEED" help:"NATS user signing seed (SU...) paired with LOCALAI_NATS_SERVICE_JWT" group:"distributed"`
|
||||
NatsWorkerJWTTTL string `env:"LOCALAI_NATS_WORKER_JWT_TTL" help:"Lifetime of minted per-node NATS JWTs (e.g. 24h, default 24h)" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT credentials (service JWT + account seed) when distributed mode is enabled" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI); use with tls:// in --nats-url" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
@@ -283,6 +291,34 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.RegistrationToken != "" {
|
||||
opts = append(opts, config.WithRegistrationToken(r.RegistrationToken))
|
||||
}
|
||||
if r.NatsAccountSeed != "" {
|
||||
opts = append(opts, config.WithNatsAccountSeed(r.NatsAccountSeed))
|
||||
}
|
||||
if r.NatsServiceJWT != "" {
|
||||
opts = append(opts, config.WithNatsServiceJWT(r.NatsServiceJWT))
|
||||
}
|
||||
if r.NatsServiceSeed != "" {
|
||||
opts = append(opts, config.WithNatsServiceSeed(r.NatsServiceSeed))
|
||||
}
|
||||
if r.NatsWorkerJWTTTL != "" {
|
||||
d, err := time.ParseDuration(r.NatsWorkerJWTTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid LOCALAI_NATS_WORKER_JWT_TTL %q: %w", r.NatsWorkerJWTTTL, err)
|
||||
}
|
||||
opts = append(opts, config.WithNatsWorkerJWTTTL(d))
|
||||
}
|
||||
if r.NatsRequireAuth {
|
||||
opts = append(opts, config.EnableNatsRequireAuth)
|
||||
}
|
||||
if r.NatsTLSCA != "" {
|
||||
opts = append(opts, config.WithNatsTLSCA(r.NatsTLSCA))
|
||||
}
|
||||
if r.NatsTLSCert != "" {
|
||||
opts = append(opts, config.WithNatsTLSCert(r.NatsTLSCert))
|
||||
}
|
||||
if r.NatsTLSKey != "" {
|
||||
opts = append(opts, config.WithNatsTLSKey(r.NatsTLSKey))
|
||||
}
|
||||
if r.AutoApproveNodes {
|
||||
opts = append(opts, config.EnableAutoApproveNodes)
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (r *VLLMDistributed) Run(ctx *cliContext.Context) error {
|
||||
FrontendURL: r.RegisterTo,
|
||||
RegistrationToken: r.RegistrationToken,
|
||||
}
|
||||
nodeID, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
nodeID, _, _, _, regErr := regClient.RegisterWithRetry(context.Background(), r.registrationBody(), 10)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("registering with frontend: %w", regErr)
|
||||
}
|
||||
|
||||
@@ -58,65 +58,77 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
|
||||
|
||||
// RegisterResponse is the JSON body returned by /api/node/register.
|
||||
type RegisterResponse struct {
|
||||
ID string `json:"id"`
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
|
||||
APIToken string `json:"api_token,omitempty"`
|
||||
NatsJWT string `json:"nats_jwt,omitempty"`
|
||||
NatsUserSeed string `json:"nats_user_seed,omitempty"`
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// (optionally) an auto-provisioned API token.
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (string, string, error) {
|
||||
// RegisterFull sends a single registration request and returns the full
|
||||
// response (node ID, approval status, and optional API token / NATS creds).
|
||||
// Re-registration is idempotent: the frontend preserves the node row and mints
|
||||
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
|
||||
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
url := c.baseURL() + "/api/node/register"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("creating request: %w", err)
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.setAuth(req)
|
||||
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("posting to %s: %w", url, err)
|
||||
return nil, fmt.Errorf("posting to %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", "", fmt.Errorf("decoding response: %w", err)
|
||||
return nil, fmt.Errorf("decoding response: %w", err)
|
||||
}
|
||||
return result.ID, result.APIToken, nil
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Register sends a single registration request and returns the node ID and
|
||||
// optional credentials (API token for agent workers, NATS JWT when configured).
|
||||
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
res, err := c.RegisterFull(ctx, body)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
|
||||
}
|
||||
|
||||
// RegisterWithRetry retries registration with exponential backoff.
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (string, string, error) {
|
||||
func (c *RegistrationClient) RegisterWithRetry(ctx context.Context, body map[string]any, maxRetries int) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
|
||||
backoff := 2 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
|
||||
var nodeID, apiToken string
|
||||
var err error
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
nodeID, apiToken, err = c.Register(ctx, body)
|
||||
nodeID, apiToken, natsJWT, natsSeed, err = c.Register(ctx, body)
|
||||
if err == nil {
|
||||
return nodeID, apiToken, nil
|
||||
return nodeID, apiToken, natsJWT, natsSeed, nil
|
||||
}
|
||||
if attempt == maxRetries {
|
||||
return "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
return "", "", "", "", fmt.Errorf("failed after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", "", ctx.Err()
|
||||
return "", "", "", "", ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
}
|
||||
return nodeID, apiToken, err
|
||||
return nodeID, apiToken, natsJWT, natsSeed, err
|
||||
}
|
||||
|
||||
// Heartbeat sends a single heartbeat POST with the given body.
|
||||
|
||||
200
core/cli/workerregistry/credentials.go
Normal file
200
core/cli/workerregistry/credentials.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
|
||||
// imported so the lightweight registration client does not pull in the nodes
|
||||
// package (and its gorm/DB dependencies).
|
||||
const statusPending = "pending"
|
||||
|
||||
// defaultMaxAttempts bounds how many times Acquire registers (and how many
|
||||
// consecutive times RefreshLoop may fail) before giving up. It is high enough
|
||||
// to ride out a slow admin approval or a transient frontend outage, but finite
|
||||
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
|
||||
// non-zero exit and the resulting restart) rather than waiting forever.
|
||||
const defaultMaxAttempts = 100
|
||||
|
||||
// RegisterFunc performs one idempotent registration round-trip.
|
||||
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
|
||||
|
||||
// NATSCredentialManager acquires NATS credentials at startup — waiting through
|
||||
// admin approval when required — and refreshes them before the minted JWT
|
||||
// expires, by re-registering (which mints a fresh JWT). The live NATS
|
||||
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
|
||||
// for concurrent use.
|
||||
//
|
||||
// It addresses two failure modes: a worker that needs credentials but registers
|
||||
// while still pending approval (it would otherwise give up and never connect),
|
||||
// and a long-running worker whose 24h JWT expires with no way to renew it.
|
||||
type NATSCredentialManager struct {
|
||||
register RegisterFunc
|
||||
requireCreds bool // block until credentials are present (frontend minting in use)
|
||||
|
||||
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
|
||||
initialBackoff time.Duration
|
||||
maxBackoff time.Duration
|
||||
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
|
||||
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
|
||||
refreshRetry time.Duration
|
||||
expiryOf func(jwt string) (time.Time, bool)
|
||||
|
||||
mu sync.RWMutex
|
||||
jwt string
|
||||
seed string
|
||||
nodeID string
|
||||
}
|
||||
|
||||
// NewNATSCredentialManager builds a manager over register. When requireCreds is
|
||||
// true, Acquire blocks until the node is approved and credentials are minted.
|
||||
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
|
||||
return &NATSCredentialManager{
|
||||
register: register,
|
||||
requireCreds: requireCreds,
|
||||
initialBackoff: 2 * time.Second,
|
||||
maxBackoff: 30 * time.Second,
|
||||
maxAttempts: defaultMaxAttempts,
|
||||
refreshLead: 0.75,
|
||||
refreshRetry: 30 * time.Second,
|
||||
expiryOf: jwtExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
|
||||
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
|
||||
func jwtExpiry(token string) (time.Time, bool) {
|
||||
if token == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
uc, err := natsauth.DecodeUserClaims(token)
|
||||
if err != nil || uc.Expires == 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Unix(uc.Expires, 0), true
|
||||
}
|
||||
|
||||
func (m *NATSCredentialManager) store(res *RegisterResponse) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.nodeID = res.ID
|
||||
if res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
|
||||
}
|
||||
}
|
||||
|
||||
// Current returns the latest NATS credentials (both empty until acquired).
|
||||
func (m *NATSCredentialManager) Current() (jwt, seed string) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.jwt, m.seed
|
||||
}
|
||||
|
||||
// NodeID returns the node ID from the most recent registration.
|
||||
func (m *NATSCredentialManager) NodeID() string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.nodeID
|
||||
}
|
||||
|
||||
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
|
||||
// supplying the current credentials on each (re)connect.
|
||||
func (m *NATSCredentialManager) Provider() func() (string, string) {
|
||||
return m.Current
|
||||
}
|
||||
|
||||
// HasCredentials reports whether complete NATS credentials have been obtained.
|
||||
func (m *NATSCredentialManager) HasCredentials() bool {
|
||||
jwt, seed := m.Current()
|
||||
return jwt != "" && seed != ""
|
||||
}
|
||||
|
||||
// Acquire registers and, when requireCreds is set, keeps re-registering with
|
||||
// exponential backoff until the node is approved (status != pending) and
|
||||
// credentials are minted. Without requireCreds it returns the first successful
|
||||
// response (the historical one-shot behavior, preserved for anonymous NATS).
|
||||
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
|
||||
backoff := m.initialBackoff
|
||||
var lastReason error
|
||||
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
|
||||
res, err := m.register(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
lastReason = err
|
||||
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
|
||||
case !m.requireCreds:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
case res.Status == statusPending:
|
||||
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
|
||||
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
case res.NatsJWT == "" || res.NatsUserSeed == "":
|
||||
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
|
||||
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
|
||||
default:
|
||||
m.store(res)
|
||||
return res, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff = min(backoff*2, m.maxBackoff)
|
||||
}
|
||||
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
|
||||
}
|
||||
|
||||
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
|
||||
// updating the credentials returned by Current/Provider so the NATS connection
|
||||
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
|
||||
// when the current credential has no expiry (nothing to refresh), and a non-nil
|
||||
// error after maxAttempts consecutive refresh failures — letting the caller
|
||||
// exit the worker so it restarts and re-acquires (or surfaces the outage)
|
||||
// rather than silently drifting toward an expired, unrenewable JWT.
|
||||
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
|
||||
failures := 0
|
||||
for {
|
||||
jwt, _ := m.Current()
|
||||
exp, ok := m.expiryOf(jwt)
|
||||
if !ok {
|
||||
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
|
||||
return nil
|
||||
}
|
||||
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(wait):
|
||||
}
|
||||
|
||||
res, err := m.register(ctx)
|
||||
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
|
||||
m.store(res)
|
||||
failures = 0
|
||||
xlog.Info("Refreshed NATS credentials", "node", res.ID)
|
||||
continue
|
||||
}
|
||||
failures++
|
||||
if err != nil {
|
||||
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
|
||||
} else {
|
||||
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
|
||||
}
|
||||
if m.maxAttempts > 0 && failures >= m.maxAttempts {
|
||||
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
|
||||
}
|
||||
// Back off before retrying so a persistent failure near expiry does not spin.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(m.refreshRetry):
|
||||
}
|
||||
}
|
||||
}
|
||||
198
core/cli/workerregistry/credentials_test.go
Normal file
198
core/cli/workerregistry/credentials_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package workerregistry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestWorkerRegistry(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "WorkerRegistry")
|
||||
}
|
||||
|
||||
// fakeRegister returns a sequence of canned responses/errors, one per call, and
|
||||
// records how many times it was invoked. The last entry repeats once exhausted.
|
||||
type fakeRegister struct {
|
||||
mu sync.Mutex
|
||||
steps []step
|
||||
calls int
|
||||
}
|
||||
|
||||
type step struct {
|
||||
res *RegisterResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeRegister) fn() RegisterFunc {
|
||||
return func(context.Context) (*RegisterResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
i := f.calls
|
||||
f.calls++
|
||||
if i >= len(f.steps) {
|
||||
i = len(f.steps) - 1
|
||||
}
|
||||
return f.steps[i].res, f.steps[i].err
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeRegister) count() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls
|
||||
}
|
||||
|
||||
var _ = Describe("NATSCredentialManager", func() {
|
||||
approved := func(jwt, seed string) *RegisterResponse {
|
||||
return &RegisterResponse{ID: "node-1", Status: "healthy", NatsJWT: jwt, NatsUserSeed: seed}
|
||||
}
|
||||
pending := &RegisterResponse{ID: "node-1", Status: "pending"}
|
||||
|
||||
Describe("Acquire (#4 — wait through admin approval)", func() {
|
||||
It("keeps re-registering until the node is approved and credentials are minted", func() {
|
||||
f := &fakeRegister{steps: []step{
|
||||
{res: pending}, // not approved yet
|
||||
{res: approved("", "")}, // approved but JWT not minted yet
|
||||
{res: approved("jwt-1", "seed-1")}, // finally minted
|
||||
}}
|
||||
m := NewNATSCredentialManager(f.fn(), true /* requireCreds */)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.ID).To(Equal("node-1"))
|
||||
Expect(f.count()).To(Equal(3))
|
||||
|
||||
jwt, seed := m.Current()
|
||||
Expect(jwt).To(Equal("jwt-1"))
|
||||
Expect(seed).To(Equal("seed-1"))
|
||||
Expect(m.HasCredentials()).To(BeTrue())
|
||||
Expect(m.NodeID()).To(Equal("node-1"))
|
||||
})
|
||||
|
||||
It("returns immediately on the first success when credentials are not required (anonymous NATS)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), false /* requireCreds */)
|
||||
|
||||
res, err := m.Acquire(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Status).To(Equal("pending"))
|
||||
Expect(f.count()).To(Equal(1))
|
||||
Expect(m.HasCredentials()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("aborts when the context is cancelled while waiting for approval", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = 10 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := m.Acquire(ctx)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
})
|
||||
|
||||
It("gives up after a bounded number of attempts so the worker exits and alerts", func() {
|
||||
f := &fakeRegister{steps: []step{{res: pending}}} // never approved
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.initialBackoff = time.Millisecond
|
||||
m.maxBackoff = time.Millisecond
|
||||
m.maxAttempts = 5
|
||||
|
||||
_, err := m.Acquire(context.Background())
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("after 5 attempts"))
|
||||
Expect(err.Error()).To(ContainSubstring("pending admin approval"))
|
||||
Expect(f.count()).To(Equal(5))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RefreshLoop (#5 — renew before the JWT expires)", func() {
|
||||
It("re-registers before expiry and updates the credentials served to new connections", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("jwt-2", "seed-2")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
// jwt-1 expires soon; jwt-2 is long-lived so the loop then idles.
|
||||
m.expiryOf = func(jwt string) (time.Time, bool) {
|
||||
switch jwt {
|
||||
case "jwt-1":
|
||||
return time.Now().Add(40 * time.Millisecond), true
|
||||
case "jwt-2":
|
||||
return time.Now().Add(time.Hour), true
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() { _ = m.RefreshLoop(ctx) }()
|
||||
|
||||
Eventually(func() string {
|
||||
jwt, _ := m.Current()
|
||||
return jwt
|
||||
}, "2s", "10ms").Should(Equal("jwt-2"))
|
||||
})
|
||||
|
||||
It("returns an error after the bounded number of consecutive failures so the caller can exit", func() {
|
||||
f := &fakeRegister{steps: []step{{err: context.DeadlineExceeded}}} // refresh always fails
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.refreshLead = 0.5
|
||||
m.refreshRetry = time.Millisecond
|
||||
m.maxAttempts = 3
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Now().Add(time.Millisecond), true }
|
||||
m.store(approved("jwt-1", "seed-1"))
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- m.RefreshLoop(context.Background()) }()
|
||||
Eventually(errCh, "2s").Should(Receive(MatchError(ContainSubstring("3 times in a row"))))
|
||||
})
|
||||
|
||||
It("exits promptly when the current credential has no expiry (nothing to refresh)", func() {
|
||||
f := &fakeRegister{steps: []step{{res: approved("x", "y")}}}
|
||||
m := NewNATSCredentialManager(f.fn(), true)
|
||||
m.expiryOf = func(string) (time.Time, bool) { return time.Time{}, false }
|
||||
m.store(approved("static", "seed"))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { _ = m.RefreshLoop(context.Background()); close(done) }()
|
||||
Eventually(done, "1s").Should(BeClosed())
|
||||
Expect(f.count()).To(Equal(0)) // never tried to re-register
|
||||
})
|
||||
})
|
||||
|
||||
Describe("jwtExpiry default", func() {
|
||||
It("decodes the expiry of a real minted worker JWT", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
cfg := natsauth.Config{AccountSeed: string(seed), WorkerJWTTTL: time.Hour}
|
||||
token, _, err := cfg.MintWorkerJWT("node-1", "backend")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
exp, ok := jwtExpiry(token)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(exp).To(BeTemporally("~", time.Now().Add(time.Hour), 2*time.Minute))
|
||||
})
|
||||
|
||||
It("reports no expiry for an empty or undecodable token", func() {
|
||||
_, ok := jwtExpiry("")
|
||||
Expect(ok).To(BeFalse())
|
||||
_, ok = jwtExpiry("not-a-jwt")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -18,6 +20,16 @@ type DistributedConfig struct {
|
||||
RegistrationToken string // --registration-token / LOCALAI_REGISTRATION_TOKEN (required token for node registration)
|
||||
AutoApproveNodes bool // --auto-approve-nodes / LOCALAI_AUTO_APPROVE_NODES (skip admin approval for new workers)
|
||||
|
||||
// NATS JWT auth (optional; see pkg/natsauth and docs/features/distributed-mode.md)
|
||||
NatsAccountSeed string // LOCALAI_NATS_ACCOUNT_SEED — account signing seed to mint per-node worker JWTs
|
||||
NatsServiceJWT string // LOCALAI_NATS_SERVICE_JWT — user JWT for frontends / agent workers
|
||||
NatsServiceSeed string // LOCALAI_NATS_SERVICE_SEED — signing seed paired with service JWT
|
||||
NatsWorkerJWTTTL time.Duration // LOCALAI_NATS_WORKER_JWT_TTL — minted worker JWT lifetime (default 24h)
|
||||
NatsRequireAuth bool // LOCALAI_NATS_REQUIRE_AUTH — fail startup if NATS credentials are missing
|
||||
NatsTLSCA string // LOCALAI_NATS_TLS_CA — PEM file for private CA (server verify)
|
||||
NatsTLSCert string // LOCALAI_NATS_TLS_CERT — client cert for NATS mTLS
|
||||
NatsTLSKey string // LOCALAI_NATS_TLS_KEY — client key paired with NatsTLSCert
|
||||
|
||||
// S3 configuration (used when StorageURL is set)
|
||||
StorageBucket string // --storage-bucket / LOCALAI_STORAGE_BUCKET
|
||||
StorageRegion string // --storage-region / LOCALAI_STORAGE_REGION
|
||||
@@ -80,6 +92,13 @@ func (c DistributedConfig) Validate() error {
|
||||
if c.RegistrationToken == "" {
|
||||
xlog.Warn("distributed mode running without registration token — node endpoints are unprotected")
|
||||
}
|
||||
if err := c.NatsAuthConfig().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.NatsTLSFiles().Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.NatsAuthConfig().WarnIfInsecure(true)
|
||||
// Check for negative durations
|
||||
for name, d := range map[string]time.Duration{
|
||||
FlagMCPToolTimeout: c.MCPToolTimeout,
|
||||
@@ -123,6 +142,52 @@ func WithRegistrationToken(token string) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsAccountSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsAccountSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceJWT(jwt string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceJWT = jwt
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsServiceSeed(seed string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsServiceSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsWorkerJWTTTL(d time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsWorkerJWTTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
var EnableNatsRequireAuth = func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsRequireAuth = true
|
||||
}
|
||||
|
||||
func WithNatsTLSCA(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCA = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSCert(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSCert = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithNatsTLSKey(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.NatsTLSKey = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithStorageURL(url string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.StorageURL = url
|
||||
@@ -217,6 +282,44 @@ const (
|
||||
// DefaultMaxUploadSize is the default maximum upload body size (50 GB).
|
||||
const DefaultMaxUploadSize int64 = 50 << 30
|
||||
|
||||
// NatsTLSFiles returns NATS TLS/mTLS PEM paths for the messaging client.
|
||||
func (c DistributedConfig) NatsTLSFiles() messaging.TLSFiles {
|
||||
return messaging.TLSFiles{
|
||||
CA: c.NatsTLSCA,
|
||||
Cert: c.NatsTLSCert,
|
||||
Key: c.NatsTLSKey,
|
||||
}
|
||||
}
|
||||
|
||||
// NatsMessagingOptions builds messaging client options (JWT + TLS) for distributed components.
|
||||
// Pass explicit userJWT/userSeed when set (e.g. worker overrides); empty uses service JWT from config.
|
||||
func (c DistributedConfig) NatsMessagingOptions(userJWT, userSeed string) []messaging.Option {
|
||||
var opts []messaging.Option
|
||||
jwt, seed := userJWT, userSeed
|
||||
if jwt == "" && seed == "" {
|
||||
auth := c.NatsAuthConfig()
|
||||
jwt, seed = auth.ServiceUserJWT, auth.ServiceUserSeed
|
||||
}
|
||||
if jwt != "" && seed != "" {
|
||||
opts = append(opts, messaging.WithUserJWT(jwt, seed))
|
||||
}
|
||||
if tls := c.NatsTLSFiles(); tls.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(tls))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// NatsAuthConfig builds pkg/natsauth settings from distributed configuration.
|
||||
func (c DistributedConfig) NatsAuthConfig() natsauth.Config {
|
||||
return natsauth.Config{
|
||||
AccountSeed: c.NatsAccountSeed,
|
||||
ServiceUserJWT: c.NatsServiceJWT,
|
||||
ServiceUserSeed: c.NatsServiceSeed,
|
||||
WorkerJWTTTL: c.NatsWorkerJWTTTL,
|
||||
RequireAuth: c.NatsRequireAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// BackendInstallTimeoutOrDefault returns the configured timeout or the default.
|
||||
func (c DistributedConfig) BackendInstallTimeoutOrDefault() time.Duration {
|
||||
return cmp.Or(c.BackendInstallTimeout, DefaultBackendInstallTimeout)
|
||||
|
||||
@@ -420,8 +420,9 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
remoteUnloader = d.Router.Unloader()
|
||||
}
|
||||
}
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken)
|
||||
natsCfg := distCfg.NatsAuthConfig()
|
||||
routes.RegisterNodeSelfServiceRoutes(e, registry, distCfg.RegistrationToken, distCfg.AutoApproveNodes, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, natsCfg)
|
||||
routes.RegisterNodeAdminRoutes(e, registry, remoteUnloader, application.GalleryService(), opcache, application.ApplicationConfig(), adminMiddleware, application.AuthDB(), application.ApplicationConfig().Auth.APIKeyHMACSecret, application.ApplicationConfig().Distributed.RegistrationToken, natsCfg)
|
||||
|
||||
// Distributed SSE routes (job progress + agent events via NATS)
|
||||
if d := application.Distributed(); d != nil {
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
)
|
||||
|
||||
// nodeError builds a schema.ErrorResponse for node endpoints.
|
||||
@@ -89,7 +90,7 @@ type RegisterNodeRequest struct {
|
||||
// RegisterNodeEndpoint registers a new backend node.
|
||||
// expectedToken is the registration token configured on the frontend (may be empty to disable auth).
|
||||
// autoApprove controls whether new nodes go directly to "healthy" or require admin approval.
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req RegisterNodeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
@@ -217,13 +218,15 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
}
|
||||
|
||||
// ApproveNodeEndpoint approves a pending node, setting its status to healthy.
|
||||
// For agent workers, it also provisions an API key so they can call the inference API.
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string) echo.HandlerFunc {
|
||||
func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
id := c.Param("id")
|
||||
@@ -253,10 +256,26 @@ func ApproveNodeEndpoint(registry *nodes.NodeRegistry, authDB *gorm.DB, hmacSecr
|
||||
}
|
||||
}
|
||||
|
||||
attachNatsJWT(response, node, natsCfg)
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
// attachNatsJWT adds a per-node NATS user JWT to a register/approve response when minting is enabled.
|
||||
func attachNatsJWT(response map[string]any, node *nodes.BackendNode, natsCfg natsauth.Config) {
|
||||
if !natsCfg.CanMintWorkers() || node == nil || node.Status == nodes.StatusPending {
|
||||
return
|
||||
}
|
||||
jwt, seed, err := natsCfg.MintWorkerJWT(node.ID, node.NodeType)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to mint NATS JWT for node", "node", node.Name, "id", node.ID, "error", err)
|
||||
return
|
||||
}
|
||||
response["nats_jwt"] = jwt
|
||||
response["nats_user_seed"] = seed
|
||||
}
|
||||
|
||||
// provisionAgentWorkerKey creates a dedicated user and API key for an agent worker node.
|
||||
// Returns the plaintext API key on success.
|
||||
func provisionAgentWorkerKey(ctx context.Context, authDB *gorm.DB, registry *nodes.NodeRegistry, node *nodes.BackendNode, hmacSecret string) (string, error) {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -63,7 +65,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -74,6 +76,29 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
Expect(resp["status"]).To(Equal(nodes.StatusHealthy))
|
||||
})
|
||||
|
||||
It("returns nats_jwt when account seed is configured", func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
e := echo.New()
|
||||
body := `{"name":"worker-nats","address":"10.0.0.2:50051"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
natsCfg := natsauth.Config{AccountSeed: string(seed)}
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsCfg)
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp["nats_jwt"]).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns 400 when name is missing", func() {
|
||||
e := echo.New()
|
||||
body := `{"address":"10.0.0.1:50051"}`
|
||||
@@ -82,7 +107,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -102,7 +127,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -121,7 +146,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -140,7 +165,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
|
||||
@@ -159,7 +184,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "correct-token", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
@@ -172,7 +197,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", false, nil, "", natsauth.Config{})
|
||||
Expect(handler(c)).To(Succeed())
|
||||
Expect(rec.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
@@ -195,7 +220,7 @@ var _ = Describe("Node HTTP handlers", func() {
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body1))
|
||||
req1.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "")
|
||||
handler := RegisterNodeEndpoint(registry, "", true, nil, "", natsauth.Config{})
|
||||
Expect(handler(e.NewContext(req1, rec1))).To(Succeed())
|
||||
Expect(rec1.Code).To(Equal(http.StatusCreated))
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ func nodeReadyMiddleware(registry *nodes.NodeRegistry) echo.MiddlewareFunc {
|
||||
// token but do not verify per-node identity. A compromised worker can heartbeat/drain/
|
||||
// deregister other nodes. Future: issue per-node JWT at registration, validate node
|
||||
// identity on subsequent requests (compare :id param with token subject).
|
||||
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string) {
|
||||
func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, registrationToken string, autoApprove bool, authDB *gorm.DB, hmacSecret string, natsCfg natsauth.Config) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
@@ -44,7 +45,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
|
||||
tokenAuthMw := nodeTokenAuth(registrationToken)
|
||||
|
||||
node := e.Group("/api/node", readyMw, tokenAuthMw)
|
||||
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret))
|
||||
node.POST("/register", localai.RegisterNodeEndpoint(registry, registrationToken, autoApprove, authDB, hmacSecret, natsCfg))
|
||||
node.POST("/:id/heartbeat", localai.HeartbeatEndpoint(registry))
|
||||
node.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
|
||||
node.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
|
||||
@@ -60,7 +61,7 @@ func RegisterNodeSelfServiceRoutes(e *echo.Echo, registry *nodes.NodeRegistry, r
|
||||
// backend install path (POST /:id/backends/install). That handler enqueues a
|
||||
// ManagementOp on the gallery channel rather than blocking on a NATS reply, so
|
||||
// the browser gets HTTP 202 + jobID immediately instead of waiting up to 3 minutes.
|
||||
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string) {
|
||||
func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloader nodes.NodeCommandSender, galleryService *galleryop.GalleryService, opcache *galleryop.OpCache, appConfig *config.ApplicationConfig, adminMw echo.MiddlewareFunc, authDB *gorm.DB, hmacSecret string, registrationToken string, natsCfg natsauth.Config) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
|
||||
admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry))
|
||||
admin.POST("/:id/drain", localai.DrainNodeEndpoint(registry))
|
||||
admin.POST("/:id/resume", localai.ResumeNodeEndpoint(registry))
|
||||
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret))
|
||||
admin.POST("/:id/approve", localai.ApproveNodeEndpoint(registry, authDB, hmacSecret, natsCfg))
|
||||
|
||||
// Backend management on workers
|
||||
admin.GET("/:id/backends", localai.ListBackendsOnNodeEndpoint(unloader))
|
||||
|
||||
@@ -2,15 +2,22 @@ package messaging
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/sanitize"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
// subscribeConfirmTimeout bounds the server round-trip used to detect whether a
|
||||
// subscription was rejected (e.g. by JWT permissions) before returning to the caller.
|
||||
const subscribeConfirmTimeout = 5 * time.Second
|
||||
|
||||
// Client wraps a NATS connection and provides helpers for pub/sub and queue subscriptions.
|
||||
type Client struct {
|
||||
conn *nats.Conn
|
||||
@@ -18,8 +25,13 @@ type Client struct {
|
||||
}
|
||||
|
||||
// New creates a new NATS client with auto-reconnect.
|
||||
func New(url string) (*Client, error) {
|
||||
nc, err := nats.Connect(url,
|
||||
func New(url string, opts ...Option) (*Client, error) {
|
||||
var cfg connectConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
natsOpts := []nats.Option{
|
||||
nats.RetryOnFailedConnect(true),
|
||||
nats.MaxReconnects(-1),
|
||||
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
|
||||
@@ -33,7 +45,60 @@ func New(url string) (*Client, error) {
|
||||
nats.ClosedHandler(func(_ *nats.Conn) {
|
||||
xlog.Info("NATS connection closed")
|
||||
}),
|
||||
)
|
||||
// Surface async errors (notably permission violations) that NATS would
|
||||
// otherwise deliver silently. A subscription the server rejects for a
|
||||
// JWT permission means the worker never receives those messages, so make
|
||||
// it loud rather than letting the feature fail invisibly.
|
||||
nats.ErrorHandler(func(_ *nats.Conn, sub *nats.Subscription, err error) {
|
||||
subject := ""
|
||||
if sub != nil {
|
||||
subject = sub.Subject
|
||||
}
|
||||
if errors.Is(err, nats.ErrPermissionViolation) {
|
||||
xlog.Error("NATS permission violation — check JWT pub/sub allow lists", "subject", subject, "error", err)
|
||||
return
|
||||
}
|
||||
xlog.Warn("NATS async error", "subject", subject, "error", err)
|
||||
}),
|
||||
}
|
||||
switch {
|
||||
case cfg.jwtProvider != nil:
|
||||
// Fetch creds on every (re)connect so a refresh loop can rotate the JWT
|
||||
// before expiry; the server expiring the old JWT triggers a reconnect
|
||||
// that transparently picks up the new one.
|
||||
natsOpts = append(natsOpts, nats.UserJWT(
|
||||
func() (string, error) {
|
||||
jwt, _ := cfg.jwtProvider()
|
||||
if jwt == "" {
|
||||
return "", fmt.Errorf("no NATS user JWT available")
|
||||
}
|
||||
return jwt, nil
|
||||
},
|
||||
func(nonce []byte) ([]byte, error) {
|
||||
_, seed := cfg.jwtProvider()
|
||||
kp, err := nkeys.FromSeed([]byte(seed))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading NATS user seed: %w", err)
|
||||
}
|
||||
defer kp.Wipe()
|
||||
return kp.Sign(nonce)
|
||||
},
|
||||
))
|
||||
case cfg.userJWT != "" && cfg.userSeed != "":
|
||||
natsOpts = append(natsOpts, nats.UserJWTAndSeed(cfg.userJWT, cfg.userSeed))
|
||||
}
|
||||
if cfg.tls.Enabled() {
|
||||
if err := cfg.tls.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsOpts, err := cfg.tls.natsOptions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
natsOpts = append(natsOpts, tlsOpts...)
|
||||
}
|
||||
|
||||
nc, err := nats.Connect(url, natsOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to NATS at %s: %w", sanitize.URL(url), err)
|
||||
}
|
||||
@@ -54,23 +119,67 @@ func (c *Client) Publish(subject string, data any) error {
|
||||
|
||||
// Subscribe creates a subscription on the given subject. All subscribers receive every message.
|
||||
func (c *Client) Subscribe(subject string, handler func([]byte)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// QueueSubscribe creates a queue subscription. Within the same queue group,
|
||||
// only one subscriber receives each message (load-balanced).
|
||||
func (c *Client) QueueSubscribe(subject, queue string, handler func([]byte)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// confirmSubscription creates a subscription via mk and forces a server
|
||||
// round-trip so that a permissions violation — which NATS otherwise reports
|
||||
// only asynchronously — is returned to the caller synchronously. The server
|
||||
// emits the "-ERR Permissions Violation" for a rejected SUB before the PONG
|
||||
// that satisfies the flush, so by the time FlushTimeout returns the violation
|
||||
// is recorded as the connection's last error. Without this, a worker whose JWT
|
||||
// lacks a subject gets a non-nil subscription that never receives a message,
|
||||
// turning a permission misconfiguration into a silent failure.
|
||||
func (c *Client) confirmSubscription(subject string, mk func(*nats.Conn) (*nats.Subscription, error)) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
conn := c.conn
|
||||
c.mu.RUnlock()
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("subscribe to %s: nil NATS connection", subject)
|
||||
}
|
||||
|
||||
sub, err := mk(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// A failed flush here means we could not round-trip to the server (not yet
|
||||
// connected, reconnecting, slow link). RetryOnFailedConnect intentionally
|
||||
// buffers subscriptions across that gap, so do NOT fail — keep the
|
||||
// subscription and let it replay on (re)connect; a later permission
|
||||
// violation is still logged by the async error handler in New.
|
||||
if err := conn.FlushTimeout(subscribeConfirmTimeout); err != nil {
|
||||
xlog.Debug("Could not confirm NATS subscription (will replay on connect)", "subject", subject, "error", err)
|
||||
return sub, nil
|
||||
}
|
||||
// Flush succeeded, so any permission violation for this SUB has already been
|
||||
// recorded as the connection's last error (the server emits it before the
|
||||
// PONG). LastError is per-connection; match the exact quoted subject the
|
||||
// server echoes ("Subscription to \"<subject>\"") so a stale violation for
|
||||
// another subject can't be mis-attributed here.
|
||||
if lerr := conn.LastError(); lerr != nil &&
|
||||
errors.Is(lerr, nats.ErrPermissionViolation) &&
|
||||
strings.Contains(lerr.Error(), `Subscription to "`+subject+`"`) {
|
||||
_ = sub.Unsubscribe()
|
||||
return nil, fmt.Errorf("subscription to %s denied by NATS server (check JWT sub allow list): %w", subject, lerr)
|
||||
}
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// Request sends a request and waits for a reply (request-reply pattern).
|
||||
// Returns the raw reply data.
|
||||
func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
|
||||
@@ -86,15 +195,15 @@ func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]
|
||||
// SubscribeReply creates a subscription that supports replying to requests.
|
||||
// The handler receives the raw request data and the reply subject.
|
||||
func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -102,15 +211,15 @@ func (c *Client) SubscribeReply(subject string, handler func(data []byte, reply
|
||||
// QueueSubscribeReply creates a queue subscription that supports replying to requests.
|
||||
// Load-balanced across subscribers in the same queue group, with request-reply support.
|
||||
func (c *Client) QueueSubscribeReply(subject, queue string, handler func(data []byte, reply func([]byte))) (Subscription, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
return c.confirmSubscription(subject, func(conn *nats.Conn) (*nats.Subscription, error) {
|
||||
return conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Data, func(replyData []byte) {
|
||||
if msg.Reply != "" {
|
||||
if err := msg.Respond(replyData); err != nil {
|
||||
xlog.Warn("Failed to send NATS reply", "subject", subject, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
34
core/services/messaging/options.go
Normal file
34
core/services/messaging/options.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package messaging
|
||||
|
||||
// Option configures NATS client connection behavior.
|
||||
type Option func(*connectConfig)
|
||||
|
||||
// CredentialProvider returns the NATS user JWT and signing seed to use for the
|
||||
// next (re)connect. It is consulted on every connection attempt, so a refresh
|
||||
// loop can rotate credentials before they expire and the connection picks them
|
||||
// up automatically when the server expires the old JWT and triggers a reconnect.
|
||||
type CredentialProvider func() (jwt, seed string)
|
||||
|
||||
type connectConfig struct {
|
||||
userJWT string
|
||||
userSeed string
|
||||
jwtProvider CredentialProvider
|
||||
tls TLSFiles
|
||||
}
|
||||
|
||||
// WithUserJWT connects using a static NATS user JWT and signing seed (UserJWTAndSeed).
|
||||
func WithUserJWT(jwt, seed string) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.userJWT = jwt
|
||||
c.userSeed = seed
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserJWTProvider connects using credentials fetched from provider on each
|
||||
// (re)connect, enabling JWT rotation without dropping the client. Takes
|
||||
// precedence over WithUserJWT when both are set.
|
||||
func WithUserJWTProvider(provider CredentialProvider) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.jwtProvider = provider
|
||||
}
|
||||
}
|
||||
68
core/services/messaging/tls.go
Normal file
68
core/services/messaging/tls.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package messaging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
// TLSFiles holds PEM paths for NATS TLS / mTLS. Cert and key must be set together.
|
||||
// Use tls:// in LOCALAI_NATS_URL; CA and client cert paths are optional extras.
|
||||
type TLSFiles struct {
|
||||
CA string // LOCALAI_NATS_TLS_CA — private CA for server verification
|
||||
Cert string // LOCALAI_NATS_TLS_CERT — client certificate (mTLS)
|
||||
Key string // LOCALAI_NATS_TLS_KEY — client private key
|
||||
}
|
||||
|
||||
// Enabled reports whether any TLS file path is configured.
|
||||
func (f TLSFiles) Enabled() bool {
|
||||
return f.CA != "" || f.Cert != "" || f.Key != ""
|
||||
}
|
||||
|
||||
// Validate checks path pairing and that files exist.
|
||||
func (f TLSFiles) Validate() error {
|
||||
if f.Cert != "" && f.Key == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_TLS_KEY is required when LOCALAI_NATS_TLS_CERT is set")
|
||||
}
|
||||
if f.Key != "" && f.Cert == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_TLS_CERT is required when LOCALAI_NATS_TLS_KEY is set")
|
||||
}
|
||||
for _, path := range []struct {
|
||||
name, path string
|
||||
}{
|
||||
{"LOCALAI_NATS_TLS_CA", f.CA},
|
||||
{"LOCALAI_NATS_TLS_CERT", f.Cert},
|
||||
{"LOCALAI_NATS_TLS_KEY", f.Key},
|
||||
} {
|
||||
if path.path == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(path.path); err != nil {
|
||||
return fmt.Errorf("%s: %w", path.name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// natsOptions builds nats-go TLS options. Call Validate first.
|
||||
func (f TLSFiles) natsOptions() ([]nats.Option, error) {
|
||||
if !f.Enabled() {
|
||||
return nil, nil
|
||||
}
|
||||
opts := []nats.Option{nats.Secure()}
|
||||
if f.CA != "" {
|
||||
opts = append(opts, nats.RootCAs(f.CA))
|
||||
}
|
||||
if f.Cert != "" {
|
||||
opts = append(opts, nats.ClientCert(f.Cert, f.Key))
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// WithTLS configures CA and/or client certificate paths for the NATS connection.
|
||||
func WithTLS(files TLSFiles) Option {
|
||||
return func(c *connectConfig) {
|
||||
c.tls = files
|
||||
}
|
||||
}
|
||||
25
core/services/messaging/tls_test.go
Normal file
25
core/services/messaging/tls_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package messaging_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TLSFiles", func() {
|
||||
It("requires cert and key together", func() {
|
||||
Expect((messaging.TLSFiles{Cert: "/tmp/c.pem"}).Validate()).To(HaveOccurred())
|
||||
Expect((messaging.TLSFiles{Key: "/tmp/k.pem"}).Validate()).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("validates files exist", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
ca := filepath.Join(dir, "ca.pem")
|
||||
Expect(os.WriteFile(ca, []byte("x"), 0600)).To(Succeed())
|
||||
Expect((messaging.TLSFiles{CA: ca}).Validate()).To(Succeed())
|
||||
})
|
||||
})
|
||||
@@ -60,7 +60,13 @@ type Config struct {
|
||||
MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"`
|
||||
|
||||
// NATS (required)
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
NatsJWT string `env:"LOCALAI_NATS_JWT" help:"NATS user JWT override (normally from registration nats_jwt)" group:"distributed"`
|
||||
NatsUserSeed string `env:"LOCALAI_NATS_USER_SEED" help:"NATS user signing seed override (normally from registration nats_user_seed)" group:"distributed"`
|
||||
NatsRequireAuth bool `env:"LOCALAI_NATS_REQUIRE_AUTH" default:"false" help:"Require NATS JWT+seed from registration or env" group:"distributed"`
|
||||
NatsTLSCA string `env:"LOCALAI_NATS_TLS_CA" type:"existingfile" help:"PEM file for NATS server CA (private PKI)" group:"distributed"`
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
|
||||
// S3 storage for distributed file transfer
|
||||
StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"`
|
||||
|
||||
33
core/services/worker/nats_connect.go
Normal file
33
core/services/worker/nats_connect.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
)
|
||||
|
||||
// connectNATS opens a NATS client using JWT+seed from env or registration (env wins).
|
||||
func connectNATS(url, envJWT, envSeed, registerJWT, registerSeed string, requireAuth bool, tls messaging.TLSFiles) (*messaging.Client, error) {
|
||||
// Env credentials take precedence, but only fall back to registration when
|
||||
// the env supplied neither half — otherwise a JWT set without its seed (or
|
||||
// vice-versa) would be silently completed from a different source.
|
||||
jwt, seed := envJWT, envSeed
|
||||
if jwt == "" && seed == "" {
|
||||
jwt, seed = registerJWT, registerSeed
|
||||
}
|
||||
// A JWT without its paired seed (or vice-versa) is a misconfiguration: refuse
|
||||
// rather than silently connecting anonymously, which would look authenticated.
|
||||
if (jwt == "") != (seed == "") {
|
||||
return nil, fmt.Errorf("NATS JWT and seed must be provided together (got JWT set=%t, seed set=%t)", jwt != "", seed != "")
|
||||
}
|
||||
var opts []messaging.Option
|
||||
if jwt != "" && seed != "" {
|
||||
opts = append(opts, messaging.WithUserJWT(jwt, seed))
|
||||
} else if requireAuth {
|
||||
return nil, fmt.Errorf("NATS JWT+seed required: set LOCALAI_NATS_JWT/LOCALAI_NATS_USER_SEED or enable frontend minting")
|
||||
}
|
||||
if tls.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(tls))
|
||||
}
|
||||
return messaging.New(url, opts...)
|
||||
}
|
||||
29
core/services/worker/nats_connect_test.go
Normal file
29
core/services/worker/nats_connect_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("connectNATS", func() {
|
||||
It("requires JWT when requireAuth is set and no credentials are provided", func() {
|
||||
_, err := connectNATS("nats://127.0.0.1:4222", "", "", "", "", true, messaging.TLSFiles{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("NATS JWT+seed required"))
|
||||
})
|
||||
|
||||
// A JWT supplied without its paired seed (or vice-versa) is an operator
|
||||
// misconfiguration. Today connectNATS silently drops the unpaired credential
|
||||
// and connects anonymously, so the operator believes the link is
|
||||
// authenticated when it is not. It should refuse instead.
|
||||
It("rejects a JWT supplied without a seed instead of connecting anonymously", func() {
|
||||
client, err := connectNATS("nats://127.0.0.1:4222", "jwt-without-seed", "", "", "", false, messaging.TLSFiles{})
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
Expect(err).To(HaveOccurred(),
|
||||
"connectNATS should reject an unpaired JWT rather than silently connecting anonymously")
|
||||
})
|
||||
})
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/cli/workerregistry"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
@@ -67,10 +68,63 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
|
||||
RegistrationToken: cfg.RegistrationToken,
|
||||
}
|
||||
|
||||
// Context cancelled on shutdown — used by registration waits, heartbeat, and
|
||||
// other background goroutines.
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
registrationBody := cfg.registrationBody()
|
||||
nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register with frontend: %w", err)
|
||||
natsTLS := messaging.TLSFiles{CA: cfg.NatsTLSCA, Cert: cfg.NatsTLSCert, Key: cfg.NatsTLSKey}
|
||||
|
||||
// Resolve how to connect to NATS. Static env credentials cannot be re-minted,
|
||||
// so register once and use them directly. Otherwise the credential manager
|
||||
// (re)registers to obtain credentials — waiting through admin approval — and
|
||||
// refreshes them before the minted JWT expires, so the connection survives
|
||||
// expiry via a transparent reconnect.
|
||||
var (
|
||||
nodeID string
|
||||
connectNats func() (*messaging.Client, error)
|
||||
)
|
||||
if cfg.NatsJWT != "" || cfg.NatsUserSeed != "" {
|
||||
nid, _, _, _, regErr := regClient.RegisterWithRetry(shutdownCtx, registrationBody, 10)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("failed to register with frontend: %w", regErr)
|
||||
}
|
||||
nodeID = nid
|
||||
connectNats = func() (*messaging.Client, error) {
|
||||
return connectNATS(cfg.NatsURL, cfg.NatsJWT, cfg.NatsUserSeed, "", "", cfg.NatsRequireAuth, natsTLS)
|
||||
}
|
||||
} else {
|
||||
credMgr := workerregistry.NewNATSCredentialManager(
|
||||
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
|
||||
return regClient.RegisterFull(ctx, registrationBody)
|
||||
},
|
||||
cfg.NatsRequireAuth,
|
||||
)
|
||||
res, regErr := credMgr.Acquire(shutdownCtx)
|
||||
if regErr != nil {
|
||||
return fmt.Errorf("failed to register with frontend: %w", regErr)
|
||||
}
|
||||
nodeID = res.ID
|
||||
connectNats = func() (*messaging.Client, error) {
|
||||
var opts []messaging.Option
|
||||
if credMgr.HasCredentials() {
|
||||
opts = append(opts, messaging.WithUserJWTProvider(credMgr.Provider()))
|
||||
}
|
||||
if natsTLS.Enabled() {
|
||||
opts = append(opts, messaging.WithTLS(natsTLS))
|
||||
}
|
||||
client, cerr := messaging.New(cfg.NatsURL, opts...)
|
||||
if cerr == nil && credMgr.HasCredentials() {
|
||||
go func() {
|
||||
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
|
||||
xlog.Error("NATS credential refresh permanently failed; shutting down worker", "error", err)
|
||||
shutdownCancel()
|
||||
}
|
||||
}()
|
||||
}
|
||||
return client, cerr
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cfg.RegisterTo)
|
||||
@@ -79,9 +133,6 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
|
||||
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cfg.HeartbeatInterval, "error", err)
|
||||
}
|
||||
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
|
||||
// Context cancelled on shutdown — used by heartbeat and other background goroutines
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
defer shutdownCancel()
|
||||
|
||||
// Start HTTP file transfer server
|
||||
httpAddr := cfg.resolveHTTPAddr()
|
||||
@@ -94,7 +145,7 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
|
||||
|
||||
// Connect to NATS
|
||||
xlog.Info("Connecting to NATS", "url", sanitize.URL(cfg.NatsURL))
|
||||
natsClient, err := messaging.New(cfg.NatsURL)
|
||||
natsClient, err := connectNats()
|
||||
if err != nil {
|
||||
nodes.ShutdownFileTransferServer(httpServer)
|
||||
return fmt.Errorf("connecting to NATS: %w", err)
|
||||
@@ -154,12 +205,21 @@ func Run(ctx *cliContext.Context, cfg *Config) error {
|
||||
}
|
||||
|
||||
xlog.Info("Worker ready, waiting for backend.install events")
|
||||
<-sigCh
|
||||
// Exit on an OS signal or on an internal fatal condition (e.g. NATS
|
||||
// credentials became unrenewable), so the worker restarts and re-acquires
|
||||
// rather than lingering unable to serve.
|
||||
var runErr error
|
||||
select {
|
||||
case <-sigCh:
|
||||
case <-shutdownCtx.Done():
|
||||
runErr = fmt.Errorf("worker shutting down: NATS credentials unavailable")
|
||||
xlog.Error("Internal shutdown requested", "error", runErr)
|
||||
}
|
||||
|
||||
xlog.Info("Shutting down worker")
|
||||
shutdownCancel() // stop heartbeat loop immediately
|
||||
regClient.GracefulDeregister(nodeID)
|
||||
supervisor.stopAllBackends()
|
||||
nodes.ShutdownFileTransferServer(httpServer)
|
||||
return nil
|
||||
return runErr
|
||||
}
|
||||
|
||||
@@ -71,6 +71,50 @@ The frontend is a standard LocalAI instance with distributed mode enabled. These
|
||||
| `--backend-upgrade-timeout` | `LOCALAI_NATS_BACKEND_UPGRADE_TIMEOUT` | `15m` | Same as the install timeout, applied to backend upgrades (force-reinstall). |
|
||||
| `--expose-node-header` | `LOCALAI_EXPOSE_NODE_HEADER` | `false` | When enabled, inference responses carry an `X-LocalAI-Node` header with the ID of the worker node that served the request. Coverage spans the OpenAI-compatible endpoints (chat completions, completions, embeddings, audio transcriptions, audio speech / TTS, image generations, image inpainting), the Jina rerank endpoint (`/v1/rerank`), the VAD endpoints (`/v1/vad`, `/vad`), and the Anthropic Messages (`/v1/messages`) and Ollama (`/api/chat`, `/api/generate`, `/api/embed`) shims. Useful for debugging, observability and load-balancer attribution. Off by default: the node ID reveals internal cluster topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency for the same model across multiple replicas, the header may reflect a recent routing decision rather than this exact request's. Acceptable for observability and debugging. |
|
||||
|
||||
### NATS JWT authentication (recommended for production)
|
||||
|
||||
By default, NATS connections are anonymous: any client that can reach port `4222` may publish control-plane subjects such as `nodes.<id>.backend.install`. Enable JWT auth to scope workers to their own node subjects and give the frontend a dedicated service credential.
|
||||
|
||||
| Flag | Env Var | Description |
|
||||
|------|---------|-------------|
|
||||
| `--nats-account-seed` | `LOCALAI_NATS_ACCOUNT_SEED` | Account signing seed (`SU...`). The frontend mints a per-node user JWT at registration (`nats_jwt` in the register response). |
|
||||
| `--nats-service-jwt` | `LOCALAI_NATS_SERVICE_JWT` | User JWT for the frontend (and optional fallback for agent workers) to publish install/upgrade and related subjects. |
|
||||
| `--nats-service-seed` | `LOCALAI_NATS_SERVICE_SEED` | User signing seed (`SU...`) paired with the service JWT. |
|
||||
| `--nats-worker-jwt-ttl` | `LOCALAI_NATS_WORKER_JWT_TTL` | Lifetime of minted worker JWTs (default `24h`). |
|
||||
| `--nats-require-auth` | `LOCALAI_NATS_REQUIRE_AUTH` | Fail startup if JWT credentials are missing when distributed mode is enabled. |
|
||||
|
||||
### NATS TLS / mTLS (optional)
|
||||
|
||||
Use `tls://` in `--nats-url` / `LOCALAI_NATS_URL` for encrypted transport. When the server uses a private CA or requires client certificates, set:
|
||||
|
||||
| Flag | Env Var | Description |
|
||||
|------|---------|-------------|
|
||||
| `--nats-tls-ca` | `LOCALAI_NATS_TLS_CA` | PEM file to verify the NATS server (private CA) |
|
||||
| `--nats-tls-cert` | `LOCALAI_NATS_TLS_CERT` | Client certificate for NATS mTLS |
|
||||
| `--nats-tls-key` | `LOCALAI_NATS_TLS_KEY` | Client private key (required with `--nats-tls-cert`) |
|
||||
|
||||
The same env vars apply to backend workers and `local-ai agent-worker`. If the server cert is already trusted by the OS, `tls://` alone is enough.
|
||||
|
||||
**Worker register response** (when minting is enabled and the node is approved):
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "…",
|
||||
"nats_jwt": "eyJ…",
|
||||
"nats_user_seed": "SU…"
|
||||
}
|
||||
```
|
||||
|
||||
Workers connect with that JWT and seed automatically (shown once; store securely). Override with `LOCALAI_NATS_JWT` / `LOCALAI_NATS_USER_SEED` if needed. Set `LOCALAI_NATS_REQUIRE_AUTH=true` on workers when the bus requires credentials.
|
||||
|
||||
When `LOCALAI_NATS_REQUIRE_AUTH=true` and no static credentials are provided, a worker that registers while still **pending admin approval** keeps re-registering (with backoff) until an admin approves it and the frontend mints its JWT — it does not start unauthenticated. This retry is **bounded**: if the node is never approved (or no credentials are minted) after a large number of attempts, the worker exits non-zero so the failure is visible (a crash-looping or failed worker) rather than hanging silently. Minted worker JWTs are also **refreshed automatically** before they expire (the worker re-registers at ~75% of the JWT lifetime), so long-running workers survive past `LOCALAI_NATS_WORKER_JWT_TTL`; the NATS connection picks up the new JWT on its next reconnect. If refresh fails persistently, the worker exits (to restart and re-acquire) rather than drifting toward an expired, unrenewable JWT. Statically configured (`LOCALAI_NATS_JWT`) and service (`LOCALAI_NATS_SERVICE_JWT`) credentials are used as-is and not refreshed.
|
||||
|
||||
Generate operator/account material with [`scripts/nats-auth-setup.sh`](https://github.com/mudler/LocalAI/blob/master/scripts/nats-auth-setup.sh) (requires [nsc](https://docs.nats.io/running-a-nats-service/configuration/securing_nats/auth_intro/nsc)). Configure the NATS server with account resolver JWTs before enabling `LOCALAI_NATS_REQUIRE_AUTH`.
|
||||
|
||||
{{% notice note %}}
|
||||
`LOCALAI_AUTH` (HTTP users/sessions) and NATS JWTs are separate: end-user API keys do not connect to NATS. HTTP registration still uses `LOCALAI_REGISTRATION_TOKEN`.
|
||||
{{% /notice %}}
|
||||
|
||||
### Optional: S3 Object Storage
|
||||
|
||||
For multi-host deployments where workers don't share a filesystem, S3-compatible storage enables distributed file transfer (model files, configs):
|
||||
@@ -134,6 +178,12 @@ local-ai worker \
|
||||
| `--registration-token` | `LOCALAI_REGISTRATION_TOKEN` | *(empty)* | Token to authenticate with the frontend |
|
||||
| `--heartbeat-interval` | `LOCALAI_HEARTBEAT_INTERVAL` | `10s` | Interval between heartbeat pings |
|
||||
| `--nats-url` | `LOCALAI_NATS_URL` | *(required)* | NATS URL for backend installation and file staging |
|
||||
| `--nats-jwt` | `LOCALAI_NATS_JWT` | *(empty)* | Optional override for the `nats_jwt` returned at registration |
|
||||
| `--nats-user-seed` | `LOCALAI_NATS_USER_SEED` | *(empty)* | Optional override for `nats_user_seed` from registration |
|
||||
| `--nats-require-auth` | `LOCALAI_NATS_REQUIRE_AUTH` | `false` | Require NATS JWT+seed (from registration or env) |
|
||||
| `--nats-tls-ca` | `LOCALAI_NATS_TLS_CA` | *(empty)* | PEM file for NATS server CA |
|
||||
| `--nats-tls-cert` | `LOCALAI_NATS_TLS_CERT` | *(empty)* | Client certificate for NATS mTLS |
|
||||
| `--nats-tls-key` | `LOCALAI_NATS_TLS_KEY` | *(empty)* | Client private key for NATS mTLS |
|
||||
| `--backends-path` | `LOCALAI_BACKENDS_PATH` | `./backends` | Path to backend binaries |
|
||||
| `--models-path` | `LOCALAI_MODELS_PATH` | `./models` | Path to model files |
|
||||
|
||||
|
||||
3
go.mod
3
go.mod
@@ -41,7 +41,9 @@ require (
|
||||
github.com/mudler/go-processmanager v0.1.1
|
||||
github.com/mudler/memory v0.0.0-20260406210934-424c1ecf2cf8
|
||||
github.com/mudler/xlog v0.0.6
|
||||
github.com/nats-io/jwt/v2 v2.7.4
|
||||
github.com/nats-io/nats.go v1.52.0
|
||||
github.com/nats-io/nkeys v0.4.15
|
||||
github.com/ollama/ollama v0.20.4
|
||||
github.com/onsi/ginkgo/v2 v2.29.0
|
||||
github.com/onsi/gomega v1.41.0
|
||||
@@ -134,7 +136,6 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/moby/moby/api v1.54.2 // indirect
|
||||
github.com/moby/moby/client v0.4.1 // indirect
|
||||
github.com/nats-io/nkeys v0.4.15 // indirect
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
github.com/oklog/ulid v1.3.1 // indirect
|
||||
github.com/secure-systems-lab/go-securesystemslib v0.9.1 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1016,6 +1016,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A=
|
||||
github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM=
|
||||
github.com/nats-io/jwt/v2 v2.7.4 h1:jXFuDDxs/GQjGDZGhNgH4tXzSUK6WQi2rsj4xmsNOtI=
|
||||
github.com/nats-io/jwt/v2 v2.7.4/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA=
|
||||
github.com/nats-io/nats.go v1.52.0 h1:n3avV4VBsCgsdwh71TppsTwtv+QdPs7ntSKM8qJLGsc=
|
||||
github.com/nats-io/nats.go v1.52.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno=
|
||||
github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4=
|
||||
|
||||
66
pkg/natsauth/config.go
Normal file
66
pkg/natsauth/config.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package natsauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// DefaultWorkerJWTTTL is how long a worker may use a minted NATS user JWT before re-registering.
|
||||
const DefaultWorkerJWTTTL = 24 * time.Hour
|
||||
|
||||
// Config holds NATS JWT authentication settings for distributed mode.
|
||||
type Config struct {
|
||||
// AccountSeed is the NATS account signing seed (SU...). Used to mint per-node worker JWTs.
|
||||
AccountSeed string
|
||||
// ServiceUserJWT is a pre-generated user JWT for frontends and agent workers (broad publish).
|
||||
ServiceUserJWT string
|
||||
// ServiceUserSeed is the signing seed (SU...) paired with ServiceUserJWT.
|
||||
ServiceUserSeed string
|
||||
// WorkerJWTTTL sets expiry on minted worker JWTs. Zero uses DefaultWorkerJWTTTL.
|
||||
WorkerJWTTTL time.Duration
|
||||
// RequireAuth rejects anonymous NATS when true (both ServiceUserJWT and AccountSeed expected).
|
||||
RequireAuth bool
|
||||
}
|
||||
|
||||
// Enabled reports whether any NATS credential material is configured.
|
||||
func (c Config) Enabled() bool {
|
||||
return c.AccountSeed != "" || c.ServiceUserJWT != ""
|
||||
}
|
||||
|
||||
// CanMintWorkers reports whether per-node JWTs can be issued at registration.
|
||||
func (c Config) CanMintWorkers() bool {
|
||||
return c.AccountSeed != ""
|
||||
}
|
||||
|
||||
// WorkerTTL returns the configured worker JWT lifetime.
|
||||
func (c Config) WorkerTTL() time.Duration {
|
||||
if c.WorkerJWTTTL > 0 {
|
||||
return c.WorkerJWTTTL
|
||||
}
|
||||
return DefaultWorkerJWTTTL
|
||||
}
|
||||
|
||||
// Validate checks consistency when distributed NATS auth is required.
|
||||
func (c Config) Validate() error {
|
||||
if !c.RequireAuth {
|
||||
return nil
|
||||
}
|
||||
if c.ServiceUserJWT == "" || c.ServiceUserSeed == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
|
||||
}
|
||||
if c.AccountSeed == "" {
|
||||
return fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH is set but LOCALAI_NATS_ACCOUNT_SEED is empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WarnIfInsecure logs when distributed NATS is reachable without credentials.
|
||||
func (c Config) WarnIfInsecure(distributed bool) {
|
||||
if !distributed || c.Enabled() {
|
||||
return
|
||||
}
|
||||
xlog.Warn("NATS is used without JWT credentials — any client on the bus can publish backend.install. " +
|
||||
"Set LOCALAI_NATS_ACCOUNT_SEED + LOCALAI_NATS_SERVICE_JWT (see docs/features/distributed-mode.md).")
|
||||
}
|
||||
16
pkg/natsauth/decode.go
Normal file
16
pkg/natsauth/decode.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package natsauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
)
|
||||
|
||||
// DecodeUserClaims decodes a minted worker JWT for tests and diagnostics.
|
||||
func DecodeUserClaims(token string) (*jwt.UserClaims, error) {
|
||||
uc, err := jwt.DecodeUserClaims(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("natsauth: decode user JWT: %w", err)
|
||||
}
|
||||
return uc, nil
|
||||
}
|
||||
59
pkg/natsauth/mint.go
Normal file
59
pkg/natsauth/mint.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package natsauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
// MintWorkerJWT creates a signed NATS user JWT and user seed scoped to nodeID and nodeType.
|
||||
// The seed is returned once at registration so the worker can sign NATS connections.
|
||||
func (c Config) MintWorkerJWT(nodeID, nodeType string) (userJWT, userSeed string, err error) {
|
||||
if c.AccountSeed == "" {
|
||||
return "", "", fmt.Errorf("natsauth: account seed not configured")
|
||||
}
|
||||
if nodeID == "" {
|
||||
return "", "", fmt.Errorf("natsauth: node ID is required")
|
||||
}
|
||||
|
||||
accountKP, err := nkeys.FromSeed([]byte(c.AccountSeed))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: invalid account seed: %w", err)
|
||||
}
|
||||
|
||||
userKP, err := nkeys.CreateUser()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: create user key: %w", err)
|
||||
}
|
||||
seedBytes, err := userKP.Seed()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: user seed: %w", err)
|
||||
}
|
||||
|
||||
accountPub, err := accountKP.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: account public key: %w", err)
|
||||
}
|
||||
userPub, err := userKP.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: user public key: %w", err)
|
||||
}
|
||||
|
||||
pubAllow, subAllow := WorkerPermissions(nodeID, nodeType)
|
||||
|
||||
uc := jwt.NewUserClaims(userPub)
|
||||
uc.Name = fmt.Sprintf("localai-%s-%s", nodeType, workerSubjectToken(nodeID))
|
||||
uc.IssuerAccount = accountPub
|
||||
uc.Expires = time.Now().Add(c.WorkerTTL()).Unix()
|
||||
|
||||
uc.Permissions.Pub.Allow = pubAllow
|
||||
uc.Permissions.Sub.Allow = subAllow
|
||||
|
||||
token, err := uc.Encode(accountKP)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("natsauth: encode user JWT: %w", err)
|
||||
}
|
||||
return token, string(seedBytes), nil
|
||||
}
|
||||
60
pkg/natsauth/mint_test.go
Normal file
60
pkg/natsauth/mint_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package natsauth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestNatsAuth(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "NatsAuth")
|
||||
}
|
||||
|
||||
var _ = Describe("MintWorkerJWT", func() {
|
||||
var accountSeed string
|
||||
|
||||
BeforeEach(func() {
|
||||
akp, err := nkeys.CreateAccount()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
seed, err := akp.Seed()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
accountSeed = string(seed)
|
||||
})
|
||||
|
||||
It("mints a JWT with backend worker permissions", func() {
|
||||
cfg := natsauth.Config{AccountSeed: accountSeed, WorkerJWTTTL: time.Hour}
|
||||
token, seed, err := cfg.MintWorkerJWT("550e8400-e29b-41d4-a716-446655440000", "backend")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(token).NotTo(BeEmpty())
|
||||
Expect(seed).NotTo(BeEmpty())
|
||||
|
||||
uc, err := jwt.DecodeUserClaims(token)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(uc.Permissions.Sub.Allow).To(ContainElement("nodes.550e8400-e29b-41d4-a716-446655440000.>"))
|
||||
Expect(uc.Permissions.Pub.Allow).To(ContainElement("nodes.550e8400-e29b-41d4-a716-446655440000.backend.install.*.progress"))
|
||||
})
|
||||
|
||||
It("mints agent permissions without backend install subscribe", func() {
|
||||
cfg := natsauth.Config{AccountSeed: accountSeed}
|
||||
token, _, err := cfg.MintWorkerJWT("node-1", "agent")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
uc, err := jwt.DecodeUserClaims(token)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(uc.Permissions.Sub.Allow).To(ContainElement("agent.execute"))
|
||||
for _, subj := range uc.Permissions.Sub.Allow {
|
||||
Expect(subj).NotTo(ContainSubstring("backend.install"))
|
||||
}
|
||||
})
|
||||
|
||||
It("rejects mint without account seed", func() {
|
||||
_, _, err := (natsauth.Config{}).MintWorkerJWT("id", "backend")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
49
pkg/natsauth/permissions.go
Normal file
49
pkg/natsauth/permissions.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package natsauth
|
||||
|
||||
import "strings"
|
||||
|
||||
// workerSubjectToken mirrors messaging.sanitizeSubjectToken without importing unexported logic.
|
||||
func workerSubjectToken(nodeID string) string {
|
||||
r := strings.NewReplacer(".", "-", "*", "-", ">", "-", " ", "-", "\t", "-", "\n", "-")
|
||||
return r.Replace(nodeID)
|
||||
}
|
||||
|
||||
// WorkerPermissions returns NATS pub/sub allow lists for a registered node.
|
||||
func WorkerPermissions(nodeID, nodeType string) (pubAllow, subAllow []string) {
|
||||
tok := workerSubjectToken(nodeID)
|
||||
prefix := "nodes." + tok
|
||||
|
||||
switch nodeType {
|
||||
case "agent":
|
||||
// Agent workers consume queue workloads; they must not handle backend.install.
|
||||
// Keep this list in sync with the subscriptions in core/cli/agent_worker.go.
|
||||
subAllow = []string{
|
||||
"agent.execute",
|
||||
"jobs.*.cancel",
|
||||
"jobs.*.progress",
|
||||
"jobs.*.result",
|
||||
"jobs.mcp-ci.new", // MCP CI jobs dispatched to agent workers
|
||||
"mcp.tools.execute",
|
||||
"mcp.discovery",
|
||||
prefix + ".backend.stop", // stop events drive MCP session cleanup
|
||||
"_INBOX.>",
|
||||
}
|
||||
pubAllow = []string{
|
||||
"agent.>",
|
||||
"jobs.>",
|
||||
"_INBOX.>",
|
||||
}
|
||||
default:
|
||||
// Backend worker: lifecycle + file staging on this node only.
|
||||
subAllow = []string{
|
||||
prefix + ".>",
|
||||
"_INBOX.>",
|
||||
}
|
||||
pubAllow = []string{
|
||||
prefix + ".backend.install.*.progress",
|
||||
prefix + ".files.>",
|
||||
"_INBOX.>",
|
||||
}
|
||||
}
|
||||
return pubAllow, subAllow
|
||||
}
|
||||
134
pkg/natsauth/permissions_coverage_test.go
Normal file
134
pkg/natsauth/permissions_coverage_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package natsauth_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// subjectMatches implements NATS subject-token matching: "*" matches exactly one
|
||||
// token and ">" matches one or more trailing tokens. It lets these tests assert
|
||||
// that a permission allow-list (which uses wildcards) actually covers a concrete
|
||||
// subject a component publishes/subscribes — the same check the NATS server makes.
|
||||
func subjectMatches(pattern, subject string) bool {
|
||||
p := strings.Split(pattern, ".")
|
||||
s := strings.Split(subject, ".")
|
||||
for i, tok := range p {
|
||||
if tok == ">" {
|
||||
return i < len(s) // ">" must match at least one remaining token
|
||||
}
|
||||
if i >= len(s) {
|
||||
return false
|
||||
}
|
||||
if tok != "*" && tok != s[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(p) == len(s)
|
||||
}
|
||||
|
||||
func anyAllows(allow []string, subject string) bool {
|
||||
for _, p := range allow {
|
||||
if subjectMatches(p, subject) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var _ = Describe("WorkerPermissions subject coverage", func() {
|
||||
// A node ID containing NATS-reserved characters exercises the (duplicated)
|
||||
// sanitizer in pkg/natsauth against the canonical one in core/services/messaging.
|
||||
// If the two ever diverge, the minted prefix stops matching the real subject
|
||||
// and these assertions fail — guarding the copy noted in the review.
|
||||
const nodeID = "host.a 1*b"
|
||||
|
||||
Context("backend worker", func() {
|
||||
pub, sub := natsauth.WorkerPermissions(nodeID, "backend")
|
||||
|
||||
// Every subject core/services/worker/{lifecycle,file_staging}.go subscribes to.
|
||||
subscribed := []string{
|
||||
messaging.SubjectNodeBackendInstall(nodeID),
|
||||
messaging.SubjectNodeBackendUpgrade(nodeID),
|
||||
messaging.SubjectNodeBackendStop(nodeID),
|
||||
messaging.SubjectNodeBackendDelete(nodeID),
|
||||
messaging.SubjectNodeBackendList(nodeID),
|
||||
messaging.SubjectNodeModelUnload(nodeID),
|
||||
messaging.SubjectNodeModelDelete(nodeID),
|
||||
messaging.SubjectNodeStop(nodeID),
|
||||
messaging.SubjectNodeFilesEnsure(nodeID),
|
||||
messaging.SubjectNodeFilesStage(nodeID),
|
||||
messaging.SubjectNodeFilesTemp(nodeID),
|
||||
messaging.SubjectNodeFilesListDir(nodeID),
|
||||
}
|
||||
for _, subject := range subscribed {
|
||||
It("allows subscribing to "+subject, func() {
|
||||
Expect(anyAllows(sub, subject)).To(BeTrue(),
|
||||
"backend JWT sub allow-list %v does not cover %s", sub, subject)
|
||||
})
|
||||
}
|
||||
|
||||
It("allows publishing backend.install progress", func() {
|
||||
subject := messaging.SubjectNodeBackendInstallProgress(nodeID, "op-123")
|
||||
Expect(anyAllows(pub, subject)).To(BeTrue(),
|
||||
"backend JWT pub allow-list %v does not cover %s", pub, subject)
|
||||
})
|
||||
})
|
||||
|
||||
Context("agent worker", func() {
|
||||
// node_type "agent"; subjects from core/cli/agent_worker.go.
|
||||
pub, sub := natsauth.WorkerPermissions(nodeID, "agent")
|
||||
_ = pub
|
||||
|
||||
subscribed := []string{
|
||||
messaging.SubjectAgentExecute, // dispatcher (default --agent-subject)
|
||||
messaging.SubjectMCPToolExecute, // QueueSubscribeReply
|
||||
messaging.SubjectMCPDiscovery, // QueueSubscribeReply
|
||||
messaging.SubjectMCPCIJobsNew, // QueueSubscribe — jobs.mcp-ci.new
|
||||
messaging.SubjectNodeBackendStop(nodeID), // Subscribe — MCP session cleanup
|
||||
}
|
||||
for _, subject := range subscribed {
|
||||
It("allows subscribing to "+subject, func() {
|
||||
Expect(anyAllows(sub, subject)).To(BeTrue(),
|
||||
"agent JWT sub allow-list %v does not cover %s — the agent worker subscribes to it", sub, subject)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
var allowPubRe = regexp.MustCompile(`--allow-pub "([^"]*)"`)
|
||||
|
||||
var _ = Describe("Documented NATS service-user permissions", func() {
|
||||
// scripts/nats-auth-setup.sh ships the recommended service (frontend) JWT
|
||||
// permissions. They must cover every subject the frontend actually publishes,
|
||||
// or prefix-cache sync (and friends) break once LOCALAI_NATS_REQUIRE_AUTH is on.
|
||||
const scriptPath = "../../scripts/nats-auth-setup.sh"
|
||||
|
||||
// Representative subjects the frontend publishes on the control plane.
|
||||
// prefixcache.* is emitted by prefixcache.Sync in core/application/distributed.go.
|
||||
frontendPublishes := []string{
|
||||
messaging.SubjectPrefixCacheObserve,
|
||||
messaging.SubjectPrefixCacheInvalidate,
|
||||
messaging.SubjectNodeBackendInstall("node-1"),
|
||||
messaging.SubjectGalleryProgress("op-1"),
|
||||
}
|
||||
|
||||
It("cover every subject the frontend publishes", func() {
|
||||
raw, err := os.ReadFile(scriptPath)
|
||||
Expect(err).ToNot(HaveOccurred(), "cannot read %s", scriptPath)
|
||||
m := allowPubRe.FindStringSubmatch(string(raw))
|
||||
Expect(m).To(HaveLen(2), "no --allow-pub list found in %s", scriptPath)
|
||||
allow := strings.Split(m[1], ",")
|
||||
|
||||
for _, subject := range frontendPublishes {
|
||||
Expect(anyAllows(allow, subject)).To(BeTrue(),
|
||||
"service-user --allow-pub %v does not cover %s (frontend publishes it)", allow, subject)
|
||||
}
|
||||
})
|
||||
})
|
||||
49
scripts/nats-auth-setup.sh
Executable file
49
scripts/nats-auth-setup.sh
Executable file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env bash
|
||||
# Generate NATS account + service user JWTs for LocalAI distributed mode.
|
||||
#
|
||||
# Requires: nsc (https://docs.nats.io/running-a-nats-service/configuration/securing_nats/auth_intro/nsc)
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/nats-auth-setup.sh
|
||||
#
|
||||
# Outputs operator/account seeds and a service user JWT suitable for:
|
||||
# LOCALAI_NATS_ACCOUNT_SEED
|
||||
# LOCALAI_NATS_SERVICE_JWT
|
||||
#
|
||||
# Per-node worker JWTs are minted automatically by the frontend at registration
|
||||
# when LOCALAI_NATS_ACCOUNT_SEED is set.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if ! command -v nsc >/dev/null 2>&1; then
|
||||
echo "nsc is required. Install from https://github.com/nats-io/nsc/releases" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
OPERATOR="${NATS_OPERATOR_NAME:-localai-operator}"
|
||||
ACCOUNT="${NATS_ACCOUNT_NAME:-localai}"
|
||||
SERVICE_USER="${NATS_SERVICE_USER:-localai-frontend}"
|
||||
|
||||
nsc add operator -n "$OPERATOR" --generate-signing-key
|
||||
nsc add account -n "$ACCOUNT"
|
||||
nsc add user -n "$SERVICE_USER" --account "$ACCOUNT"
|
||||
|
||||
# Broad publish for frontend control plane (tighten with custom claims in production).
|
||||
nsc edit user -n "$SERVICE_USER" --account "$ACCOUNT" \
|
||||
--allow-pub "nodes.>,gallery.>,agent.>,jobs.>,mcp.>,cache.>,prefixcache.>,finetune.>" \
|
||||
--allow-sub "nodes.>,gallery.>,agent.>,jobs.>,mcp.>,cache.>,prefixcache.>,_INBOX.>"
|
||||
|
||||
KEYS_DIR="${NATS_KEYS_DIR:-./nats-keys}"
|
||||
mkdir -p "$KEYS_DIR"
|
||||
nsc generate creds -a "$ACCOUNT" -n "$SERVICE_USER" -o "$KEYS_DIR"
|
||||
|
||||
ACCOUNT_SEED=$(nsc describe account "$ACCOUNT" -o json | jq -r '.nats.private_key')
|
||||
SERVICE_JWT=$(cat "$KEYS_DIR/${ACCOUNT}/${SERVICE_USER}.jwt" 2>/dev/null || cat "$KEYS_DIR/${SERVICE_USER}.jwt")
|
||||
|
||||
echo ""
|
||||
echo "=== LocalAI NATS auth material ==="
|
||||
echo "LOCALAI_NATS_ACCOUNT_SEED=${ACCOUNT_SEED}"
|
||||
echo "LOCALAI_NATS_SERVICE_JWT=${SERVICE_JWT}"
|
||||
echo ""
|
||||
echo "Configure the NATS server with the generated operator/account JWTs under $KEYS_DIR"
|
||||
echo "and set LOCALAI_NATS_REQUIRE_AUTH=true on frontends and workers in production."
|
||||
156
tests/e2e/distributed/nats_jwt_helpers_test.go
Normal file
156
tests/e2e/distributed/nats_jwt_helpers_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package distributed_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
tcnats "github.com/testcontainers/testcontainers-go/modules/nats"
|
||||
)
|
||||
|
||||
// JWTTestInfra holds a NATS server configured with JWT auth and minted worker credentials.
|
||||
type JWTTestInfra struct {
|
||||
*TestInfra
|
||||
AccountSeed string
|
||||
NodeID string
|
||||
WorkerJWT string
|
||||
WorkerSeed string
|
||||
}
|
||||
|
||||
// SetupJWTInfra starts NATS with an in-memory JWT resolver and returns worker credentials
|
||||
// minted the same way as node registration (pkg/natsauth).
|
||||
func SetupJWTInfra() *JWTTestInfra {
|
||||
GinkgoHelper()
|
||||
|
||||
infra := &JWTTestInfra{TestInfra: &TestInfra{Ctx: context.Background()}}
|
||||
|
||||
operatorJWT, accountJWT, accountSeed, err := jwtResolverMaterial()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
infra.AccountSeed = accountSeed
|
||||
|
||||
conf := fmt.Sprintf(`listen: 0.0.0.0:4222
|
||||
|
||||
operator: %s
|
||||
|
||||
resolver: MEMORY
|
||||
resolver_preload: {
|
||||
%s: %s
|
||||
}
|
||||
`, operatorJWT, accountPublicKeyFromSeed(accountSeed), accountJWT)
|
||||
|
||||
var natsContainer *tcnats.NATSContainer
|
||||
// Override default testcontainers -js: JetStream fails without a system account in JWT mode.
|
||||
natsContainer, err = tcnats.Run(infra.Ctx, "nats:2-alpine",
|
||||
tcnats.WithConfigFile(bytes.NewBufferString(conf)),
|
||||
testcontainers.WithCmd("-c", "/etc/nats.conf"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
infra.NATSContainer = natsContainer
|
||||
|
||||
infra.NatsURL, err = infra.NATSContainer.ConnectionString(infra.Ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
infra.NodeID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
cfg := natsauth.Config{AccountSeed: infra.AccountSeed, WorkerJWTTTL: time.Hour}
|
||||
infra.WorkerJWT, infra.WorkerSeed, err = cfg.MintWorkerJWT(infra.NodeID, "backend")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
infra.NC, err = messaging.New(infra.NatsURL, messaging.WithUserJWT(infra.WorkerJWT, infra.WorkerSeed))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
DeferCleanup(func() {
|
||||
if infra.NC != nil {
|
||||
infra.NC.Close()
|
||||
}
|
||||
if infra.NATSContainer != nil {
|
||||
_ = infra.NATSContainer.Terminate(context.Background())
|
||||
}
|
||||
})
|
||||
|
||||
return infra
|
||||
}
|
||||
|
||||
// jwtResolverMaterial builds operator + account JWTs for a MEMORY resolver.
|
||||
// Follows the NATS JWT tutorial: self-signed account, then operator re-sign, with the
|
||||
// account identity key listed as a signing key so MintWorkerJWT can use the account seed.
|
||||
func jwtResolverMaterial() (operatorJWT, accountJWT, accountSeed string, err error) {
|
||||
okp, err := nkeys.CreateOperator()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
opk, err := okp.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
oc := jwt.NewOperatorClaims(opk)
|
||||
oc.Name = "localai-test-operator"
|
||||
oskp, err := nkeys.CreateOperator()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
ospk, err := oskp.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
oc.SigningKeys.Add(ospk)
|
||||
operatorJWT, err = oc.Encode(okp)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
akp, err := nkeys.CreateAccount()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
seed, err := akp.Seed()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
accountSeed = string(seed)
|
||||
|
||||
apk, err := akp.PublicKey()
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
ac := jwt.NewAccountClaims(apk)
|
||||
ac.Name = "localai-test-account"
|
||||
ac.SigningKeys.Add(apk)
|
||||
accountJWT, err = ac.Encode(akp)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
ac, err = jwt.DecodeAccountClaims(accountJWT)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
accountJWT, err = ac.Encode(oskp)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
return operatorJWT, accountJWT, accountSeed, nil
|
||||
}
|
||||
|
||||
func accountPublicKeyFromSeed(accountSeed string) string {
|
||||
akp, err := nkeys.FromSeed([]byte(accountSeed))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
pk, err := akp.PublicKey()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return pk
|
||||
}
|
||||
|
||||
// nodeSubjectPrefix returns the sanitized nodes.* prefix for a node ID.
|
||||
func nodeSubjectPrefix(nodeID string) string {
|
||||
tok := strings.NewReplacer(".", "-", "*", "-", ">", "-", " ", "-", "\t", "-", "\n", "-").Replace(nodeID)
|
||||
return "nodes." + tok
|
||||
}
|
||||
99
tests/e2e/distributed/nats_jwt_test.go
Normal file
99
tests/e2e/distributed/nats_jwt_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package distributed_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/pkg/natsauth"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("NATS JWT Auth", Label("Distributed", "NatsJWT"), func() {
|
||||
var infra *JWTTestInfra
|
||||
|
||||
BeforeEach(func() {
|
||||
infra = SetupJWTInfra()
|
||||
})
|
||||
|
||||
It("connects with a minted backend worker JWT and publishes on allowed subjects", func() {
|
||||
// Backend workers may publish under nodes.<id>.files.> (see pkg/natsauth permissions).
|
||||
subject := nodeSubjectPrefix(infra.NodeID) + ".files.in"
|
||||
Expect(infra.NC.Publish(subject, map[string]string{"path": "/tmp/model"})).To(Succeed())
|
||||
Expect(infra.NC.Conn().FlushTimeout(2 * time.Second)).To(Succeed())
|
||||
Expect(infra.NC.Conn().IsConnected()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("allows backend subscribe on the node prefix", func() {
|
||||
wild := nodeSubjectPrefix(infra.NodeID) + ".>"
|
||||
sub, err := infra.NC.Subscribe(wild, func(_ []byte) {})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer func() { _ = sub.Unsubscribe() }()
|
||||
Expect(infra.NC.Conn().FlushTimeout(2 * time.Second)).To(Succeed())
|
||||
Expect(infra.NC.Conn().IsConnected()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects anonymous publish on the JWT-enabled server", func() {
|
||||
anon, err := messaging.New(infra.NatsURL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer anon.Close()
|
||||
|
||||
err = anon.Publish("nodes.any.files.x", map[string]string{"x": "1"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(anon.Conn().FlushTimeout(2 * time.Second)).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("denies backend publish to another node's subjects", func() {
|
||||
other := nodeSubjectPrefix("other-node-id") + ".files.stage"
|
||||
Expect(infra.NC.Publish(other, map[string]string{"stage": "nope"})).To(Succeed())
|
||||
Eventually(func() error {
|
||||
_ = infra.NC.Conn().FlushTimeout(500 * time.Millisecond)
|
||||
return infra.NC.Conn().LastError()
|
||||
}, "3s", "50ms").Should(HaveOccurred())
|
||||
})
|
||||
|
||||
It("mints agent JWT without backend.install in claims", func() {
|
||||
cfg := natsauth.Config{AccountSeed: infra.AccountSeed}
|
||||
token, _, err := cfg.MintWorkerJWT("agent-node-1", "agent")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
claims, err := natsauth.DecodeUserClaims(token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(claims.Permissions.Sub.Allow).To(ContainElement("agent.execute"))
|
||||
for _, subj := range claims.Permissions.Sub.Allow {
|
||||
Expect(subj).NotTo(ContainSubstring("backend.install"))
|
||||
}
|
||||
})
|
||||
|
||||
// Regression guard for the silent permission gaps: decoding the JWT claims
|
||||
// (above) only proves the agent JWT is *restrictive*, not that it is
|
||||
// *sufficient*. Stand a real agent connection up against the enforcing
|
||||
// server and exercise every subscription core/cli/agent_worker.go actually
|
||||
// makes — a denied SUB now surfaces synchronously via confirmSubscription,
|
||||
// so a missing allow rule fails this test instead of silently dropping
|
||||
// backend.stop / MCP-CI deliveries at runtime.
|
||||
It("lets an agent-minted JWT establish all the subscriptions the agent worker uses", func() {
|
||||
const nodeID = "agent-node-subs"
|
||||
cfg := natsauth.Config{AccountSeed: infra.AccountSeed, WorkerJWTTTL: time.Hour}
|
||||
token, seed, err := cfg.MintWorkerJWT(nodeID, "agent")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
nc, err := messaging.New(infra.NatsURL, messaging.WithUserJWT(token, seed))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
DeferCleanup(nc.Close)
|
||||
|
||||
// Mirror core/cli/agent_worker.go exactly.
|
||||
_, err = nc.QueueSubscribeReply(messaging.SubjectMCPToolExecute, messaging.QueueAgentWorkers, func([]byte, func([]byte)) {})
|
||||
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s", messaging.SubjectMCPToolExecute)
|
||||
|
||||
_, err = nc.QueueSubscribeReply(messaging.SubjectMCPDiscovery, messaging.QueueAgentWorkers, func([]byte, func([]byte)) {})
|
||||
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s", messaging.SubjectMCPDiscovery)
|
||||
|
||||
_, err = nc.QueueSubscribe(messaging.SubjectMCPCIJobsNew, messaging.QueueWorkers, func([]byte) {})
|
||||
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s (MCP CI jobs)", messaging.SubjectMCPCIJobsNew)
|
||||
|
||||
_, err = nc.Subscribe(messaging.SubjectNodeBackendStop(nodeID), func([]byte) {})
|
||||
Expect(err).ToNot(HaveOccurred(), "agent JWT must allow %s (MCP session cleanup)", messaging.SubjectNodeBackendStop(nodeID))
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user