Files
LocalAI/core/application/distributed.go
Richard Palethorpe 3a932a9803 feat(distributed): Add NATS JWT authentication and TLS/mTLS options (#10159)
* feat(distributed): NATS JWT auth, TLS/mTLS options, and e2e coverage

Mint per-node NATS user JWTs at registration when LOCALAI_NATS_ACCOUNT_SEED
is set, and connect workers with scoped credentials from the register response.
Add optional LOCALAI_NATS_TLS_CA/CERT/KEY for private CA and mTLS alongside
tls:// URLs, plus test-e2e-distributed and NatsJWT container e2e specs.

Document JWT setup (nats-auth-setup.sh) and TLS env vars in distributed-mode.

Assisted-by: Grok:grok grok-build
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(distributed): correct NATS JWT scoping and harden client auth

The JWT-auth path added in 46467cc7 had several gaps that fail silently
under LOCALAI_NATS_REQUIRE_AUTH:

- Agent-worker minted JWTs did not allow the subjects the agent worker
  actually subscribes to (jobs.mcp-ci.new and nodes.<id>.backend.stop),
  so MCP-CI jobs and backend-stop session cleanup were silently dropped.
  Scope the agent permission set to those subjects.
- NATS subscription permission violations were swallowed (Subscribe
  returned a live-but-dead subscription). Confirm subscriptions with a
  server round-trip so a denial surfaces synchronously, and log async
  permission errors.
- The backend worker connected anonymously when given a JWT without its
  paired seed; reject the unpaired credential instead.
- The documented service-user permissions in nats-auth-setup.sh omitted
  prefixcache.>, which the frontend publishes and subscribes; add it.

Also: add a credential-provider hook to the messaging client (consumed by
the follow-up credential-lifecycle change), drop the always-nil error from
NatsMessagingOptions, run go mod tidy (jwt/v2 and nkeys are now direct),
and gofmt the feature's files.

Tests: an agent-JWT e2e spec that connects to the enforcing NATS server
and exercises every subscription the agent worker makes, plus permission
allow-list coverage unit tests.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(distributed): acquire and auto-refresh worker NATS credentials

Workers fetched NATS credentials once at startup, which broke two cases
under JWT auth: a worker that registered while still pending admin
approval never received a minted JWT (it connected unauthenticated and
gave up), and a long-running worker's 24h JWT expired with no way to renew
it.

Introduce workerregistry.NATSCredentialManager, built on idempotent
re-registration (the frontend preserves the node row and mints a fresh JWT
each call):

- Acquire re-registers through admin approval until the node is approved
  and credentials are minted (or returns the first success when auth is
  not required, preserving anonymous-NATS behavior).
- RefreshLoop re-registers before the JWT expires (~75% of its lifetime),
  updating the credentials served to the connection.
- Both are bounded (default 100 attempts / consecutive failures) and
  return an error on exhaustion, so an unapprovable or unrenewable worker
  exits non-zero and surfaces the problem instead of hanging or drifting
  toward an expired credential.

The messaging client gains WithUserJWTProvider, fetching credentials on
each (re)connect so the connection transparently adopts a refreshed JWT
when the server expires the old one. RegisterFull exposes the approval
status and full response; Register delegates to it.

Both the backend worker and the agent worker are wired to this: explicit
env credentials are used as-is, minted credentials are acquired-with-wait
and refreshed, and a permanent refresh failure shuts the worker down so it
restarts and re-acquires.

Tests cover Acquire (wait-through-pending, bounded give-up, context
cancel), RefreshLoop (refresh-before-expiry, bounded failure, no-expiry
exit) and jwtExpiry decoding. Docs updated in distributed-mode.md.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-06-03 19:43:56 +02:00

389 lines
14 KiB
Go

package application
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/agents"
"github.com/mudler/LocalAI/core/services/distributed"
"github.com/mudler/LocalAI/core/services/jobs"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/core/services/storage"
"github.com/mudler/LocalAI/pkg/distributedhdr"
"github.com/mudler/LocalAI/pkg/sanitize"
"github.com/mudler/xlog"
"gorm.io/gorm"
)
// DistributedServices holds all services initialized for distributed mode.
type DistributedServices struct {
Nats *messaging.Client
Store storage.ObjectStore
Registry *nodes.NodeRegistry
Router *nodes.SmartRouter
Health *nodes.HealthMonitor
Reconciler *nodes.ReplicaReconciler
JobStore *jobs.JobStore
Dispatcher *jobs.Dispatcher
AgentStore *agents.AgentStore
AgentBridge *agents.EventBridge
DistStores *distributed.Stores
FileMgr *storage.FileManager
FileStager nodes.FileStager
ModelAdapter *nodes.ModelRouterAdapter
Unloader *nodes.RemoteUnloaderAdapter
shutdownOnce sync.Once
}
// Shutdown stops all distributed services in reverse initialization order.
// It is safe to call on a nil receiver and is idempotent (uses sync.Once).
func (ds *DistributedServices) Shutdown() {
if ds == nil {
return
}
ds.shutdownOnce.Do(func() {
if ds.Health != nil {
ds.Health.Stop()
}
if ds.Dispatcher != nil {
ds.Dispatcher.Stop()
}
if closer, ok := ds.Store.(io.Closer); ok {
closer.Close()
}
// AgentBridge has no Close method — its NATS subscriptions are cleaned up
// when the NATS client is closed below.
if ds.Nats != nil {
ds.Nats.Close()
}
xlog.Info("Distributed services shut down")
})
}
// initDistributed validates distributed mode prerequisites and initializes
// NATS, object storage, node registry, and instance identity.
// Returns nil if distributed mode is not enabled.
// configLoader is used by the SmartRouter to compute concurrency-group
// anti-affinity at placement time (#9659); it may be nil in tests.
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoader *config.ModelConfigLoader) (*DistributedServices, error) {
if !cfg.Distributed.Enabled {
return nil, nil
}
xlog.Info("Distributed mode enabled — validating prerequisites")
// Validate distributed config (NATS URL, S3 credential pairing, durations, etc.)
if err := cfg.Distributed.Validate(); err != nil {
return nil, err
}
// Validate PostgreSQL is configured (auth DB must be PostgreSQL for distributed mode)
if !cfg.Auth.Enabled {
return nil, fmt.Errorf("distributed mode requires authentication to be enabled (--auth / LOCALAI_AUTH=true)")
}
if !isPostgresURL(cfg.Auth.DatabaseURL) {
return nil, fmt.Errorf("distributed mode requires PostgreSQL for auth database (got %q)", sanitize.URL(cfg.Auth.DatabaseURL))
}
// Generate instance ID if not set
if cfg.Distributed.InstanceID == "" {
cfg.Distributed.InstanceID = uuid.New().String()
}
xlog.Info("Distributed instance", "id", cfg.Distributed.InstanceID)
// Connect to NATS
natsAuth := cfg.Distributed.NatsAuthConfig()
if natsAuth.RequireAuth && (natsAuth.ServiceUserJWT == "" || natsAuth.ServiceUserSeed == "") {
return nil, fmt.Errorf("LOCALAI_NATS_REQUIRE_AUTH requires LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED")
}
natsOpts := cfg.Distributed.NatsMessagingOptions("", "")
natsClient, err := messaging.New(cfg.Distributed.NatsURL, natsOpts...)
if err != nil {
return nil, fmt.Errorf("connecting to NATS: %w", err)
}
xlog.Info("Connected to NATS", "url", sanitize.URL(cfg.Distributed.NatsURL))
// Ensure NATS is closed if any subsequent initialization step fails.
success := false
defer func() {
if !success {
natsClient.Close()
}
}()
// Initialize object storage
var store storage.ObjectStore
if cfg.Distributed.StorageURL != "" {
if cfg.Distributed.StorageBucket == "" {
return nil, fmt.Errorf("distributed storage bucket must be set when storage URL is configured")
}
s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{
Endpoint: cfg.Distributed.StorageURL,
Region: cfg.Distributed.StorageRegion,
Bucket: cfg.Distributed.StorageBucket,
AccessKeyID: cfg.Distributed.StorageAccessKey,
SecretAccessKey: cfg.Distributed.StorageSecretKey,
ForcePathStyle: true, // required for MinIO
})
if err != nil {
return nil, fmt.Errorf("initializing S3 storage: %w", err)
}
xlog.Info("Object storage initialized (S3)", "endpoint", cfg.Distributed.StorageURL, "bucket", cfg.Distributed.StorageBucket)
store = s3Store
} else {
// Fallback to filesystem storage in distributed mode (useful for single-node testing)
fsStore, err := storage.NewFilesystemStore(cfg.DataPath + "/objectstore")
if err != nil {
return nil, fmt.Errorf("initializing filesystem storage: %w", err)
}
xlog.Info("Object storage initialized (filesystem fallback)", "path", cfg.DataPath+"/objectstore")
store = fsStore
}
// Initialize node registry (requires the auth DB which is PostgreSQL)
if authDB == nil {
return nil, fmt.Errorf("distributed mode requires auth database to be initialized first")
}
registry, err := nodes.NewNodeRegistry(authDB)
if err != nil {
return nil, fmt.Errorf("initializing node registry: %w", err)
}
xlog.Info("Node registry initialized")
// Collect SmartRouter option values; the router itself is created after all
// dependencies (including FileStager and Unloader) are ready.
var routerAuthToken string
if cfg.Distributed.RegistrationToken != "" {
routerAuthToken = cfg.Distributed.RegistrationToken
}
var routerGalleriesJSON string
if galleriesJSON, err := json.Marshal(cfg.BackendGalleries); err == nil {
routerGalleriesJSON = string(galleriesJSON)
}
healthMon := nodes.NewHealthMonitor(registry, authDB,
cfg.Distributed.HealthCheckIntervalOrDefault(),
cfg.Distributed.StaleNodeThresholdOrDefault(),
routerAuthToken,
!cfg.Distributed.DisablePerModelHealthCheck,
)
// Initialize job store
jobStore, err := jobs.NewJobStore(authDB)
if err != nil {
return nil, fmt.Errorf("initializing job store: %w", err)
}
xlog.Info("Distributed job store initialized")
// Initialize job dispatcher
dispatcher := jobs.NewDispatcher(jobStore, natsClient, authDB, cfg.Distributed.InstanceID, cfg.Distributed.JobWorkerConcurrency)
// Initialize agent store
agentStore, err := agents.NewAgentStore(authDB)
if err != nil {
return nil, fmt.Errorf("initializing agent store: %w", err)
}
xlog.Info("Distributed agent store initialized")
// Initialize agent event bridge
agentBridge := agents.NewEventBridge(natsClient, agentStore, cfg.Distributed.InstanceID)
// Start observable persister — captures observable_update events from workers
// (which have no DB access) and persists them to PostgreSQL.
if err := agentBridge.StartObservablePersister(); err != nil {
xlog.Warn("Failed to start observable persister", "error", err)
} else {
xlog.Info("Observable persister started")
}
// Initialize Phase 4 stores (MCP, Gallery, FineTune, Skills)
distStores, err := distributed.InitStores(authDB)
if err != nil {
return nil, fmt.Errorf("initializing distributed stores: %w", err)
}
// Initialize file manager with local cache
cacheDir := cfg.DataPath + "/cache"
fileMgr, err := storage.NewFileManager(store, cacheDir)
if err != nil {
return nil, fmt.Errorf("initializing file manager: %w", err)
}
xlog.Info("File manager initialized", "cacheDir", cacheDir)
// Create FileStager for distributed file transfer
var fileStager nodes.FileStager
if cfg.Distributed.StorageURL != "" {
fileStager = nodes.NewS3NATSFileStager(fileMgr, natsClient)
xlog.Info("File stager initialized (S3+NATS)")
} else {
fileStager = nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
node, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
if node.HTTPAddress == "" {
return "", fmt.Errorf("node %s has no HTTP address for file transfer", nodeID)
}
return node.HTTPAddress, nil
}, cfg.Distributed.RegistrationToken)
xlog.Info("File stager initialized (HTTP direct transfer)")
}
// Create RemoteUnloaderAdapter — needed by SmartRouter and startup.go
remoteUnloader := nodes.NewRemoteUnloaderAdapter(
registry,
natsClient,
cfg.Distributed.BackendInstallTimeoutOrDefault(),
cfg.Distributed.BackendUpgradeTimeoutOrDefault(),
)
// Prefix-cache-aware routing. Enabled by default; an operator can opt out
// with --distributed-prefix-cache=false, which leaves prefixProvider and
// pressure nil so the SmartRouter and reconciler behave exactly as the
// round-robin floor (true no-op). When enabled we build the local index,
// wrap it in a NATS-backed Sync (publishes our observations, applies peers'
// via the subscriptions below), install the extraction hook used by
// core/backend/llm.go, and run a background eviction ticker on the app ctx.
var prefixProvider prefixcache.Provider
var pressure *prefixcache.Pressure
var prefixCfg prefixcache.Config
if !cfg.Distributed.PrefixCacheDisabled {
prefixCfg = prefixcache.DefaultConfig()
if cfg.Distributed.PrefixCacheTTL > 0 {
prefixCfg.TTL = cfg.Distributed.PrefixCacheTTL
}
if err := prefixCfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid prefix-cache configuration: %w", err)
}
idx := prefixcache.NewIndex(prefixCfg)
prefixSync := prefixcache.NewSync(idx, natsClient)
pressure = prefixcache.NewPressure(prefixCfg.PressureWindow)
prefixProvider = prefixSync
// Invalidate the prefix-cache index whenever a replica row is removed.
// SetReplicaRemovedHook fires from the single chokepoint all removal paths
// funnel through (RemoveNodeModel / RemoveAllNodeModelReplicas), so this
// one hook covers every path: reconciler scale-down, probe reaper,
// health-monitor reap, RemoteUnloaderAdapter, and the router. Registering
// it only inside this enabled block keeps the disabled path a true no-op
// (the registry stays hook-less).
registry.SetReplicaRemovedHook(func(model, node string, replica int) {
if replica < 0 {
prefixSync.InvalidateNode(model, node)
} else {
prefixSync.Invalidate(model, prefixcache.ReplicaKey{NodeID: node, Replica: replica})
}
})
distributedhdr.PrefixChainHook = func(model, prompt string) []uint64 {
return prefixcache.ExtractChain(model, prompt, prefixCfg)
}
// Apply peers' observations/invalidations to the same Sync. ApplyObserve
// and ApplyInvalidate update only the local index and do not re-publish,
// so there is no broadcast loop.
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheObserve, func(ev messaging.PrefixCacheObserveEvent) {
prefixSync.ApplyObserve(ev, time.Now())
}); err != nil {
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheObserve, err)
}
if _, err := messaging.SubscribeJSON(natsClient, messaging.SubjectPrefixCacheInvalidate, func(ev messaging.PrefixCacheInvalidateEvent) {
prefixSync.ApplyInvalidate(ev)
}); err != nil {
return nil, fmt.Errorf("subscribing to %s: %w", messaging.SubjectPrefixCacheInvalidate, err)
}
// Background eviction: sweep idle entries on the app context. Stopped
// when the app context is cancelled (mirrors the reconciler loop which
// also runs on options.Context). TTL/2 keeps stale entries from
// outliving their idle window by more than half a TTL.
evictInterval := prefixCfg.TTL / 2
go func() {
ticker := time.NewTicker(evictInterval)
defer ticker.Stop()
for {
select {
case <-cfg.Context.Done():
return
case <-ticker.C:
prefixSync.Evict(time.Now())
}
}
}()
xlog.Info("Prefix-cache-aware routing enabled", "ttl", prefixCfg.TTL, "evictInterval", evictInterval)
} else {
xlog.Info("Prefix-cache-aware routing disabled: using round-robin routing")
}
// All dependencies ready — build SmartRouter with all options at once
var conflictResolver nodes.ConcurrencyConflictResolver
if configLoader != nil {
conflictResolver = configLoader
}
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
Unloader: remoteUnloader,
FileStager: fileStager,
GalleriesJSON: routerGalleriesJSON,
AuthToken: routerAuthToken,
DB: authDB,
ConflictResolver: conflictResolver,
PrefixProvider: prefixProvider,
PrefixConfig: prefixCfg,
Pressure: pressure,
})
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
// RegistrationToken feed the state-reconciliation passes: pending op
// drain uses the adapter, and model health probes use the token to auth
// against workers' gRPC HealthCheck.
reconciler := nodes.NewReplicaReconciler(nodes.ReplicaReconcilerOptions{
Registry: registry,
Scheduler: router,
Unloader: remoteUnloader,
Adapter: remoteUnloader,
RegistrationToken: cfg.Distributed.RegistrationToken,
DB: authDB,
Interval: 30 * time.Second,
ScaleDownDelay: 5 * time.Minute,
ProbeStaleAfter: 2 * time.Minute,
Pressure: pressure,
PressureThreshold: prefixCfg.PressureScaleThreshold,
})
// Create ModelRouterAdapter to wire into ModelLoader
modelAdapter := nodes.NewModelRouterAdapter(router)
success = true
return &DistributedServices{
Nats: natsClient,
Store: store,
Registry: registry,
Router: router,
Health: healthMon,
Reconciler: reconciler,
JobStore: jobStore,
Dispatcher: dispatcher,
AgentStore: agentStore,
AgentBridge: agentBridge,
DistStores: distStores,
FileMgr: fileMgr,
FileStager: fileStager,
ModelAdapter: modelAdapter,
Unloader: remoteUnloader,
}, nil
}
func isPostgresURL(url string) bool {
return strings.HasPrefix(url, "postgres://") || strings.HasPrefix(url, "postgresql://")
}