diff --git a/core/cli/worker.go b/core/cli/worker.go index 2d6ffc605..e78529211 100644 --- a/core/cli/worker.go +++ b/core/cli/worker.go @@ -1,1096 +1,19 @@ package cli import ( - "cmp" - "context" - "encoding/json" - "fmt" - "maps" - "net" - "os" - "os/signal" - "path/filepath" - "slices" - "strconv" - "strings" - "sync" - "syscall" - "time" - cliContext "github.com/mudler/LocalAI/core/cli/context" - "github.com/mudler/LocalAI/core/cli/workerregistry" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/gallery" - "github.com/mudler/LocalAI/core/services/galleryop" - "github.com/mudler/LocalAI/core/services/messaging" - "github.com/mudler/LocalAI/core/services/nodes" - "github.com/mudler/LocalAI/core/services/storage" - grpc "github.com/mudler/LocalAI/pkg/grpc" - "github.com/mudler/LocalAI/pkg/model" - "github.com/mudler/LocalAI/pkg/sanitize" - "github.com/mudler/LocalAI/pkg/system" - "github.com/mudler/LocalAI/pkg/xsysinfo" - process "github.com/mudler/go-processmanager" - "github.com/mudler/xlog" + "github.com/mudler/LocalAI/core/services/worker" ) -// isPathAllowed checks if path is within one of the allowed directories. -func isPathAllowed(path string, allowedDirs []string) bool { - absPath, err := filepath.Abs(path) - if err != nil { - return false - } - resolved, err := filepath.EvalSymlinks(absPath) - if err != nil { - // Path may not exist yet; use the absolute path - resolved = absPath - } - for _, dir := range allowedDirs { - absDir, err := filepath.Abs(dir) - if err != nil { - continue - } - if strings.HasPrefix(resolved, absDir+string(filepath.Separator)) || resolved == absDir { - return true - } - } - return false -} - -// WorkerCMD starts a generic worker process for distributed mode. -// Workers are backend-agnostic — they wait for backend.install NATS events -// from the SmartRouter to install and start the required backend. -// -// NATS is required. The worker acts as a process supervisor: -// - Receives backend.install → installs backend from gallery, starts gRPC process, replies success -// - Receives backend.stop → stops the gRPC process -// - Receives stop → full shutdown (deregister + exit) -// -// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that. +// WorkerCMD is the kong-parsed CLI surface for `local-ai worker`. +// All business logic lives in core/services/worker — this struct just +// embeds the worker.Config (so kong sees the flag tags) and delegates Run. type WorkerCMD struct { - // Primary address — the reachable address of this worker. - // Host is used for advertise, port is the base for gRPC backends. - // HTTP file transfer runs on port-1. - Addr string `env:"LOCALAI_ADDR" help:"Address where this worker is reachable (host:port). Port is base for gRPC backends, port-1 for HTTP." group:"server"` - ServeAddr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"(Advanced) gRPC base port bind address" group:"server" hidden:""` - - BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"` - BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"` - BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"` - ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"` - - // HTTP file transfer - HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server" hidden:""` - AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server" hidden:""` - - // Registration (required) - AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""` - RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"` - NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"` - RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"` - HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"` - NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"` - // MaxReplicasPerModel caps how many replicas of any one model can run on - // this worker concurrently. Default 1 = historical single-replica - // behavior. Set higher when a node has enough VRAM to host multiple - // copies of the same model (e.g. a fat 128 GiB box running 4× of a - // 24 GiB model for throughput). The auto-label `node.replica-slots=N` - // is published so model schedulers can target high-capacity nodes via - // the existing label selector. - MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"` - - // NATS (required) - NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"` - - // S3 storage for distributed file transfer - StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"` - StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" help:"S3 bucket name" group:"distributed"` - StorageRegion string `env:"LOCALAI_STORAGE_REGION" help:"S3 region" group:"distributed"` - StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"` - StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"` + worker.Config `embed:""` } +// Run starts the distributed worker. Delegates to worker.Run after kong has +// populated the embedded Config. func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error { - xlog.Info("Starting worker", "advertise", cmd.advertiseAddr(), "basePort", cmd.effectiveBasePort()) - - systemState, err := system.GetSystemState( - system.WithModelPath(cmd.ModelsPath), - system.WithBackendPath(cmd.BackendsPath), - system.WithBackendSystemPath(cmd.BackendsSystemPath), - ) - if err != nil { - return fmt.Errorf("getting system state: %w", err) - } - - ml := model.NewModelLoader(systemState) - ml.SetBackendLoggingEnabled(true) - - // Register already-installed backends - gallery.RegisterBackends(systemState, ml) - - // Parse galleries config - var galleries []config.Gallery - if err := json.Unmarshal([]byte(cmd.BackendGalleries), &galleries); err != nil { - xlog.Warn("Failed to parse backend galleries", "error", err) - } - - // Self-registration with frontend (with retry) - regClient := &workerregistry.RegistrationClient{ - FrontendURL: cmd.RegisterTo, - RegistrationToken: cmd.RegistrationToken, - } - - registrationBody := cmd.registrationBody() - nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10) - if err != nil { - return fmt.Errorf("failed to register with frontend: %w", err) - } - - xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo) - heartbeatInterval, err := time.ParseDuration(cmd.HeartbeatInterval) - if err != nil && cmd.HeartbeatInterval != "" { - xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err) - } - heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second) - // Context cancelled on shutdown — used by heartbeat and other background goroutines - shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) - defer shutdownCancel() - - // Start HTTP file transfer server - httpAddr := cmd.resolveHTTPAddr() - stagingDir := filepath.Join(cmd.ModelsPath, "..", "staging") - dataDir := filepath.Join(cmd.ModelsPath, "..", "data") - httpServer, err := nodes.StartFileTransferServer(httpAddr, stagingDir, cmd.ModelsPath, dataDir, cmd.RegistrationToken, config.DefaultMaxUploadSize, ml.BackendLogs()) - if err != nil { - return fmt.Errorf("starting HTTP file transfer server: %w", err) - } - - // Connect to NATS - xlog.Info("Connecting to NATS", "url", sanitize.URL(cmd.NatsURL)) - natsClient, err := messaging.New(cmd.NatsURL) - if err != nil { - nodes.ShutdownFileTransferServer(httpServer) - return fmt.Errorf("connecting to NATS: %w", err) - } - defer natsClient.Close() - - // Start heartbeat goroutine (after NATS is connected so IsConnected check works) - go func() { - ticker := time.NewTicker(heartbeatInterval) - defer ticker.Stop() - for { - select { - case <-shutdownCtx.Done(): - return - case <-ticker.C: - if !natsClient.IsConnected() { - xlog.Warn("Skipping heartbeat: NATS disconnected") - continue - } - body := cmd.heartbeatBody() - if err := regClient.Heartbeat(shutdownCtx, nodeID, body); err != nil { - xlog.Warn("Heartbeat failed", "error", err) - } - } - } - }() - - // Process supervisor — manages multiple backend gRPC processes on different ports - basePort := cmd.effectiveBasePort() - // Buffered so NATS stop handler can send without blocking - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - // Set the registration token once before any backends are started - if cmd.RegistrationToken != "" { - os.Setenv(grpc.AuthTokenEnvVar, cmd.RegistrationToken) - } - - supervisor := &backendSupervisor{ - cmd: cmd, - ml: ml, - systemState: systemState, - galleries: galleries, - nodeID: nodeID, - nats: natsClient, - sigCh: sigCh, - processes: make(map[string]*backendProcess), - nextPort: basePort, - } - supervisor.subscribeLifecycleEvents() - - // Subscribe to file staging NATS subjects if S3 is configured - if cmd.StorageURL != "" { - if err := cmd.subscribeFileStaging(natsClient, nodeID); err != nil { - xlog.Error("Failed to subscribe to file staging subjects", "error", err) - } - } - - xlog.Info("Worker ready, waiting for backend.install events") - <-sigCh - - xlog.Info("Shutting down worker") - shutdownCancel() // stop heartbeat loop immediately - regClient.GracefulDeregister(nodeID) - supervisor.stopAllBackends() - nodes.ShutdownFileTransferServer(httpServer) - return nil -} - -// subscribeFileStaging subscribes to NATS file staging subjects for this node. -func (cmd *WorkerCMD) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error { - // Create FileManager with same S3 config as the frontend - // TODO: propagate a caller-provided context once WorkerCMD carries one - s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{ - Endpoint: cmd.StorageURL, - Region: cmd.StorageRegion, - Bucket: cmd.StorageBucket, - AccessKeyID: cmd.StorageAccessKey, - SecretAccessKey: cmd.StorageSecretKey, - ForcePathStyle: true, - }) - if err != nil { - return fmt.Errorf("initializing S3 store: %w", err) - } - - cacheDir := filepath.Join(cmd.ModelsPath, "..", "cache") - fm, err := storage.NewFileManager(s3Store, cacheDir) - if err != nil { - return fmt.Errorf("initializing file manager: %w", err) - } - - // Subscribe: files.ensure — download S3 key to local, reply with local path - natsClient.SubscribeReply(messaging.SubjectNodeFilesEnsure(nodeID), func(data []byte, reply func([]byte)) { - var req struct { - Key string `json:"key"` - } - if err := json.Unmarshal(data, &req); err != nil { - replyJSON(reply, map[string]string{"error": "invalid request"}) - return - } - - localPath, err := fm.Download(context.Background(), req.Key) - if err != nil { - xlog.Error("File ensure failed", "key", req.Key, "error", err) - replyJSON(reply, map[string]string{"error": err.Error()}) - return - } - - xlog.Debug("File ensured locally", "key", req.Key, "path", localPath) - replyJSON(reply, map[string]string{"local_path": localPath}) - }) - - // Subscribe: files.stage — upload local path to S3, reply with key - natsClient.SubscribeReply(messaging.SubjectNodeFilesStage(nodeID), func(data []byte, reply func([]byte)) { - var req struct { - LocalPath string `json:"local_path"` - Key string `json:"key"` - } - if err := json.Unmarshal(data, &req); err != nil { - replyJSON(reply, map[string]string{"error": "invalid request"}) - return - } - - allowedDirs := []string{cacheDir} - if cmd.ModelsPath != "" { - allowedDirs = append(allowedDirs, cmd.ModelsPath) - } - if !isPathAllowed(req.LocalPath, allowedDirs) { - replyJSON(reply, map[string]string{"error": "path outside allowed directories"}) - return - } - - if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil { - xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err) - replyJSON(reply, map[string]string{"error": err.Error()}) - return - } - - xlog.Debug("File staged to S3", "path", req.LocalPath, "key", req.Key) - replyJSON(reply, map[string]string{"key": req.Key}) - }) - - // Subscribe: files.temp — allocate temp file, reply with local path - natsClient.SubscribeReply(messaging.SubjectNodeFilesTemp(nodeID), func(data []byte, reply func([]byte)) { - tmpDir := filepath.Join(cacheDir, "staging-tmp") - if err := os.MkdirAll(tmpDir, 0750); err != nil { - replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp dir: %v", err)}) - return - } - - f, err := os.CreateTemp(tmpDir, "localai-staging-*.tmp") - if err != nil { - replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp file: %v", err)}) - return - } - localPath := f.Name() - f.Close() - - xlog.Debug("Allocated temp file", "path", localPath) - replyJSON(reply, map[string]string{"local_path": localPath}) - }) - - // Subscribe: files.listdir — list files in a local directory, reply with relative paths - natsClient.SubscribeReply(messaging.SubjectNodeFilesListDir(nodeID), func(data []byte, reply func([]byte)) { - var req struct { - KeyPrefix string `json:"key_prefix"` - } - if err := json.Unmarshal(data, &req); err != nil { - replyJSON(reply, map[string]any{"error": "invalid request"}) - return - } - - // Resolve key prefix to local directory - dirPath := filepath.Join(cacheDir, req.KeyPrefix) - if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.ModelKeyPrefix); ok && cmd.ModelsPath != "" { - dirPath = filepath.Join(cmd.ModelsPath, rel) - } else if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.DataKeyPrefix); ok { - dirPath = filepath.Join(cacheDir, "..", "data", rel) - } - - // Sanitize to prevent directory traversal via crafted key_prefix - dirPath = filepath.Clean(dirPath) - cleanCache := filepath.Clean(cacheDir) - cleanModels := filepath.Clean(cmd.ModelsPath) - cleanData := filepath.Clean(filepath.Join(cacheDir, "..", "data")) - if !(strings.HasPrefix(dirPath, cleanCache+string(filepath.Separator)) || - dirPath == cleanCache || - (cleanModels != "." && strings.HasPrefix(dirPath, cleanModels+string(filepath.Separator))) || - dirPath == cleanModels || - strings.HasPrefix(dirPath, cleanData+string(filepath.Separator)) || - dirPath == cleanData) { - replyJSON(reply, map[string]any{"error": "invalid key prefix"}) - return - } - - var files []string - filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error { - if err != nil { - return nil - } - if !d.IsDir() { - rel, err := filepath.Rel(dirPath, path) - if err == nil { - files = append(files, rel) - } - } - return nil - }) - - xlog.Debug("Listed remote dir", "keyPrefix", req.KeyPrefix, "dirPath", dirPath, "fileCount", len(files)) - replyJSON(reply, map[string]any{"files": files}) - }) - - xlog.Info("Subscribed to file staging NATS subjects", "nodeID", nodeID) - return nil -} - -// replyJSON marshals v to JSON and calls the reply function. -func replyJSON(reply func([]byte), v any) { - data, err := json.Marshal(v) - if err != nil { - xlog.Error("Failed to marshal NATS reply", "error", err) - data = []byte(`{"error":"internal marshal error"}`) - } - reply(data) -} - -// backendProcess represents a single gRPC backend process. -type backendProcess struct { - proc *process.Process - backend string - addr string // gRPC address (host:port) -} - -// backendSupervisor manages multiple backend gRPC processes on different ports. -// Each backend type (e.g., llama-cpp, bert-embeddings) gets its own process and port. -type backendSupervisor struct { - cmd *WorkerCMD - ml *model.ModelLoader - systemState *system.SystemState - galleries []config.Gallery - nodeID string - nats messaging.MessagingClient - sigCh chan<- os.Signal // send shutdown signal instead of os.Exit - - mu sync.Mutex - processes map[string]*backendProcess // key: backend name - nextPort int // next available port for new backends - freePorts []int // ports freed by stopBackend, reused before nextPort -} - -// startBackend starts a gRPC backend process on a dynamically allocated port. -// Returns the gRPC address. -func (s *backendSupervisor) startBackend(backend, backendPath string) (string, error) { - s.mu.Lock() - - // Already running? - if bp, ok := s.processes[backend]; ok { - if bp.proc != nil && bp.proc.IsAlive() { - s.mu.Unlock() - return bp.addr, nil - } - // Process died — clean up and restart - xlog.Warn("Backend process died unexpectedly, restarting", "backend", backend) - delete(s.processes, backend) - } - - // Allocate port — recycle freed ports first, then grow upward from basePort - var port int - if len(s.freePorts) > 0 { - port = s.freePorts[len(s.freePorts)-1] - s.freePorts = s.freePorts[:len(s.freePorts)-1] - } else { - port = s.nextPort - s.nextPort++ - } - bindAddr := fmt.Sprintf("0.0.0.0:%d", port) - clientAddr := fmt.Sprintf("127.0.0.1:%d", port) - - proc, err := s.ml.StartProcess(backendPath, backend, bindAddr) - if err != nil { - s.mu.Unlock() - return "", fmt.Errorf("starting backend process: %w", err) - } - - s.processes[backend] = &backendProcess{ - proc: proc, - backend: backend, - addr: clientAddr, - } - xlog.Info("Backend process started", "backend", backend, "addr", clientAddr) - - // Capture reference before unlocking for race-safe health check. - // Another goroutine could stopBackend and recycle the port while we poll. - bp := s.processes[backend] - s.mu.Unlock() - - // Wait for the gRPC server to be ready before reporting success. - // Slow nodes (Jetson Orin doing first-boot CUDA init, large CGO libs) - // can take 10-15s before the gRPC port accepts connections; the previous - // 4s window made the worker reply Success on a not-yet-listening port, - // which manifested upstream as "connect: connection refused" on the - // frontend's first LoadModel dial. - client := grpc.NewClientWithToken(clientAddr, false, nil, false, s.cmd.RegistrationToken) - const ( - readinessPollInterval = 200 * time.Millisecond - readinessTimeout = 30 * time.Second - ) - deadline := time.Now().Add(readinessTimeout) - for time.Now().Before(deadline) { - time.Sleep(readinessPollInterval) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - if ok, _ := client.HealthCheck(ctx); ok { - cancel() - // Verify the process wasn't stopped/replaced while health-checking - s.mu.Lock() - currentBP, exists := s.processes[backend] - s.mu.Unlock() - if !exists || currentBP != bp { - return "", fmt.Errorf("backend %s was stopped during startup", backend) - } - xlog.Debug("Backend gRPC server is ready", "backend", backend, "addr", clientAddr) - return clientAddr, nil - } - cancel() - - // Check if the process died (e.g. OOM, CUDA error, missing libs) - if !proc.IsAlive() { - stderrTail := readLastLinesFromFile(proc.StderrPath(), 20) - xlog.Warn("Backend process died during startup", "backend", backend, "stderr", stderrTail) - s.mu.Lock() - delete(s.processes, backend) - s.freePorts = append(s.freePorts, port) - s.mu.Unlock() - return "", fmt.Errorf("backend process %s died during startup. Last stderr:\n%s", backend, stderrTail) - } - } - - // Readiness deadline exceeded. Returning success here would leave the - // frontend with an unbound address (it dials, gets ECONNREFUSED, and - // the operator sees a misleading "connection refused" instead of the - // real cause). Stop the half-started process, recycle the port, and - // surface the failure to the caller with the backend's stderr tail. - stderrTail := readLastLinesFromFile(proc.StderrPath(), 20) - xlog.Error("Backend gRPC server not ready before deadline; aborting install", "backend", backend, "addr", clientAddr, "timeout", readinessTimeout, "stderr", stderrTail) - if killErr := proc.Stop(); killErr != nil { - xlog.Warn("Failed to stop unready backend process", "backend", backend, "error", killErr) - } - s.mu.Lock() - if cur, ok := s.processes[backend]; ok && cur == bp { - delete(s.processes, backend) - s.freePorts = append(s.freePorts, port) - } - s.mu.Unlock() - return "", fmt.Errorf("backend %s did not become ready within %s. Last stderr:\n%s", backend, readinessTimeout, stderrTail) -} - -// resolveProcessKeys turns a caller-supplied identifier into the set of -// process map keys it refers to. PR #9583 changed s.processes to be keyed by -// `modelID#replicaIndex`, but external NATS handlers still pass the bare -// model ID — without this resolver, those lookups silently no-op'd, so -// admin "Unload model" / "Delete backend" left the worker process alive. -// -// - Exact match wins. Callers that already know the full processKey -// (stopAllBackends iterating its own map) get exactly that entry. -// - Else, an identifier without `#` is treated as a model prefix and -// every `id#N` replica is returned. -// - An identifier that contains `#` but doesn't match anything returns -// nothing — no spurious prefix fallback when the caller was explicit. -func (s *backendSupervisor) resolveProcessKeys(id string) []string { - s.mu.Lock() - defer s.mu.Unlock() - if _, ok := s.processes[id]; ok { - return []string{id} - } - if strings.Contains(id, "#") { - return nil - } - prefix := id + "#" - var keys []string - for k := range s.processes { - if strings.HasPrefix(k, prefix) { - keys = append(keys, k) - } - } - return keys -} - -// stopBackend stops the backend process(es) matching the given identifier. -// Accepts a bare modelID (stops every replica) or a full processKey -// (stops just that replica). -func (s *backendSupervisor) stopBackend(id string) { - for _, key := range s.resolveProcessKeys(id) { - s.stopBackendExact(key) - } -} - -// stopBackendExact stops the process under exactly this key. Locking and -// network I/O are split: the map mutation runs under the lock, the gRPC -// Free() and proc.Stop() calls run after release so they don't block -// other supervisor operations. -func (s *backendSupervisor) stopBackendExact(key string) { - s.mu.Lock() - bp, ok := s.processes[key] - if !ok || bp.proc == nil { - s.mu.Unlock() - return - } - delete(s.processes, key) - if _, portStr, err := net.SplitHostPort(bp.addr); err == nil { - if p, err := strconv.Atoi(portStr); err == nil { - s.freePorts = append(s.freePorts, p) - } - } - s.mu.Unlock() - - client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cmd.RegistrationToken) - xlog.Debug("Calling Free() before stopping backend", "backend", key) - if err := client.Free(context.Background()); err != nil { - xlog.Warn("Free() failed (best-effort)", "backend", key, "error", err) - } - - xlog.Info("Stopping backend process", "backend", key, "addr", bp.addr) - if err := bp.proc.Stop(); err != nil { - xlog.Error("Error stopping backend process", "backend", key, "error", err) - } -} - -// stopAllBackends stops all running backend processes. -func (s *backendSupervisor) stopAllBackends() { - s.mu.Lock() - backends := slices.Collect(maps.Keys(s.processes)) - s.mu.Unlock() - - for _, b := range backends { - s.stopBackend(b) - } -} - -// readLastLinesFromFile reads the last n lines from a file. -// Returns an empty string if the file cannot be read. -func readLastLinesFromFile(path string, n int) string { - data, err := os.ReadFile(path) - if err != nil { - return "" - } - lines := strings.Split(strings.TrimRight(string(data), "\n"), "\n") - if len(lines) > n { - lines = lines[len(lines)-n:] - } - return strings.Join(lines, "\n") -} - -// isRunning returns whether at least one backend process matching the given -// identifier is currently running. Accepts a bare modelID (matches any -// replica) or a full processKey (exact match). Callers like the -// backend.delete pre-check rely on the bare-name path. -func (s *backendSupervisor) isRunning(id string) bool { - keys := s.resolveProcessKeys(id) - if len(keys) == 0 { - // Same lock-free zero-process check the caller would have done. - return false - } - s.mu.Lock() - defer s.mu.Unlock() - for _, key := range keys { - if bp, ok := s.processes[key]; ok && bp.proc != nil && bp.proc.IsAlive() { - return true - } - } - return false -} - -// getAddr returns the gRPC address for a running backend, or empty string. -func (s *backendSupervisor) getAddr(backend string) string { - s.mu.Lock() - defer s.mu.Unlock() - if bp, ok := s.processes[backend]; ok { - return bp.addr - } - return "" -} - -// buildProcessKey is the supervisor's stable identifier for a backend gRPC -// process. It includes the replica index so the same model can run multiple -// processes on a worker simultaneously without colliding on the same map slot -// or port. The "#N" suffix is purely internal — the controller never reads it. -func buildProcessKey(modelID, backend string, replicaIndex int) string { - base := modelID - if base == "" { - base = backend - } - return fmt.Sprintf("%s#%d", base, replicaIndex) -} - -// installBackend handles the backend.install flow: -// 1. If already running for this (model, replica) slot AND req.Force is false, -// return existing address (the fast path used by routine load events that -// just want to know which port a backend already serves on). -// 2. If req.Force is true, stop any process(es) currently using this backend -// so the gallery install can replace the on-disk artifact and the freshly -// started process picks up the new binary. This is the upgrade path — -// without it, every backend.install we receive after the first hits the -// fast path and silently no-ops, leaving the cluster on a stale build. -// 3. Install backend from gallery (force=req.Force so existing artifacts get -// overwritten on upgrade). -// 4. Find backend binary -// 5. Start gRPC process on a new port -// -// Returns the gRPC address of the backend process. -// -// ProcessKey includes the replica index so a worker with MaxReplicasPerModel>1 -// can host multiple processes for the same model on distinct ports. Old -// controllers (no replica_index in the request) implicitly target replica 0, -// which preserves single-replica behavior. -func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest) (string, error) { - processKey := buildProcessKey(req.ModelID, req.Backend, int(req.ReplicaIndex)) - - if !req.Force { - // Fast path: already running for this model+replica → return existing - // address. Verify liveness before trusting the cached entry: a process - // that died without the supervisor noticing leaves a stale (key, addr) - // pair, and getAddr would otherwise hand the controller an address - // that immediately ECONNREFUSEDs. The reconciler then marks the - // replica failed, retries the install, the supervisor says "already - // running" again, and the cluster loops on a dead replica forever. - if addr := s.getAddr(processKey); addr != "" { - if s.isRunning(processKey) { - xlog.Info("Backend already running for model replica", "backend", req.Backend, "model", req.ModelID, "replica", req.ReplicaIndex, "addr", addr) - return addr, nil - } - xlog.Warn("Stale process entry for backend (dead process); cleaning up before reinstall", - "backend", req.Backend, "model", req.ModelID, "replica", req.ReplicaIndex, "addr", addr) - s.stopBackendExact(processKey) - } - } else { - // Upgrade path: stop every live process that shares this backend so the - // gallery install can overwrite the on-disk artifact and the restarted - // process picks up the new binary. resolveProcessKeys catches peer - // replicas of the same backend (whisper#0, whisper#1, ...) on workers - // configured with MaxReplicasPerModel>1. We also stop the exact - // processKey from the request tuple — keys created with an explicit - // modelID don't share the bare-name prefix the resolver matches, but - // they're still using the old binary and need to come down. Both calls - // are no-ops on missing keys. - toStop := s.resolveProcessKeys(req.Backend) - toStop = append(toStop, processKey) - for _, key := range toStop { - xlog.Info("Force install: stopping running backend before reinstall", - "backend", req.Backend, "processKey", key) - s.stopBackendExact(key) - } - } - - // Parse galleries from request (override local config if provided) - galleries := s.galleries - if req.BackendGalleries != "" { - var reqGalleries []config.Gallery - if err := json.Unmarshal([]byte(req.BackendGalleries), &reqGalleries); err == nil { - galleries = reqGalleries - } - } - - // On upgrade, run the gallery install path even if the binary already - // exists on disk: findBackend would otherwise short-circuit and we'd - // restart the same stale binary. The force flag passed to - // InstallBackendFromGallery makes it overwrite the existing artifact. - backendPath := "" - if !req.Force { - backendPath = s.findBackend(req.Backend) - } - if backendPath == "" { - if req.URI != "" { - xlog.Info("Installing backend from external URI", "backend", req.Backend, "uri", req.URI, "force", req.Force) - if err := galleryop.InstallExternalBackend( - context.Background(), galleries, s.systemState, s.ml, nil, req.URI, req.Name, req.Alias, - ); err != nil { - return "", fmt.Errorf("installing backend from gallery: %w", err) - } - } else { - xlog.Info("Installing backend from gallery", "backend", req.Backend, "force", req.Force) - if err := gallery.InstallBackendFromGallery( - context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, req.Force, - ); err != nil { - return "", fmt.Errorf("installing backend from gallery: %w", err) - } - } - // Re-register after install and retry - gallery.RegisterBackends(s.systemState, s.ml) - backendPath = s.findBackend(req.Backend) - } - - if backendPath == "" { - return "", fmt.Errorf("backend %q not found after install attempt", req.Backend) - } - - xlog.Info("Found backend binary", "path", backendPath, "processKey", processKey) - - // Start the gRPC process on a new port (keyed by model, not just backend) - return s.startBackend(processKey, backendPath) -} - -// findBackend looks for the backend binary in the backends path and system path. -func (s *backendSupervisor) findBackend(backend string) string { - candidates := []string{ - filepath.Join(s.cmd.BackendsPath, backend), - filepath.Join(s.cmd.BackendsPath, backend, backend), - filepath.Join(s.cmd.BackendsSystemPath, backend), - filepath.Join(s.cmd.BackendsSystemPath, backend, backend), - } - if uri := s.ml.GetExternalBackend(backend); uri != "" { - if fi, err := os.Stat(uri); err == nil && !fi.IsDir() { - return uri - } - } - for _, path := range candidates { - fi, err := os.Stat(path) - if err == nil && !fi.IsDir() { - return path - } - } - return "" -} - -// subscribeLifecycleEvents subscribes to NATS backend lifecycle events. -func (s *backendSupervisor) subscribeLifecycleEvents() { - // backend.install — install backend + start gRPC process (request-reply) - s.nats.SubscribeReply(messaging.SubjectNodeBackendInstall(s.nodeID), func(data []byte, reply func([]byte)) { - xlog.Info("Received NATS backend.install event") - var req messaging.BackendInstallRequest - if err := json.Unmarshal(data, &req); err != nil { - resp := messaging.BackendInstallReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} - replyJSON(reply, resp) - return - } - - addr, err := s.installBackend(req) - if err != nil { - xlog.Error("Failed to install backend via NATS", "error", err) - resp := messaging.BackendInstallReply{Success: false, Error: err.Error()} - replyJSON(reply, resp) - return - } - - // Return the gRPC address so the router knows which port to use - advertiseAddr := addr - advAddr := s.cmd.advertiseAddr() - if advAddr != addr { // only remap if advertise differs from bind - _, port, _ := net.SplitHostPort(addr) - advertiseHost, _, _ := net.SplitHostPort(advAddr) - advertiseAddr = net.JoinHostPort(advertiseHost, port) - } - resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr} - replyJSON(reply, resp) - }) - - // backend.stop — stop a specific backend process - s.nats.Subscribe(messaging.SubjectNodeBackendStop(s.nodeID), func(data []byte) { - // Try to parse backend name from payload; if empty, stop all - var req struct { - Backend string `json:"backend"` - } - if json.Unmarshal(data, &req) == nil && req.Backend != "" { - xlog.Info("Received NATS backend.stop event", "backend", req.Backend) - s.stopBackend(req.Backend) - } else { - xlog.Info("Received NATS backend.stop event (all)") - s.stopAllBackends() - } - }) - - // backend.delete — stop backend + delete files (request-reply) - s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), func(data []byte, reply func([]byte)) { - var req messaging.BackendDeleteRequest - if err := json.Unmarshal(data, &req); err != nil { - resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} - replyJSON(reply, resp) - return - } - xlog.Info("Received NATS backend.delete event", "backend", req.Backend) - - // Stop if running this backend - if s.isRunning(req.Backend) { - s.stopBackend(req.Backend) - } - - // Delete the backend files - if err := gallery.DeleteBackendFromSystem(s.systemState, req.Backend); err != nil { - xlog.Warn("Failed to delete backend files", "backend", req.Backend, "error", err) - resp := messaging.BackendDeleteReply{Success: false, Error: err.Error()} - replyJSON(reply, resp) - return - } - - // Re-register backends after deletion - gallery.RegisterBackends(s.systemState, s.ml) - - resp := messaging.BackendDeleteReply{Success: true} - replyJSON(reply, resp) - }) - - // backend.list — list installed backends (request-reply) - s.nats.SubscribeReply(messaging.SubjectNodeBackendList(s.nodeID), func(data []byte, reply func([]byte)) { - xlog.Info("Received NATS backend.list event") - backends, err := gallery.ListSystemBackends(s.systemState) - if err != nil { - resp := messaging.BackendListReply{Error: err.Error()} - replyJSON(reply, resp) - return - } - - var infos []messaging.NodeBackendInfo - for name, b := range backends { - info := messaging.NodeBackendInfo{ - Name: name, - IsSystem: b.IsSystem, - IsMeta: b.IsMeta, - } - if b.Metadata != nil { - info.InstalledAt = b.Metadata.InstalledAt - info.GalleryURL = b.Metadata.GalleryURL - info.Version = b.Metadata.Version - info.URI = b.Metadata.URI - info.Digest = b.Metadata.Digest - } - infos = append(infos, info) - } - - resp := messaging.BackendListReply{Backends: infos} - replyJSON(reply, resp) - }) - - // model.unload — call gRPC Free() to release GPU memory (request-reply) - s.nats.SubscribeReply(messaging.SubjectNodeModelUnload(s.nodeID), func(data []byte, reply func([]byte)) { - xlog.Info("Received NATS model.unload event") - var req messaging.ModelUnloadRequest - if err := json.Unmarshal(data, &req); err != nil { - resp := messaging.ModelUnloadReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} - replyJSON(reply, resp) - return - } - - // Find the backend address for this model's backend type - // The request includes an Address field if the router knows which process to target - targetAddr := req.Address - if targetAddr == "" { - // Fallback: try all running backends - s.mu.Lock() - for _, bp := range s.processes { - targetAddr = bp.addr - break - } - s.mu.Unlock() - } - - if targetAddr != "" { - // Best-effort gRPC Free() - client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cmd.RegistrationToken) - if err := client.Free(context.Background()); err != nil { - xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr) - } - } - - resp := messaging.ModelUnloadReply{Success: true} - replyJSON(reply, resp) - }) - - // model.delete — remove model files from disk (request-reply) - s.nats.SubscribeReply(messaging.SubjectNodeModelDelete(s.nodeID), func(data []byte, reply func([]byte)) { - xlog.Info("Received NATS model.delete event") - var req messaging.ModelDeleteRequest - if err := json.Unmarshal(data, &req); err != nil { - replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: "invalid request"}) - return - } - - if err := gallery.DeleteStagedModelFiles(s.cmd.ModelsPath, req.ModelName); err != nil { - xlog.Warn("Failed to delete model files", "model", req.ModelName, "error", err) - replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: err.Error()}) - return - } - - replyJSON(reply, messaging.ModelDeleteReply{Success: true}) - }) - - // stop — trigger the normal shutdown path via sigCh so deferred cleanup runs - s.nats.Subscribe(messaging.SubjectNodeStop(s.nodeID), func(data []byte) { - xlog.Info("Received NATS stop event — signaling shutdown") - select { - case s.sigCh <- syscall.SIGTERM: - default: - xlog.Debug("Shutdown already signaled, ignoring duplicate stop") - } - }) -} - -// effectiveBasePort returns the port used as base for gRPC backend processes. -// Priority: Addr port → ServeAddr port → 50051 -func (cmd *WorkerCMD) effectiveBasePort() int { - for _, addr := range []string{cmd.Addr, cmd.ServeAddr} { - if addr != "" { - if _, portStr, ok := strings.Cut(addr, ":"); ok { - if p, _ := strconv.Atoi(portStr); p > 0 { - return p - } - } - } - } - return 50051 -} - -// advertiseAddr returns the address the frontend should use to reach this node. -func (cmd *WorkerCMD) advertiseAddr() string { - if cmd.AdvertiseAddr != "" { - return cmd.AdvertiseAddr - } - if cmd.Addr != "" { - return cmd.Addr - } - hostname, _ := os.Hostname() - return fmt.Sprintf("%s:%d", cmp.Or(hostname, "localhost"), cmd.effectiveBasePort()) -} - -// resolveHTTPAddr returns the address to bind the HTTP file transfer server to. -// Uses basePort-1 so it doesn't conflict with dynamically allocated gRPC ports -// which grow upward from basePort. -func (cmd *WorkerCMD) resolveHTTPAddr() string { - if cmd.HTTPAddr != "" { - return cmd.HTTPAddr - } - return fmt.Sprintf("0.0.0.0:%d", cmd.effectiveBasePort()-1) -} - -// advertiseHTTPAddr returns the HTTP address the frontend should use to reach -// this node for file transfer. -func (cmd *WorkerCMD) advertiseHTTPAddr() string { - if cmd.AdvertiseHTTPAddr != "" { - return cmd.AdvertiseHTTPAddr - } - advHost, _, _ := strings.Cut(cmd.advertiseAddr(), ":") - httpPort := cmd.effectiveBasePort() - 1 - return fmt.Sprintf("%s:%d", advHost, httpPort) -} - -// registrationBody builds the JSON body for node registration. -func (cmd *WorkerCMD) registrationBody() map[string]any { - nodeName := cmd.NodeName - if nodeName == "" { - hostname, err := os.Hostname() - if err != nil { - nodeName = fmt.Sprintf("node-%d", os.Getpid()) - } else { - nodeName = hostname - } - } - - // Detect GPU info for VRAM-aware scheduling - totalVRAM, _ := xsysinfo.TotalAvailableVRAM() - gpuVendor, _ := xsysinfo.DetectGPUVendor() - - maxReplicas := cmd.MaxReplicasPerModel - if maxReplicas < 1 { - maxReplicas = 1 - } - body := map[string]any{ - "name": nodeName, - "address": cmd.advertiseAddr(), - "http_address": cmd.advertiseHTTPAddr(), - "total_vram": totalVRAM, - "available_vram": totalVRAM, // initially all VRAM is available - "gpu_vendor": gpuVendor, - "max_replicas_per_model": maxReplicas, - } - - // If no GPU detected, report system RAM so the scheduler/UI has capacity info - if totalVRAM == 0 { - if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil { - body["total_ram"] = ramInfo.Total - body["available_ram"] = ramInfo.Available - } - } - if cmd.RegistrationToken != "" { - body["token"] = cmd.RegistrationToken - } - - // Parse and add static node labels. Always include the auto-label - // `node.replica-slots=N` so AND-selectors in ModelSchedulingConfig can - // target high-capacity nodes (e.g. {"node.replica-slots":"4"}). - labels := make(map[string]string) - if cmd.NodeLabels != "" { - for _, pair := range strings.Split(cmd.NodeLabels, ",") { - pair = strings.TrimSpace(pair) - if k, v, ok := strings.Cut(pair, "="); ok { - labels[strings.TrimSpace(k)] = strings.TrimSpace(v) - } - } - } - labels["node.replica-slots"] = strconv.Itoa(maxReplicas) - body["labels"] = labels - - return body -} - -// heartbeatBody returns the current VRAM/RAM stats for heartbeat payloads. -// -// When aggregate VRAM usage is unknown (no GPU, or temporary detection -// failure), we deliberately OMIT available_vram so the frontend keeps its -// last good value — overwriting with 0 makes the UI show the node as "fully -// used", while reporting total-as-available lies to the scheduler about -// free capacity. -func (cmd *WorkerCMD) heartbeatBody() map[string]any { - body := map[string]any{} - aggregate := xsysinfo.GetGPUAggregateInfo() - if aggregate.TotalVRAM > 0 { - body["available_vram"] = aggregate.FreeVRAM - } - - // CPU-only workers (or workers that lost GPU visibility momentarily): - // report system RAM so the scheduler still has capacity info. - if aggregate.TotalVRAM == 0 { - if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil { - body["available_ram"] = ramInfo.Available - } - } - return body + return worker.Run(ctx, &cmd.Config) } diff --git a/core/http/endpoints/localai/nodes.go b/core/http/endpoints/localai/nodes.go index 4cc8643bf..9b622acf5 100644 --- a/core/http/endpoints/localai/nodes.go +++ b/core/http/endpoints/localai/nodes.go @@ -407,10 +407,10 @@ func InstallBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.Handler } // Admin-driven backend install: not tied to a specific replica slot // (no model is being loaded). Pass replica 0 to match the worker's - // admin process-key convention (`backend#0`). force=false so the - // worker's fast path takes over if the backend is already running — - // upgrades go through the dedicated /api/backends/upgrade path. - reply, err := unloader.InstallBackend(nodeID, req.Backend, "", req.BackendGalleries, req.URI, req.Name, req.Alias, 0, false) + // admin process-key convention (`backend#0`). The worker's fast path + // takes over if the backend is already running — upgrades go through + // the dedicated /api/backends/upgrade path on backend.upgrade. + reply, err := unloader.InstallBackend(nodeID, req.Backend, "", req.BackendGalleries, req.URI, req.Name, req.Alias, 0) if err != nil { xlog.Error("Failed to install backend on node", "node", nodeID, "backend", req.Backend, "uri", req.URI, "error", err) return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to install backend on node")) diff --git a/core/services/messaging/subjects.go b/core/services/messaging/subjects.go index de2e7dcc0..25080caee 100644 --- a/core/services/messaging/subjects.go +++ b/core/services/messaging/subjects.go @@ -137,13 +137,12 @@ type BackendInstallRequest struct { // (single-replica behavior — no collision because the controller never // asks for replica > 0 on a node whose MaxReplicasPerModel is 1). ReplicaIndex int32 `json:"replica_index,omitempty"` - // Force skips the "already running" short-circuit and re-runs the gallery - // install. UpgradeBackend sets this so the worker actually re-downloads the - // artifact, stops the live process, and starts a fresh one — without it, - // the install handler's early return makes upgrades a silent no-op while - // the coordinator's drift detection keeps re-flagging the backend forever. - // Older workers that don't know this field treat it as false (current - // behavior preserved). + // Force is retained on the wire only for backward compatibility with + // pre-2026-05-08 masters that did not know about backend.upgrade. New + // callers MUST send to SubjectNodeBackendUpgrade instead. Workers continue + // to honor Force=true here so a rolling update with new master + old + // worker still works (the master's install fallback path also uses this + // when backend.upgrade returns nats.ErrNoResponders). Force bool `json:"force,omitempty"` } @@ -154,6 +153,41 @@ type BackendInstallReply struct { Error string `json:"error,omitempty"` } +// SubjectNodeBackendUpgrade tells a worker node to force-reinstall a backend +// from the gallery, stop every running process for that backend, and restart. +// Uses NATS request-reply with a long deadline (gallery image pulls can take +// many minutes on slow links). Routine model loads use SubjectNodeBackendInstall +// instead — this subject exists so the slow path doesn't head-of-line-block +// the fast one through a shared subscription goroutine. +func SubjectNodeBackendUpgrade(nodeID string) string { + return subjectNodePrefix + sanitizeSubjectToken(nodeID) + ".backend.upgrade" +} + +// BackendUpgradeRequest is the payload for a backend.upgrade NATS request. +// It is intentionally a strict subset of BackendInstallRequest — there is no +// Force field because the upgrade subject IS the force semantics; no ModelID +// because upgrade is backend-scoped (it stops every replica using the binary +// before re-installing). Per-replica restart happens on the next routine load. +type BackendUpgradeRequest struct { + Backend string `json:"backend"` + BackendGalleries string `json:"backend_galleries,omitempty"` + URI string `json:"uri,omitempty"` + Name string `json:"name,omitempty"` + Alias string `json:"alias,omitempty"` + // ReplicaIndex is informational — upgrade stops all replicas regardless, + // but the field lets future per-replica metadata (e.g. progress reporting + // scoped to a slot) ride the same wire without a v3 type. + ReplicaIndex int32 `json:"replica_index,omitempty"` +} + +// BackendUpgradeReply mirrors BackendInstallReply minus Address — upgrade does +// not start a process, so there is no port to advertise. The subsequent +// routine load will re-bind via backend.install and learn the new address. +type BackendUpgradeReply struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + // SubjectNodeBackendList queries a worker node for its installed backends. // Uses NATS request-reply. func SubjectNodeBackendList(nodeID string) string { diff --git a/core/services/messaging/subjects_upgrade_test.go b/core/services/messaging/subjects_upgrade_test.go new file mode 100644 index 000000000..e60369cfc --- /dev/null +++ b/core/services/messaging/subjects_upgrade_test.go @@ -0,0 +1,32 @@ +package messaging_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/messaging" +) + +var _ = Describe("SubjectNodeBackendUpgrade", func() { + It("returns the per-node upgrade subject", func() { + Expect(messaging.SubjectNodeBackendUpgrade("abc")). + To(Equal("nodes.abc.backend.upgrade")) + }) + + It("sanitizes reserved NATS tokens in the node id", func() { + Expect(messaging.SubjectNodeBackendUpgrade("a.b*c")). + To(Equal("nodes.a-b-c.backend.upgrade")) + }) +}) + +var _ = Describe("BackendUpgradeRequest", func() { + It("carries backend name, galleries JSON, and replica index", func() { + req := messaging.BackendUpgradeRequest{ + Backend: "llama-cpp", + BackendGalleries: `[{"name":"x"}]`, + ReplicaIndex: 2, + } + Expect(req.Backend).To(Equal("llama-cpp")) + Expect(req.ReplicaIndex).To(BeEquivalentTo(2)) + }) +}) diff --git a/core/services/nodes/managers_distributed.go b/core/services/nodes/managers_distributed.go index c8d7bf48b..e5c99d9b7 100644 --- a/core/services/nodes/managers_distributed.go +++ b/core/services/nodes/managers_distributed.go @@ -339,7 +339,7 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall // Admin-driven backend install: not tied to a specific replica slot. // Pass replica 0 — the worker's processKey is "backend#0" when no // modelID is supplied, matching pre-PR4 behavior. - reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON), op.ExternalURI, op.ExternalName, op.ExternalAlias, 0, false) + reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON), op.ExternalURI, op.ExternalName, op.ExternalAlias, 0) if err != nil { return err } @@ -354,18 +354,18 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall return result.Err() } -// UpgradeBackend reuses the install NATS subject (the worker re-downloads -// from the gallery). Unlike Install/Delete, upgrade only targets the nodes -// that already report this backend as installed — fanning out to every node -// would ask workers to "upgrade" something they never had, which fails at -// the gallery (e.g. a darwin/arm64 worker has no platform variant for a -// linux-only backend) and leaves a forever-retrying pending_backend_ops row. +// UpgradeBackend uses a separate NATS subject (backend.upgrade) so the slow +// force-reinstall path doesn't head-of-line-block routine model loads on +// the same worker. Only nodes that already report this backend as installed +// are targeted — fanning out to every node would ask workers to "upgrade" +// something they never had, which fails at the gallery (e.g. a darwin/arm64 +// worker has no platform variant for a linux-only backend) and leaves a +// forever-retrying pending_backend_ops row. // -// force=true on the install call is what distinguishes upgrade from install: -// the worker stops the live process for this backend, overwrites the on-disk -// artifact, and restarts. Without it, the worker's "already running" fast -// path turns every backend.install into a no-op and the gallery's drift -// detection never converges. +// Rolling-update fallback: when a worker returns nats.ErrNoResponders on +// backend.upgrade, we try the legacy backend.install Force=true path so a +// new master + old worker still converges. Drop the fallback once every +// worker in the fleet is on 2026-05-08 or newer. func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name string, progressCb galleryop.ProgressCallback) error { galleriesJSON, _ := json.Marshal(d.backendGalleries) @@ -383,8 +383,20 @@ func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name str } result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendUpgrade, name, galleriesJSON, targetNodeIDs, func(node BackendNode) error { - reply, err := d.adapter.InstallBackend(node.ID, name, "", string(galleriesJSON), "", "", "", 0, true) + reply, err := d.adapter.UpgradeBackend(node.ID, name, string(galleriesJSON), "", "", "", 0) if err != nil { + // Rolling-update fallback: an older worker doesn't know + // backend.upgrade. Try the legacy install-with-force path. + if errors.Is(err, nats.ErrNoResponders) { + instReply, instErr := d.adapter.installWithForceFallback(node.ID, name, string(galleriesJSON), "", "", "", 0) + if instErr != nil { + return instErr + } + if !instReply.Success { + return fmt.Errorf("upgrade (legacy fallback) failed: %s", instReply.Error) + } + return nil + } return err } if !reply.Success { diff --git a/core/services/nodes/managers_distributed_test.go b/core/services/nodes/managers_distributed_test.go index 77ac55808..79ae8c7b9 100644 --- a/core/services/nodes/managers_distributed_test.go +++ b/core/services/nodes/managers_distributed_test.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/nats-io/nats.go" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gorm.io/gorm" @@ -21,10 +22,22 @@ import ( // (or error). Used so each fan-out request can simulate a different worker // outcome without spinning up real NATS. type scriptedMessagingClient struct { - mu sync.Mutex - replies map[string][]byte - errs map[string]error - calls []requestCall + mu sync.Mutex + replies map[string][]byte + errs map[string]error + calls []requestCall + matchedReplies map[string][]matchedReply +} + +// matchedReply lets a test script a canned reply that only fires when the +// inbound request matches a predicate. Used by scriptReplyMatching to +// distinguish "install Force=true" (the fallback) from "install Force=false" +// on the same subject. +type matchedReply struct { + pred func(messaging.BackendInstallRequest) bool + reply []byte + fallback []byte + fallbackErr error } func newScriptedMessagingClient() *scriptedMessagingClient { @@ -48,10 +61,69 @@ func (s *scriptedMessagingClient) scriptErr(subject string, err error) { s.errs[subject] = err } +// scriptNoResponders scripts a nats.ErrNoResponders error for `subject` so +// tests can simulate "old worker without backend.upgrade subscription" +// scenarios. Uses the real nats sentinel so errors.Is(...) works at the +// caller (the manager's NoResponders fallback path). +func (s *scriptedMessagingClient) scriptNoResponders(subject string) { + s.mu.Lock() + defer s.mu.Unlock() + s.errs[subject] = nats.ErrNoResponders +} + +// scriptReplyMatching is like scriptReply but the canned reply only fires +// when the inbound request payload matches `pred(req)`. Lets tests +// differentiate "install with Force=true" from "install Force=false" on +// the same subject — useful for asserting the rolling-update fallback +// path actually sets Force=true on its retry. +// +// If `pred` returns false (or the unmarshal of the payload into the +// predicate's expected type fails), the subject falls through to whatever +// was scripted before (or to the unscripted default ErrNoResponders). +func (s *scriptedMessagingClient) scriptReplyMatching(subject string, pred func(messaging.BackendInstallRequest) bool, reply messaging.BackendInstallReply) { + raw, err := json.Marshal(reply) + Expect(err).ToNot(HaveOccurred()) + s.mu.Lock() + defer s.mu.Unlock() + prev := s.replies[subject] // may be nil — that's fine + prevErr := s.errs[subject] // may be nil — that's fine + if s.matchedReplies == nil { + s.matchedReplies = map[string][]matchedReply{} + } + s.matchedReplies[subject] = append(s.matchedReplies[subject], matchedReply{ + pred: pred, + reply: raw, + fallback: prev, + fallbackErr: prevErr, + }) +} + func (s *scriptedMessagingClient) Request(subject string, data []byte, _ time.Duration) ([]byte, error) { s.mu.Lock() defer s.mu.Unlock() s.calls = append(s.calls, requestCall{Subject: subject, Data: data}) + + // Predicate-matched replies take precedence over flat scriptReply. + if matchers, ok := s.matchedReplies[subject]; ok { + var req messaging.BackendInstallRequest + _ = json.Unmarshal(data, &req) + for _, m := range matchers { + if m.pred(req) { + return m.reply, nil + } + } + // No predicate matched — fall through to the recorded fallback + // (whatever was scripted before scriptReplyMatching took over). + if matchers[0].fallback != nil { + return matchers[0].fallback, nil + } + if matchers[0].fallbackErr != nil { + return nil, matchers[0].fallbackErr + } + // No fallback either — default to ErrNoResponders. + return nil, nats.ErrNoResponders + } + if err, ok := s.errs[subject]; ok && err != nil { return nil, err } @@ -79,10 +151,12 @@ func (s *scriptedMessagingClient) SubscribeReply(_ string, _ func([]byte, func([ func (s *scriptedMessagingClient) IsConnected() bool { return true } func (s *scriptedMessagingClient) Close() {} -// fakeNoRespondersErr matches nats.ErrNoResponders by name only — we don't -// import nats here to avoid pulling the whole client. The distributed -// manager treats it via errors.Is, so the concrete type matters for the -// "mark unhealthy" path; here we just want a non-nil error. +// fakeNoRespondersErr is the unscripted-subject default. It matches +// nats.ErrNoResponders by string only — used when a test forgets to script +// a node so the failure is loud but doesn't tickle errors.Is(...) sentinel +// paths the test wasn't deliberately exercising. Tests that DO want the +// real sentinel (e.g. to drive the manager's NoResponders fallback) call +// scriptNoResponders instead, which scripts nats.ErrNoResponders directly. type fakeNoRespondersErr struct{} func (e *fakeNoRespondersErr) Error() string { return "no responders" } @@ -264,10 +338,10 @@ var _ = Describe("DistributedBackendManager", func() { n2 := registerHealthyBackend("worker-b", "10.0.0.2:50051") scriptInstalled("vllm-development", n1.ID, n2.ID) - mc.scriptReply(messaging.SubjectNodeBackendInstall(n1.ID), - messaging.BackendInstallReply{Success: false, Error: "image manifest not found"}) - mc.scriptReply(messaging.SubjectNodeBackendInstall(n2.ID), - messaging.BackendInstallReply{Success: false, Error: "registry unauthorized"}) + mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n1.ID), + messaging.BackendUpgradeReply{Success: false, Error: "image manifest not found"}) + mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n2.ID), + messaging.BackendUpgradeReply{Success: false, Error: "registry unauthorized"}) err := mgr.UpgradeBackend(ctx, "vllm-development", nil) Expect(err).To(HaveOccurred()) @@ -282,8 +356,8 @@ var _ = Describe("DistributedBackendManager", func() { It("returns nil", func() { n1 := registerHealthyBackend("worker-a", "10.0.0.1:50051") scriptInstalled("vllm-development", n1.ID) - mc.scriptReply(messaging.SubjectNodeBackendInstall(n1.ID), - messaging.BackendInstallReply{Success: true}) + mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n1.ID), + messaging.BackendUpgradeReply{Success: true}) Expect(mgr.UpgradeBackend(ctx, "vllm-development", nil)).To(Succeed()) }) }) @@ -300,9 +374,9 @@ var _ = Describe("DistributedBackendManager", func() { scriptInstalled("cpu-insightface-development", has.ID) scriptNoBackends(lacks.ID) - mc.scriptReply(messaging.SubjectNodeBackendInstall(has.ID), - messaging.BackendInstallReply{Success: true}) - // Deliberately don't script SubjectNodeBackendInstall for `lacks`: + mc.scriptReply(messaging.SubjectNodeBackendUpgrade(has.ID), + messaging.BackendUpgradeReply{Success: true}) + // Deliberately don't script SubjectNodeBackendUpgrade for `lacks`: // if the manager attempts it, the scripted-client default returns // fakeNoRespondersErr and the assertion below fails loudly. @@ -311,7 +385,7 @@ var _ = Describe("DistributedBackendManager", func() { mc.mu.Lock() defer mc.mu.Unlock() for _, call := range mc.calls { - Expect(call.Subject).ToNot(Equal(messaging.SubjectNodeBackendInstall(lacks.ID)), + Expect(call.Subject).ToNot(Equal(messaging.SubjectNodeBackendUpgrade(lacks.ID)), "upgrade leaked to %s which does not have the backend installed", lacks.Name) } }) @@ -329,10 +403,44 @@ var _ = Describe("DistributedBackendManager", func() { mc.mu.Lock() defer mc.mu.Unlock() for _, call := range mc.calls { + Expect(call.Subject).ToNot(Equal(messaging.SubjectNodeBackendUpgrade(n1.ID))) Expect(call.Subject).ToNot(Equal(messaging.SubjectNodeBackendInstall(n1.ID))) } }) }) + + // Rolling-update fallback: pre-2026-05-08 workers don't subscribe to + // backend.upgrade, so the manager catches nats.ErrNoResponders and + // re-fires the legacy backend.install Force=true on the same node. + // Drop these specs once the fallback path itself is removed (see + // managers_distributed.go UpgradeBackend godoc for the deprecation). + Context("rolling-update fallback", func() { + It("falls back to backend.install Force=true when upgrade returns ErrNoResponders", func() { + n := registerHealthyBackend("worker-old", "10.0.0.1:50051") + scriptInstalled("vllm-development", n.ID) + + // Old worker: no subscriber on backend.upgrade. + mc.scriptNoResponders(messaging.SubjectNodeBackendUpgrade(n.ID)) + // Fallback re-fires legacy backend.install with Force=true. + mc.scriptReplyMatching(messaging.SubjectNodeBackendInstall(n.ID), + func(req messaging.BackendInstallRequest) bool { return req.Force }, + messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"}) + + Expect(mgr.UpgradeBackend(ctx, "vllm-development", nil)).To(Succeed()) + }) + + It("returns the upgrade error when it is not ErrNoResponders", func() { + n := registerHealthyBackend("worker-bad", "10.0.0.1:50051") + scriptInstalled("vllm-development", n.ID) + + mc.scriptReply(messaging.SubjectNodeBackendUpgrade(n.ID), + messaging.BackendUpgradeReply{Success: false, Error: "disk full"}) + + err := mgr.UpgradeBackend(ctx, "vllm-development", nil) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("disk full")) + }) + }) }) Describe("DeleteBackend", func() { diff --git a/core/services/nodes/reconciler.go b/core/services/nodes/reconciler.go index 298d4e752..72e1c441c 100644 --- a/core/services/nodes/reconciler.go +++ b/core/services/nodes/reconciler.go @@ -187,18 +187,36 @@ func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) { switch op.Op { case OpBackendDelete: _, applyErr = rc.adapter.DeleteBackend(op.NodeID, op.Backend) - case OpBackendInstall, OpBackendUpgrade: - // Pending-op drain for admin install/upgrade — not a per-replica - // load. Replica 0 is the conventional admin slot. Upgrade ops set - // force=true so the worker reinstalls the artifact and restarts - // the live process; install ops keep the existing fast-path - // semantics for the case where the backend is already running. - force := op.Op == OpBackendUpgrade - reply, err := rc.adapter.InstallBackend(op.NodeID, op.Backend, "", string(op.Galleries), "", "", "", 0, force) + case OpBackendInstall: + // Pending-op drain for admin install — not a per-replica load. + // Replica 0 is the conventional admin slot. Install is idempotent: + // the worker short-circuits if the backend is already running. + reply, err := rc.adapter.InstallBackend(op.NodeID, op.Backend, "", string(op.Galleries), "", "", "", 0) if err != nil { applyErr = err } else if !reply.Success { - applyErr = fmt.Errorf("%s failed: %s", op.Op, reply.Error) + applyErr = fmt.Errorf("install failed: %s", reply.Error) + } + case OpBackendUpgrade: + // Pending-op drain for admin upgrade — fires backend.upgrade so + // the slow re-pull doesn't head-of-line-block install traffic on + // the same worker. Falls back to the legacy backend.install + // Force=true path on nats.ErrNoResponders for old workers that + // don't subscribe to backend.upgrade yet (rolling-update window). + reply, err := rc.adapter.UpgradeBackend(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0) + if err != nil { + if errors.Is(err, nats.ErrNoResponders) { + instReply, instErr := rc.adapter.installWithForceFallback(op.NodeID, op.Backend, string(op.Galleries), "", "", "", 0) + if instErr != nil { + applyErr = instErr + } else if !instReply.Success { + applyErr = fmt.Errorf("upgrade (legacy fallback) failed: %s", instReply.Error) + } + } else { + applyErr = err + } + } else if !reply.Success { + applyErr = fmt.Errorf("upgrade failed: %s", reply.Error) } default: xlog.Warn("Reconciler: unknown pending op", "op", op.Op, "id", op.ID) diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go index 751b33f0e..2d9101f6d 100644 --- a/core/services/nodes/router.go +++ b/core/services/nodes/router.go @@ -16,6 +16,7 @@ import ( pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/vram" "github.com/mudler/xlog" + "golang.org/x/sync/singleflight" "google.golang.org/protobuf/proto" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -55,6 +56,11 @@ type SmartRouter struct { db *gorm.DB // for advisory locks during routing stagingTracker *StagingTracker // tracks file staging progress for UI visibility conflictResolver ConcurrencyConflictResolver + // installFlight coalesces concurrent identical NATS install requests + // (same nodeID + backend + modelID + replica) so 6 simultaneous chat + // completions for one not-yet-loaded model produce ONE round-trip, not + // six. Avoids amplifying head-of-line blocking on the worker side. + installFlight singleflight.Group } // NewSmartRouter creates a new SmartRouter backed by the given ModelRouter. @@ -664,31 +670,42 @@ func (r *SmartRouter) estimateModelVRAM(ctx context.Context, opts *pb.ModelOptio return result.VRAMForContext(ctxSize) } -// installBackendOnNode sends a NATS backend.install request-reply to the node. -// The worker installs the backend from gallery (if not already installed), -// starts the gRPC process, and replies when ready. -// installBackendOnNode installs a backend on a node and returns the gRPC address. +// installBackendOnNode sends a NATS backend.install request-reply to the node +// and returns the gRPC address. Concurrent identical calls (same nodeID + +// backend + modelID + replica) coalesce via singleflight: 6 chat completions +// for the same not-yet-loaded model produce 1 NATS round-trip and 6 callers +// share the result. This kills the load-amplification we saw in the live +// cluster where 6× simultaneous BackendLoader logs sat behind one slow +// install in the worker's NATS callback queue. +// +// Routine load: the worker's fast-path "already running → return current +// address" is correct here. Upgrades go through +// DistributedBackendManager.UpgradeBackend on the backend.upgrade subject. func (r *SmartRouter) installBackendOnNode(ctx context.Context, node *BackendNode, backendType, modelID string, replicaIndex int) (string, error) { if r.unloader == nil { return "", fmt.Errorf("no NATS connection for backend installation") } - // force=false: routine load, the worker's fast-path "already running → - // return current address" is correct here. Upgrades go through - // DistributedBackendManager.UpgradeBackend which sets force=true. - reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "", replicaIndex, false) + key := fmt.Sprintf("%s|%s|%s|%d", node.ID, backendType, modelID, replicaIndex) + v, err, _ := r.installFlight.Do(key, func() (any, error) { + reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "", replicaIndex) + if err != nil { + return "", err + } + if !reply.Success { + return "", fmt.Errorf("worker replied with error: %s", reply.Error) + } + // Return the backend's gRPC address (per-replica port from worker) + addr := reply.Address + if addr == "" { + addr = node.Address // fallback to node base address + } + return addr, nil + }) if err != nil { return "", err } - if !reply.Success { - return "", fmt.Errorf("worker replied with error: %s", reply.Error) - } - // Return the backend's gRPC address (per-replica port from worker) - addr := reply.Address - if addr == "" { - addr = node.Address // fallback to node base address - } - return addr, nil + return v.(string), nil } func (r *SmartRouter) buildClientForAddr(node *BackendNode, addr string, parallel bool) grpc.Backend { diff --git a/core/services/nodes/router_test.go b/core/services/nodes/router_test.go index a63c0521e..966df2f33 100644 --- a/core/services/nodes/router_test.go +++ b/core/services/nodes/router_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "runtime" + "sync" "time" . "github.com/onsi/ginkgo/v2" @@ -289,13 +290,27 @@ func (f *stubClientFactory) NewClient(_ string, _ bool) grpc.Backend { // --------------------------------------------------------------------------- type fakeUnloader struct { + // mu guards installCalls and upgradeCalls so concurrent test + // goroutines (e.g. singleflight specs) don't race the slice appends. + mu sync.Mutex + installReply *messaging.BackendInstallReply installErr error installCalls []installCall // every InstallBackend invocation, in order - stopCalls []string // "nodeID:model" - stopErr error - unloadCalls []string - unloadErr error + // installHook, if non-nil, runs at the start of InstallBackend before + // the call is recorded. Used by concurrency tests as a deterministic + // "block here" seam — set installHook to a function that sleeps or + // blocks on a channel to overlap two callers. + installHook func() + + upgradeReply *messaging.BackendUpgradeReply + upgradeErr error + upgradeCalls []upgradeCall // every UpgradeBackend invocation, in order + + stopCalls []string // "nodeID:model" + stopErr error + unloadCalls []string + unloadErr error } // installCall captures the args we care about when asserting that the @@ -307,14 +322,34 @@ type installCall struct { backend string modelID string replica int - force bool } -func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int, force bool) (*messaging.BackendInstallReply, error) { - f.installCalls = append(f.installCalls, installCall{nodeID, backend, modelID, replica, force}) +type upgradeCall struct { + nodeID string + backend string + replica int +} + +func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int) (*messaging.BackendInstallReply, error) { + // installHook intentionally runs OUTSIDE the mutex: the hook may block + // on a channel and we don't want to serialize concurrent callers, + // which would defeat the singleflight-overlap test. + if f.installHook != nil { + f.installHook() + } + f.mu.Lock() + f.installCalls = append(f.installCalls, installCall{nodeID, backend, modelID, replica}) + f.mu.Unlock() return f.installReply, f.installErr } +func (f *fakeUnloader) UpgradeBackend(nodeID, backend, _, _, _, _ string, replica int) (*messaging.BackendUpgradeReply, error) { + f.mu.Lock() + f.upgradeCalls = append(f.upgradeCalls, upgradeCall{nodeID, backend, replica}) + f.mu.Unlock() + return f.upgradeReply, f.upgradeErr +} + func (f *fakeUnloader) DeleteBackend(_, _ string) (*messaging.BackendDeleteReply, error) { return &messaging.BackendDeleteReply{Success: true}, nil } @@ -951,4 +986,68 @@ var _ = Describe("SmartRouter", func() { Expect(out).To(Equal(candidates)) }) }) + + Describe("installBackendOnNode singleflight", func() { + It("coalesces concurrent identical installs into one NATS call", func() { + node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"} + + // Slow install reply so concurrent calls overlap deterministically. + started := make(chan struct{}, 5) + release := make(chan struct{}) + unloader := &fakeUnloader{ + installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"}, + } + unloader.installHook = func() { + started <- struct{}{} + <-release + } + + router := NewSmartRouter(&fakeModelRouter{}, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: &stubClientFactory{client: &stubBackend{}}, + }) + + // Fire 5 concurrent identical installBackendOnNode calls. + done := make(chan error, 5) + for i := 0; i < 5; i++ { + go func() { + _, err := router.installBackendOnNode(context.Background(), node, "llama-cpp", "my-model", 0) + done <- err + }() + } + + // Only ONE call should have entered the unloader hook (the + // singleflight leader). The other 4 are coalesced and waiting on + // the leader's result. + Eventually(started).Should(Receive()) + Consistently(started, 100*time.Millisecond).ShouldNot(Receive()) + + // Release the leader; the other 4 callers receive the same result. + close(release) + for i := 0; i < 5; i++ { + Expect(<-done).ToNot(HaveOccurred()) + } + Expect(unloader.installCalls).To(HaveLen(1), + "singleflight should coalesce 5 concurrent identical loads into 1 NATS call") + }) + + It("does NOT coalesce installs for different (modelID, replica) keys", func() { + node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"} + unloader := &fakeUnloader{ + installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"}, + } + router := NewSmartRouter(&fakeModelRouter{}, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: &stubClientFactory{client: &stubBackend{}}, + }) + + _, err1 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-A", 0) + _, err2 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-B", 0) + _, err3 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-A", 1) + Expect(err1).ToNot(HaveOccurred()) + Expect(err2).ToNot(HaveOccurred()) + Expect(err3).ToNot(HaveOccurred()) + Expect(unloader.installCalls).To(HaveLen(3)) + }) + }) }) diff --git a/core/services/nodes/unloader.go b/core/services/nodes/unloader.go index 677964481..611a34dee 100644 --- a/core/services/nodes/unloader.go +++ b/core/services/nodes/unloader.go @@ -17,12 +17,19 @@ type backendStopRequest struct { // NodeCommandSender abstracts NATS-based commands to worker nodes. // Used by HTTP endpoint handlers to avoid coupling to the concrete RemoteUnloaderAdapter. // -// The `force` parameter on InstallBackend is set by the upgrade path to make -// the worker re-run the gallery install (overwriting the on-disk artifact) and -// restart any live process for that backend. Routine installs and load events -// pass force=false so an already-running process short-circuits as before. +// InstallBackend is idempotent: the worker short-circuits if the backend is +// already running for the requested (modelID, replica) slot. Routine model +// loads and admin installs both call this. +// +// UpgradeBackend is the destructive force-reinstall path: the worker stops +// every live process for the backend, re-pulls the gallery artifact, and +// replies. Caller (DistributedBackendManager.UpgradeBackend) handles +// rolling-update fallback to the legacy install Force=true path on +// nats.ErrNoResponders for old workers that don't subscribe to the new +// backend.upgrade subject. type NodeCommandSender interface { - InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int, force bool) (*messaging.BackendInstallReply, error) + InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error) + UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error) DeleteBackend(nodeID, backendName string) (*messaging.BackendDeleteReply, error) ListBackends(nodeID string) (*messaging.BackendListReply, error) StopBackend(nodeID, backend string) error @@ -75,22 +82,21 @@ func (a *RemoteUnloaderAdapter) UnloadRemoteModel(modelName string) error { } // InstallBackend sends a backend.install request-reply to a worker node. -// The worker installs the backend from gallery (if not already installed), -// starts the gRPC process, and replies when ready. +// Idempotent on the worker: if the (modelID, replica) process is already +// running, the worker short-circuits and returns its address; if the binary +// is on disk, the worker just spawns a process; only a missing binary +// triggers a full gallery pull. // -// replicaIndex selects which replica slot the worker should use as its -// process key — distinct slots run on distinct ports so multiple replicas of -// the same model can coexist on a fat node. Pass 0 for single-replica. +// Timeout: 3 minutes. Most calls return in under 2 seconds (process already +// running). The 3-minute ceiling covers the cold-binary spawn-after-download +// case while still failing fast enough to surface real worker hangs. // -// force=true is the upgrade path: the worker stops any live process for this -// backend, overwrites the on-disk artifact via gallery install, and restarts. -// Routine installs and load events pass force=false to keep the existing -// "already running → return current address" fast path. -// -// Timeout: 5 minutes (gallery install can take a while). -func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int, force bool) (*messaging.BackendInstallReply, error) { +// For force-reinstall (admin-driven Upgrade), use UpgradeBackend instead — +// it lives on a different NATS subject so it cannot head-of-line-block +// routine load traffic on the same worker. +func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error) { subject := messaging.SubjectNodeBackendInstall(nodeID) - xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex, "force", force) + xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex) return messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{ Backend: backendType, @@ -100,8 +106,50 @@ func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, gal Name: name, Alias: alias, ReplicaIndex: int32(replicaIndex), - Force: force, - }, 5*time.Minute) + }, 3*time.Minute) +} + +// UpgradeBackend sends a backend.upgrade request-reply to a worker node. +// The worker stops every live process for this backend, force-reinstalls +// from the gallery (overwriting the on-disk artifact), and replies. The +// next routine InstallBackend call spawns a fresh process with the new +// binary — upgrade itself does not start a process. +// +// Timeout: 15 minutes. Real-world worst case observed: 8–10 minutes for +// large CUDA-l4t backend images on Jetson over WiFi. +func (a *RemoteUnloaderAdapter) UpgradeBackend(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendUpgradeReply, error) { + subject := messaging.SubjectNodeBackendUpgrade(nodeID) + xlog.Info("Sending NATS backend.upgrade", "nodeID", nodeID, "backend", backendType, "replica", replicaIndex) + + return messaging.RequestJSON[messaging.BackendUpgradeRequest, messaging.BackendUpgradeReply](a.nats, subject, messaging.BackendUpgradeRequest{ + Backend: backendType, + BackendGalleries: galleriesJSON, + URI: uri, + Name: name, + Alias: alias, + ReplicaIndex: int32(replicaIndex), + }, 15*time.Minute) +} + +// installWithForceFallback is the rolling-update fallback used by +// DistributedBackendManager.UpgradeBackend when backend.upgrade returns +// nats.ErrNoResponders (the worker is on a pre-2026-05-08 build that +// doesn't subscribe to the new subject). It re-fires the legacy +// backend.install with Force=true. Drop this once every worker is on +// 2026-05-08 or newer. +func (a *RemoteUnloaderAdapter) installWithForceFallback(nodeID, backendType, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error) { + subject := messaging.SubjectNodeBackendInstall(nodeID) + xlog.Warn("Falling back to legacy backend.install Force=true (old worker)", "nodeID", nodeID, "backend", backendType) + + return messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{ + Backend: backendType, + BackendGalleries: galleriesJSON, + URI: uri, + Name: name, + Alias: alias, + ReplicaIndex: int32(replicaIndex), + Force: true, + }, 15*time.Minute) } // ListBackends queries a worker node for its installed backends via NATS request-reply. diff --git a/core/services/nodes/unloader_upgrade_test.go b/core/services/nodes/unloader_upgrade_test.go new file mode 100644 index 000000000..2261dd02e --- /dev/null +++ b/core/services/nodes/unloader_upgrade_test.go @@ -0,0 +1,31 @@ +package nodes + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/messaging" +) + +var _ = Describe("RemoteUnloaderAdapter.UpgradeBackend", func() { + It("fires a NATS request to the backend.upgrade subject and returns the reply", func() { + mc := newScriptedMessagingClient() + nodeID := "node-x" + + mc.scriptReply(messaging.SubjectNodeBackendUpgrade(nodeID), + messaging.BackendUpgradeReply{Success: true}) + + adapter := NewRemoteUnloaderAdapter(nil, mc) + reply, err := adapter.UpgradeBackend(nodeID, "llama-cpp", `[{"name":"x"}]`, "", "", "", 0) + Expect(err).ToNot(HaveOccurred()) + Expect(reply.Success).To(BeTrue()) + }) + + It("returns the underlying error when the subject has no responders", func() { + mc := newScriptedMessagingClient() // unscripted subject => fakeNoRespondersErr by harness convention + + adapter := NewRemoteUnloaderAdapter(nil, mc) + _, err := adapter.UpgradeBackend("missing-node", "llama-cpp", "", "", "", "", 0) + Expect(err).To(HaveOccurred()) + }) +}) diff --git a/core/cli/worker_addr_test.go b/core/services/worker/addr_test.go similarity index 77% rename from core/cli/worker_addr_test.go rename to core/services/worker/addr_test.go index 6dd7b0ab9..5265a287a 100644 --- a/core/cli/worker_addr_test.go +++ b/core/services/worker/addr_test.go @@ -1,4 +1,4 @@ -package cli +package worker import ( "os" @@ -8,12 +8,12 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("WorkerCMD address resolution", func() { +var _ = Describe("Worker address resolution", func() { Describe("effectiveBasePort", func() { DescribeTable("returns the correct port", func(addr, serve string, want int) { - cmd := &WorkerCMD{Addr: addr, ServeAddr: serve} - Expect(cmd.effectiveBasePort()).To(Equal(want)) + cfg := &Config{Addr: addr, ServeAddr: serve} + Expect(cfg.effectiveBasePort()).To(Equal(want)) }, Entry("Addr takes priority", "worker1.example.com:60000", "0.0.0.0:50051", 60000), Entry("falls back to ServeAddr", "", "0.0.0.0:50051", 50051), @@ -25,21 +25,21 @@ var _ = Describe("WorkerCMD address resolution", func() { Describe("advertiseAddr", func() { It("returns AdvertiseAddr when set", func() { - cmd := &WorkerCMD{ + cfg := &Config{ AdvertiseAddr: "public.example.com:50051", Addr: "10.0.0.5:60000", } - Expect(cmd.advertiseAddr()).To(Equal("public.example.com:50051")) + Expect(cfg.advertiseAddr()).To(Equal("public.example.com:50051")) }) It("returns Addr when set", func() { - cmd := &WorkerCMD{Addr: "worker1.example.com:60000"} - Expect(cmd.advertiseAddr()).To(Equal("worker1.example.com:60000")) + cfg := &Config{Addr: "worker1.example.com:60000"} + Expect(cfg.advertiseAddr()).To(Equal("worker1.example.com:60000")) }) It("falls back to hostname:basePort", func() { - cmd := &WorkerCMD{ServeAddr: "0.0.0.0:50051"} - got := cmd.advertiseAddr() + cfg := &Config{ServeAddr: "0.0.0.0:50051"} + got := cfg.advertiseAddr() _, port, _ := strings.Cut(got, ":") Expect(port).To(Equal("50051")) @@ -54,8 +54,8 @@ var _ = Describe("WorkerCMD address resolution", func() { Describe("resolveHTTPAddr", func() { DescribeTable("returns the correct address", func(httpAddr, addr, serve, want string) { - cmd := &WorkerCMD{HTTPAddr: httpAddr, Addr: addr, ServeAddr: serve} - Expect(cmd.resolveHTTPAddr()).To(Equal(want)) + cfg := &Config{HTTPAddr: httpAddr, Addr: addr, ServeAddr: serve} + Expect(cfg.resolveHTTPAddr()).To(Equal(want)) }, Entry("HTTPAddr takes priority", "0.0.0.0:8080", "", "", "0.0.0.0:8080"), Entry("derives from Addr port minus 1", "", "worker1:60000", "0.0.0.0:50051", "0.0.0.0:59999"), @@ -67,13 +67,13 @@ var _ = Describe("WorkerCMD address resolution", func() { Describe("advertiseHTTPAddr", func() { DescribeTable("returns the correct address", func(advertiseHTTP, advertise, addr, serve, want string) { - cmd := &WorkerCMD{ + cfg := &Config{ AdvertiseHTTPAddr: advertiseHTTP, AdvertiseAddr: advertise, Addr: addr, ServeAddr: serve, } - Expect(cmd.advertiseHTTPAddr()).To(Equal(want)) + Expect(cfg.advertiseHTTPAddr()).To(Equal(want)) }, Entry("AdvertiseHTTPAddr takes priority", "public.example.com:8080", "", "", "", "public.example.com:8080"), Entry("derives from advertiseAddr host + basePort-1", "", "", "worker1.example.com:60000", "", "worker1.example.com:59999"), diff --git a/core/services/worker/concurrency_test.go b/core/services/worker/concurrency_test.go new file mode 100644 index 000000000..c3b0b8900 --- /dev/null +++ b/core/services/worker/concurrency_test.go @@ -0,0 +1,105 @@ +package worker + +import ( + "sync" + "sync/atomic" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("backendSupervisor.lockBackend", func() { + It("serializes operations on the same backend name", func() { + s := &backendSupervisor{processes: map[string]*backendProcess{}} + + var inflight, peak int32 + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + release := s.lockBackend("llama-cpp") + defer release() + + now := atomic.AddInt32(&inflight, 1) + for { + p := atomic.LoadInt32(&peak) + if now <= p || atomic.CompareAndSwapInt32(&peak, p, now) { + break + } + } + time.Sleep(20 * time.Millisecond) + atomic.AddInt32(&inflight, -1) + }() + } + wg.Wait() + + Expect(atomic.LoadInt32(&peak)).To(Equal(int32(1)), + "only one goroutine should hold the per-backend lock at a time") + }) + + It("allows different backend names to run in parallel", func() { + s := &backendSupervisor{processes: map[string]*backendProcess{}} + + var inflight, peak int32 + var wg sync.WaitGroup + names := []string{"llama-cpp", "vllm", "whisper", "speaker-recognition"} + for _, n := range names { + n := n + wg.Add(1) + go func() { + defer wg.Done() + release := s.lockBackend(n) + defer release() + + now := atomic.AddInt32(&inflight, 1) + for { + p := atomic.LoadInt32(&peak) + if now <= p || atomic.CompareAndSwapInt32(&peak, p, now) { + break + } + } + time.Sleep(50 * time.Millisecond) + atomic.AddInt32(&inflight, -1) + }() + } + wg.Wait() + + Expect(atomic.LoadInt32(&peak)).To(BeNumerically(">=", int32(2)), + "distinct backends should be able to run concurrently") + }) +}) + +var _ = Describe("backendSupervisor upgrade handler", func() { + It("serializes upgrade against install for the same backend name", func() { + s := &backendSupervisor{processes: map[string]*backendProcess{}} + + var inflight, peak int32 + var wg sync.WaitGroup + + // Simulate one install + one upgrade on the same backend name. + // The two handlers each acquire lockBackend("llama-cpp"); only one + // should hold the lock at a time. + acquire := func() { + defer wg.Done() + release := s.lockBackend("llama-cpp") + defer release() + now := atomic.AddInt32(&inflight, 1) + for { + p := atomic.LoadInt32(&peak) + if now <= p || atomic.CompareAndSwapInt32(&peak, p, now) { + break + } + } + time.Sleep(20 * time.Millisecond) + atomic.AddInt32(&inflight, -1) + } + wg.Add(2) + go acquire() + go acquire() + wg.Wait() + + Expect(atomic.LoadInt32(&peak)).To(Equal(int32(1))) + }) +}) diff --git a/core/services/worker/config.go b/core/services/worker/config.go new file mode 100644 index 000000000..002306217 --- /dev/null +++ b/core/services/worker/config.go @@ -0,0 +1,59 @@ +package worker + +// Config is the configuration for the distributed agent worker. +// +// Field tags are kong/kong-env metadata read by core/cli/worker.go's WorkerCMD, +// which embeds Config; this package does NOT import kong and the tags are inert +// here. +// +// Workers are backend-agnostic — they wait for backend.install NATS events +// from the SmartRouter to install and start the required backend. +// +// NATS is required. The worker acts as a process supervisor: +// - Receives backend.install → installs backend from gallery, starts gRPC process, replies success +// - Receives backend.stop → stops the gRPC process +// - Receives stop → full shutdown (deregister + exit) +// +// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that. +type Config struct { + // Primary address — the reachable address of this worker. + // Host is used for advertise, port is the base for gRPC backends. + // HTTP file transfer runs on port-1. + Addr string `env:"LOCALAI_ADDR" help:"Address where this worker is reachable (host:port). Port is base for gRPC backends, port-1 for HTTP." group:"server"` + ServeAddr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"(Advanced) gRPC base port bind address" group:"server" hidden:""` + + BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"` + BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"` + BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"` + ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"` + + // HTTP file transfer + HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server" hidden:""` + AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server" hidden:""` + + // Registration (required) + AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""` + RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"` + NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"` + RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"` + HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"` + NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"` + // MaxReplicasPerModel caps how many replicas of any one model can run on + // this worker concurrently. Default 1 = historical single-replica + // behavior. Set higher when a node has enough VRAM to host multiple + // copies of the same model (e.g. a fat 128 GiB box running 4× of a + // 24 GiB model for throughput). The auto-label `node.replica-slots=N` + // is published so model schedulers can target high-capacity nodes via + // the existing label selector. + MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"` + + // NATS (required) + NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"` + + // S3 storage for distributed file transfer + StorageURL string `env:"LOCALAI_STORAGE_URL" help:"S3 endpoint URL" group:"distributed"` + StorageBucket string `env:"LOCALAI_STORAGE_BUCKET" help:"S3 bucket name" group:"distributed"` + StorageRegion string `env:"LOCALAI_STORAGE_REGION" help:"S3 region" group:"distributed"` + StorageAccessKey string `env:"LOCALAI_STORAGE_ACCESS_KEY" help:"S3 access key" group:"distributed"` + StorageSecretKey string `env:"LOCALAI_STORAGE_SECRET_KEY" help:"S3 secret key" group:"distributed"` +} diff --git a/core/services/worker/file_staging.go b/core/services/worker/file_staging.go new file mode 100644 index 000000000..9d834a712 --- /dev/null +++ b/core/services/worker/file_staging.go @@ -0,0 +1,185 @@ +package worker + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/core/services/storage" + "github.com/mudler/xlog" +) + +// isPathAllowed checks if path is within one of the allowed directories. +func isPathAllowed(path string, allowedDirs []string) bool { + absPath, err := filepath.Abs(path) + if err != nil { + return false + } + resolved, err := filepath.EvalSymlinks(absPath) + if err != nil { + // Path may not exist yet; use the absolute path + resolved = absPath + } + for _, dir := range allowedDirs { + absDir, err := filepath.Abs(dir) + if err != nil { + continue + } + if strings.HasPrefix(resolved, absDir+string(filepath.Separator)) || resolved == absDir { + return true + } + } + return false +} + +// subscribeFileStaging subscribes to NATS file staging subjects for this node. +func (cfg *Config) subscribeFileStaging(natsClient messaging.MessagingClient, nodeID string) error { + // Create FileManager with same S3 config as the frontend + // TODO: propagate a caller-provided context once Config carries one + s3Store, err := storage.NewS3Store(context.Background(), storage.S3Config{ + Endpoint: cfg.StorageURL, + Region: cfg.StorageRegion, + Bucket: cfg.StorageBucket, + AccessKeyID: cfg.StorageAccessKey, + SecretAccessKey: cfg.StorageSecretKey, + ForcePathStyle: true, + }) + if err != nil { + return fmt.Errorf("initializing S3 store: %w", err) + } + + cacheDir := filepath.Join(cfg.ModelsPath, "..", "cache") + fm, err := storage.NewFileManager(s3Store, cacheDir) + if err != nil { + return fmt.Errorf("initializing file manager: %w", err) + } + + // Subscribe: files.ensure — download S3 key to local, reply with local path + natsClient.SubscribeReply(messaging.SubjectNodeFilesEnsure(nodeID), func(data []byte, reply func([]byte)) { + var req struct { + Key string `json:"key"` + } + if err := json.Unmarshal(data, &req); err != nil { + replyJSON(reply, map[string]string{"error": "invalid request"}) + return + } + + localPath, err := fm.Download(context.Background(), req.Key) + if err != nil { + xlog.Error("File ensure failed", "key", req.Key, "error", err) + replyJSON(reply, map[string]string{"error": err.Error()}) + return + } + + xlog.Debug("File ensured locally", "key", req.Key, "path", localPath) + replyJSON(reply, map[string]string{"local_path": localPath}) + }) + + // Subscribe: files.stage — upload local path to S3, reply with key + natsClient.SubscribeReply(messaging.SubjectNodeFilesStage(nodeID), func(data []byte, reply func([]byte)) { + var req struct { + LocalPath string `json:"local_path"` + Key string `json:"key"` + } + if err := json.Unmarshal(data, &req); err != nil { + replyJSON(reply, map[string]string{"error": "invalid request"}) + return + } + + allowedDirs := []string{cacheDir} + if cfg.ModelsPath != "" { + allowedDirs = append(allowedDirs, cfg.ModelsPath) + } + if !isPathAllowed(req.LocalPath, allowedDirs) { + replyJSON(reply, map[string]string{"error": "path outside allowed directories"}) + return + } + + if err := fm.Upload(context.Background(), req.Key, req.LocalPath); err != nil { + xlog.Error("File stage failed", "path", req.LocalPath, "key", req.Key, "error", err) + replyJSON(reply, map[string]string{"error": err.Error()}) + return + } + + xlog.Debug("File staged to S3", "path", req.LocalPath, "key", req.Key) + replyJSON(reply, map[string]string{"key": req.Key}) + }) + + // Subscribe: files.temp — allocate temp file, reply with local path + natsClient.SubscribeReply(messaging.SubjectNodeFilesTemp(nodeID), func(data []byte, reply func([]byte)) { + tmpDir := filepath.Join(cacheDir, "staging-tmp") + if err := os.MkdirAll(tmpDir, 0750); err != nil { + replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp dir: %v", err)}) + return + } + + f, err := os.CreateTemp(tmpDir, "localai-staging-*.tmp") + if err != nil { + replyJSON(reply, map[string]string{"error": fmt.Sprintf("creating temp file: %v", err)}) + return + } + localPath := f.Name() + f.Close() + + xlog.Debug("Allocated temp file", "path", localPath) + replyJSON(reply, map[string]string{"local_path": localPath}) + }) + + // Subscribe: files.listdir — list files in a local directory, reply with relative paths + natsClient.SubscribeReply(messaging.SubjectNodeFilesListDir(nodeID), func(data []byte, reply func([]byte)) { + var req struct { + KeyPrefix string `json:"key_prefix"` + } + if err := json.Unmarshal(data, &req); err != nil { + replyJSON(reply, map[string]any{"error": "invalid request"}) + return + } + + // Resolve key prefix to local directory + dirPath := filepath.Join(cacheDir, req.KeyPrefix) + if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.ModelKeyPrefix); ok && cfg.ModelsPath != "" { + dirPath = filepath.Join(cfg.ModelsPath, rel) + } else if rel, ok := strings.CutPrefix(req.KeyPrefix, storage.DataKeyPrefix); ok { + dirPath = filepath.Join(cacheDir, "..", "data", rel) + } + + // Sanitize to prevent directory traversal via crafted key_prefix + dirPath = filepath.Clean(dirPath) + cleanCache := filepath.Clean(cacheDir) + cleanModels := filepath.Clean(cfg.ModelsPath) + cleanData := filepath.Clean(filepath.Join(cacheDir, "..", "data")) + if !(strings.HasPrefix(dirPath, cleanCache+string(filepath.Separator)) || + dirPath == cleanCache || + (cleanModels != "." && strings.HasPrefix(dirPath, cleanModels+string(filepath.Separator))) || + dirPath == cleanModels || + strings.HasPrefix(dirPath, cleanData+string(filepath.Separator)) || + dirPath == cleanData) { + replyJSON(reply, map[string]any{"error": "invalid key prefix"}) + return + } + + var files []string + filepath.WalkDir(dirPath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if !d.IsDir() { + rel, err := filepath.Rel(dirPath, path) + if err == nil { + files = append(files, rel) + } + } + return nil + }) + + xlog.Debug("Listed remote dir", "keyPrefix", req.KeyPrefix, "dirPath", dirPath, "fileCount", len(files)) + replyJSON(reply, map[string]any{"files": files}) + }) + + xlog.Info("Subscribed to file staging NATS subjects", "nodeID", nodeID) + return nil +} diff --git a/core/services/worker/install.go b/core/services/worker/install.go new file mode 100644 index 000000000..f078ee5ef --- /dev/null +++ b/core/services/worker/install.go @@ -0,0 +1,225 @@ +package worker + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/xlog" +) + +// buildProcessKey is the supervisor's stable identifier for a backend gRPC +// process. It includes the replica index so the same model can run multiple +// processes on a worker simultaneously without colliding on the same map slot +// or port. The "#N" suffix is purely internal — the controller never reads it. +func buildProcessKey(modelID, backend string, replicaIndex int) string { + base := modelID + if base == "" { + base = backend + } + return fmt.Sprintf("%s#%d", base, replicaIndex) +} + +// installBackend handles the backend.install flow. force=true is the +// upgrade path; force=false is the routine load path. +// +// The caller is responsible for holding s.lockBackend(req.Backend) for +// the duration of the call so the gallery directory isn't raced. +// +// 1. If already running for this (model, replica) slot AND force is false, +// return existing address (the fast path used by routine load events that +// just want to know which port a backend already serves on). +// 2. If force is true, stop any process(es) currently using this backend +// so the gallery install can replace the on-disk artifact and the freshly +// started process picks up the new binary. This is the upgrade path — +// without it, every backend.install we receive after the first hits the +// fast path and silently no-ops, leaving the cluster on a stale build. +// 3. Install backend from gallery (force passed through so existing artifacts +// get overwritten on upgrade). +// 4. Find backend binary +// 5. Start gRPC process on a new port +// +// Returns the gRPC address of the backend process. +// +// ProcessKey includes the replica index so a worker with MaxReplicasPerModel>1 +// can host multiple processes for the same model on distinct ports. Old +// controllers (no replica_index in the request) implicitly target replica 0, +// which preserves single-replica behavior. +func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest, force bool) (string, error) { + processKey := buildProcessKey(req.ModelID, req.Backend, int(req.ReplicaIndex)) + + if !force { + // Fast path: already running for this model+replica → return existing + // address. Verify liveness before trusting the cached entry: a process + // that died without the supervisor noticing leaves a stale (key, addr) + // pair, and getAddr would otherwise hand the controller an address + // that immediately ECONNREFUSEDs. The reconciler then marks the + // replica failed, retries the install, the supervisor says "already + // running" again, and the cluster loops on a dead replica forever. + if addr := s.getAddr(processKey); addr != "" { + if s.isRunning(processKey) { + xlog.Info("Backend already running for model replica", "backend", req.Backend, "model", req.ModelID, "replica", req.ReplicaIndex, "addr", addr) + return addr, nil + } + xlog.Warn("Stale process entry for backend (dead process); cleaning up before reinstall", + "backend", req.Backend, "model", req.ModelID, "replica", req.ReplicaIndex, "addr", addr) + s.stopBackendExact(processKey) + } + } else { + // Upgrade path: stop every live process that shares this backend so the + // gallery install can overwrite the on-disk artifact and the restarted + // process picks up the new binary. resolveProcessKeys catches peer + // replicas of the same backend (whisper#0, whisper#1, ...) on workers + // configured with MaxReplicasPerModel>1. We also stop the exact + // processKey from the request tuple — keys created with an explicit + // modelID don't share the bare-name prefix the resolver matches, but + // they're still using the old binary and need to come down. Both calls + // are no-ops on missing keys. + toStop := s.resolveProcessKeys(req.Backend) + toStop = append(toStop, processKey) + for _, key := range toStop { + xlog.Info("Force install: stopping running backend before reinstall", + "backend", req.Backend, "processKey", key) + s.stopBackendExact(key) + } + } + + // Parse galleries from request (override local config if provided) + galleries := s.galleries + if req.BackendGalleries != "" { + var reqGalleries []config.Gallery + if err := json.Unmarshal([]byte(req.BackendGalleries), &reqGalleries); err == nil { + galleries = reqGalleries + } + } + + // On upgrade, run the gallery install path even if the binary already + // exists on disk: findBackend would otherwise short-circuit and we'd + // restart the same stale binary. The force flag passed to + // InstallBackendFromGallery makes it overwrite the existing artifact. + backendPath := "" + if !force { + backendPath = s.findBackend(req.Backend) + } + if backendPath == "" { + if req.URI != "" { + xlog.Info("Installing backend from external URI", "backend", req.Backend, "uri", req.URI, "force", force) + if err := galleryop.InstallExternalBackend( + context.Background(), galleries, s.systemState, s.ml, nil, req.URI, req.Name, req.Alias, + ); err != nil { + return "", fmt.Errorf("installing backend from gallery: %w", err) + } + } else { + xlog.Info("Installing backend from gallery", "backend", req.Backend, "force", force) + if err := gallery.InstallBackendFromGallery( + context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, force, + ); err != nil { + return "", fmt.Errorf("installing backend from gallery: %w", err) + } + } + // Re-register after install and retry + gallery.RegisterBackends(s.systemState, s.ml) + backendPath = s.findBackend(req.Backend) + } + + if backendPath == "" { + return "", fmt.Errorf("backend %q not found after install attempt", req.Backend) + } + + xlog.Info("Found backend binary", "path", backendPath, "processKey", processKey) + + // Start the gRPC process on a new port (keyed by model, not just backend) + return s.startBackend(processKey, backendPath) +} + +// upgradeBackend stops every running process for `backend`, force-reinstalls +// from gallery (overwriting the on-disk artifact), and re-registers backends. +// It does NOT start any new gRPC process — the next routine model load via +// backend.install will spawn a fresh process picking up the new binary. +// +// The caller is responsible for holding s.lockBackend(req.Backend). +func (s *backendSupervisor) upgradeBackend(req messaging.BackendUpgradeRequest) error { + // Stop every live process for this backend (peer replicas + the bare + // processKey). Same logic as the force branch in installBackend. + toStop := s.resolveProcessKeys(req.Backend) + toStop = append(toStop, buildProcessKey("", req.Backend, int(req.ReplicaIndex))) + for _, key := range toStop { + xlog.Info("Upgrade: stopping running backend before reinstall", + "backend", req.Backend, "processKey", key) + s.stopBackendExact(key) + } + + galleries := s.galleries + if req.BackendGalleries != "" { + var reqGalleries []config.Gallery + if err := json.Unmarshal([]byte(req.BackendGalleries), &reqGalleries); err == nil { + galleries = reqGalleries + } + } + + if req.URI != "" { + xlog.Info("Upgrading backend from external URI", "backend", req.Backend, "uri", req.URI) + if err := galleryop.InstallExternalBackend( + context.Background(), galleries, s.systemState, s.ml, nil, req.URI, req.Name, req.Alias, + ); err != nil { + return fmt.Errorf("upgrading backend from external URI: %w", err) + } + } else { + xlog.Info("Upgrading backend from gallery", "backend", req.Backend) + if err := gallery.InstallBackendFromGallery( + context.Background(), galleries, s.systemState, s.ml, req.Backend, nil, true, /* force */ + ); err != nil { + return fmt.Errorf("upgrading backend from gallery: %w", err) + } + } + + gallery.RegisterBackends(s.systemState, s.ml) + return nil +} + +// findBackend looks for the backend binary in the backends path and system path. +func (s *backendSupervisor) findBackend(backend string) string { + candidates := []string{ + filepath.Join(s.cfg.BackendsPath, backend), + filepath.Join(s.cfg.BackendsPath, backend, backend), + filepath.Join(s.cfg.BackendsSystemPath, backend), + filepath.Join(s.cfg.BackendsSystemPath, backend, backend), + } + if uri := s.ml.GetExternalBackend(backend); uri != "" { + if fi, err := os.Stat(uri); err == nil && !fi.IsDir() { + return uri + } + } + for _, path := range candidates { + fi, err := os.Stat(path) + if err == nil && !fi.IsDir() { + return path + } + } + return "" +} + +// lockBackend returns a release function for a per-backend mutex. Different +// backend names lock independently. The first caller for a name allocates +// the mutex under s.mu; subsequent callers for the same name reuse it. +func (s *backendSupervisor) lockBackend(name string) func() { + s.mu.Lock() + if s.backendLocks == nil { + s.backendLocks = make(map[string]*sync.Mutex) + } + m, ok := s.backendLocks[name] + if !ok { + m = &sync.Mutex{} + s.backendLocks[name] = m + } + s.mu.Unlock() + m.Lock() + return m.Unlock +} diff --git a/core/services/worker/lifecycle.go b/core/services/worker/lifecycle.go new file mode 100644 index 000000000..3c78b004d --- /dev/null +++ b/core/services/worker/lifecycle.go @@ -0,0 +1,247 @@ +package worker + +import ( + "context" + "encoding/json" + "fmt" + "net" + "syscall" + + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/services/messaging" + grpc "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/xlog" +) + +// subscribeLifecycleEvents wires every NATS subject this worker accepts to its +// per-event handler method. Each handler lives on *backendSupervisor below; +// keeping the dispatcher to a single line per subject makes adding a new +// subject a 2-line patch (one line here, one new method) instead of grafting +// onto a monolith. +func (s *backendSupervisor) subscribeLifecycleEvents() { + s.nats.SubscribeReply(messaging.SubjectNodeBackendInstall(s.nodeID), s.handleBackendInstall) + s.nats.SubscribeReply(messaging.SubjectNodeBackendUpgrade(s.nodeID), s.handleBackendUpgrade) + s.nats.Subscribe(messaging.SubjectNodeBackendStop(s.nodeID), s.handleBackendStop) + s.nats.SubscribeReply(messaging.SubjectNodeBackendDelete(s.nodeID), s.handleBackendDelete) + s.nats.SubscribeReply(messaging.SubjectNodeBackendList(s.nodeID), s.handleBackendList) + s.nats.SubscribeReply(messaging.SubjectNodeModelUnload(s.nodeID), s.handleModelUnload) + s.nats.SubscribeReply(messaging.SubjectNodeModelDelete(s.nodeID), s.handleModelDelete) + s.nats.Subscribe(messaging.SubjectNodeStop(s.nodeID), s.handleNodeStop) +} + +// handleBackendInstall is the NATS callback for backend.install — install +// backend (idempotent: skips download if binary exists on disk) + start gRPC +// process (request-reply). +// +// Each request runs in its own goroutine so that a slow install on one +// backend does NOT head-of-line-block install requests for unrelated +// backends arriving on the same subscription. Per-backend serialization +// is provided by lockBackend so two requests targeting the same on-disk +// artifact don't race the gallery directory. +func (s *backendSupervisor) handleBackendInstall(data []byte, reply func([]byte)) { + go func() { + xlog.Info("Received NATS backend.install event") + var req messaging.BackendInstallRequest + if err := json.Unmarshal(data, &req); err != nil { + resp := messaging.BackendInstallReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} + replyJSON(reply, resp) + return + } + + release := s.lockBackend(req.Backend) + defer release() + + // req.Force=true is the legacy path used by pre-2026-05-08 masters + // that don't know about backend.upgrade. Honor it so a rolling + // update with new worker + old master keeps working; new masters + // send to backend.upgrade instead. + addr, err := s.installBackend(req, req.Force) + if err != nil { + xlog.Error("Failed to install backend via NATS", "error", err) + resp := messaging.BackendInstallReply{Success: false, Error: err.Error()} + replyJSON(reply, resp) + return + } + + advertiseAddr := addr + advAddr := s.cfg.advertiseAddr() + if advAddr != addr { + _, port, _ := net.SplitHostPort(addr) + advertiseHost, _, _ := net.SplitHostPort(advAddr) + advertiseAddr = net.JoinHostPort(advertiseHost, port) + } + resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr} + replyJSON(reply, resp) + }() +} + +// handleBackendUpgrade is the NATS callback for backend.upgrade — force-reinstall +// a backend (request-reply). Lives on its own subscription so a multi-minute +// download here does NOT block the install fast-path subscription on the same +// worker. +func (s *backendSupervisor) handleBackendUpgrade(data []byte, reply func([]byte)) { + go func() { + xlog.Info("Received NATS backend.upgrade event") + var req messaging.BackendUpgradeRequest + if err := json.Unmarshal(data, &req); err != nil { + resp := messaging.BackendUpgradeReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} + replyJSON(reply, resp) + return + } + + release := s.lockBackend(req.Backend) + defer release() + + if err := s.upgradeBackend(req); err != nil { + xlog.Error("Failed to upgrade backend via NATS", "error", err) + replyJSON(reply, messaging.BackendUpgradeReply{Success: false, Error: err.Error()}) + return + } + replyJSON(reply, messaging.BackendUpgradeReply{Success: true}) + }() +} + +// handleBackendStop is the NATS callback for backend.stop — stop a specific +// backend process (fire-and-forget, no reply expected). +func (s *backendSupervisor) handleBackendStop(data []byte) { + // Try to parse backend name from payload; if empty, stop all + var req struct { + Backend string `json:"backend"` + } + if json.Unmarshal(data, &req) == nil && req.Backend != "" { + xlog.Info("Received NATS backend.stop event", "backend", req.Backend) + s.stopBackend(req.Backend) + } else { + xlog.Info("Received NATS backend.stop event (all)") + s.stopAllBackends() + } +} + +// handleBackendDelete is the NATS callback for backend.delete — stop the +// backend process if running, then remove its files from disk (request-reply). +func (s *backendSupervisor) handleBackendDelete(data []byte, reply func([]byte)) { + var req messaging.BackendDeleteRequest + if err := json.Unmarshal(data, &req); err != nil { + resp := messaging.BackendDeleteReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} + replyJSON(reply, resp) + return + } + xlog.Info("Received NATS backend.delete event", "backend", req.Backend) + + // Stop if running this backend + if s.isRunning(req.Backend) { + s.stopBackend(req.Backend) + } + + // Delete the backend files + if err := gallery.DeleteBackendFromSystem(s.systemState, req.Backend); err != nil { + xlog.Warn("Failed to delete backend files", "backend", req.Backend, "error", err) + resp := messaging.BackendDeleteReply{Success: false, Error: err.Error()} + replyJSON(reply, resp) + return + } + + // Re-register backends after deletion + gallery.RegisterBackends(s.systemState, s.ml) + + resp := messaging.BackendDeleteReply{Success: true} + replyJSON(reply, resp) +} + +// handleBackendList is the NATS callback for backend.list — reply with the +// installed backends from this node's gallery (request-reply). +func (s *backendSupervisor) handleBackendList(data []byte, reply func([]byte)) { + xlog.Info("Received NATS backend.list event") + backends, err := gallery.ListSystemBackends(s.systemState) + if err != nil { + resp := messaging.BackendListReply{Error: err.Error()} + replyJSON(reply, resp) + return + } + + var infos []messaging.NodeBackendInfo + for name, b := range backends { + info := messaging.NodeBackendInfo{ + Name: name, + IsSystem: b.IsSystem, + IsMeta: b.IsMeta, + } + if b.Metadata != nil { + info.InstalledAt = b.Metadata.InstalledAt + info.GalleryURL = b.Metadata.GalleryURL + info.Version = b.Metadata.Version + info.URI = b.Metadata.URI + info.Digest = b.Metadata.Digest + } + infos = append(infos, info) + } + + resp := messaging.BackendListReply{Backends: infos} + replyJSON(reply, resp) +} + +// handleModelUnload is the NATS callback for model.unload — call gRPC Free() +// to release GPU memory without killing the backend process (request-reply). +func (s *backendSupervisor) handleModelUnload(data []byte, reply func([]byte)) { + xlog.Info("Received NATS model.unload event") + var req messaging.ModelUnloadRequest + if err := json.Unmarshal(data, &req); err != nil { + resp := messaging.ModelUnloadReply{Success: false, Error: fmt.Sprintf("invalid request: %v", err)} + replyJSON(reply, resp) + return + } + + // Find the backend address for this model's backend type + // The request includes an Address field if the router knows which process to target + targetAddr := req.Address + if targetAddr == "" { + // Fallback: try all running backends + s.mu.Lock() + for _, bp := range s.processes { + targetAddr = bp.addr + break + } + s.mu.Unlock() + } + + if targetAddr != "" { + // Best-effort gRPC Free() + client := grpc.NewClientWithToken(targetAddr, false, nil, false, s.cfg.RegistrationToken) + if err := client.Free(context.Background()); err != nil { + xlog.Warn("Free() failed during model.unload", "error", err, "addr", targetAddr) + } + } + + resp := messaging.ModelUnloadReply{Success: true} + replyJSON(reply, resp) +} + +// handleModelDelete is the NATS callback for model.delete — remove model +// files from disk (request-reply). +func (s *backendSupervisor) handleModelDelete(data []byte, reply func([]byte)) { + xlog.Info("Received NATS model.delete event") + var req messaging.ModelDeleteRequest + if err := json.Unmarshal(data, &req); err != nil { + replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: "invalid request"}) + return + } + + if err := gallery.DeleteStagedModelFiles(s.cfg.ModelsPath, req.ModelName); err != nil { + xlog.Warn("Failed to delete model files", "model", req.ModelName, "error", err) + replyJSON(reply, messaging.ModelDeleteReply{Success: false, Error: err.Error()}) + return + } + + replyJSON(reply, messaging.ModelDeleteReply{Success: true}) +} + +// handleNodeStop is the NATS callback for node.stop — trigger the normal +// shutdown path via sigCh so deferred cleanup runs (fire-and-forget). +func (s *backendSupervisor) handleNodeStop(data []byte) { + xlog.Info("Received NATS stop event — signaling shutdown") + select { + case s.sigCh <- syscall.SIGTERM: + default: + xlog.Debug("Shutdown already signaled, ignoring duplicate stop") + } +} diff --git a/core/services/worker/process_helpers.go b/core/services/worker/process_helpers.go new file mode 100644 index 000000000..68d3ac6de --- /dev/null +++ b/core/services/worker/process_helpers.go @@ -0,0 +1,20 @@ +package worker + +import ( + "os" + "strings" +) + +// readLastLinesFromFile reads the last n lines from a file. +// Returns an empty string if the file cannot be read. +func readLastLinesFromFile(path string, n int) string { + data, err := os.ReadFile(path) + if err != nil { + return "" + } + lines := strings.Split(strings.TrimRight(string(data), "\n"), "\n") + if len(lines) > n { + lines = lines[len(lines)-n:] + } + return strings.Join(lines, "\n") +} diff --git a/core/services/worker/registration.go b/core/services/worker/registration.go new file mode 100644 index 000000000..87a8a7966 --- /dev/null +++ b/core/services/worker/registration.go @@ -0,0 +1,142 @@ +package worker + +import ( + "cmp" + "fmt" + "os" + "strconv" + "strings" + + "github.com/mudler/LocalAI/pkg/xsysinfo" +) + +// effectiveBasePort returns the port used as base for gRPC backend processes. +// Priority: Addr port → ServeAddr port → 50051 +func (cfg *Config) effectiveBasePort() int { + for _, addr := range []string{cfg.Addr, cfg.ServeAddr} { + if addr != "" { + if _, portStr, ok := strings.Cut(addr, ":"); ok { + if p, _ := strconv.Atoi(portStr); p > 0 { + return p + } + } + } + } + return 50051 +} + +// advertiseAddr returns the address the frontend should use to reach this node. +func (cfg *Config) advertiseAddr() string { + if cfg.AdvertiseAddr != "" { + return cfg.AdvertiseAddr + } + if cfg.Addr != "" { + return cfg.Addr + } + hostname, _ := os.Hostname() + return fmt.Sprintf("%s:%d", cmp.Or(hostname, "localhost"), cfg.effectiveBasePort()) +} + +// resolveHTTPAddr returns the address to bind the HTTP file transfer server to. +// Uses basePort-1 so it doesn't conflict with dynamically allocated gRPC ports +// which grow upward from basePort. +func (cfg *Config) resolveHTTPAddr() string { + if cfg.HTTPAddr != "" { + return cfg.HTTPAddr + } + return fmt.Sprintf("0.0.0.0:%d", cfg.effectiveBasePort()-1) +} + +// advertiseHTTPAddr returns the HTTP address the frontend should use to reach +// this node for file transfer. +func (cfg *Config) advertiseHTTPAddr() string { + if cfg.AdvertiseHTTPAddr != "" { + return cfg.AdvertiseHTTPAddr + } + advHost, _, _ := strings.Cut(cfg.advertiseAddr(), ":") + httpPort := cfg.effectiveBasePort() - 1 + return fmt.Sprintf("%s:%d", advHost, httpPort) +} + +// registrationBody builds the JSON body for node registration. +func (cfg *Config) registrationBody() map[string]any { + nodeName := cfg.NodeName + if nodeName == "" { + hostname, err := os.Hostname() + if err != nil { + nodeName = fmt.Sprintf("node-%d", os.Getpid()) + } else { + nodeName = hostname + } + } + + // Detect GPU info for VRAM-aware scheduling + totalVRAM, _ := xsysinfo.TotalAvailableVRAM() + gpuVendor, _ := xsysinfo.DetectGPUVendor() + + maxReplicas := cfg.MaxReplicasPerModel + if maxReplicas < 1 { + maxReplicas = 1 + } + body := map[string]any{ + "name": nodeName, + "address": cfg.advertiseAddr(), + "http_address": cfg.advertiseHTTPAddr(), + "total_vram": totalVRAM, + "available_vram": totalVRAM, // initially all VRAM is available + "gpu_vendor": gpuVendor, + "max_replicas_per_model": maxReplicas, + } + + // If no GPU detected, report system RAM so the scheduler/UI has capacity info + if totalVRAM == 0 { + if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil { + body["total_ram"] = ramInfo.Total + body["available_ram"] = ramInfo.Available + } + } + if cfg.RegistrationToken != "" { + body["token"] = cfg.RegistrationToken + } + + // Parse and add static node labels. Always include the auto-label + // `node.replica-slots=N` so AND-selectors in ModelSchedulingConfig can + // target high-capacity nodes (e.g. {"node.replica-slots":"4"}). + labels := make(map[string]string) + if cfg.NodeLabels != "" { + for _, pair := range strings.Split(cfg.NodeLabels, ",") { + pair = strings.TrimSpace(pair) + if k, v, ok := strings.Cut(pair, "="); ok { + labels[strings.TrimSpace(k)] = strings.TrimSpace(v) + } + } + } + labels["node.replica-slots"] = strconv.Itoa(maxReplicas) + body["labels"] = labels + + return body +} + +// heartbeatBody returns the current VRAM/RAM stats for heartbeat payloads. +// +// When aggregate VRAM usage is unknown (no GPU, or temporary detection +// failure), we deliberately OMIT available_vram so the frontend keeps its +// last good value — overwriting with 0 makes the UI show the node as "fully +// used", while reporting total-as-available lies to the scheduler about +// free capacity. +func (cfg *Config) heartbeatBody() map[string]any { + body := map[string]any{} + aggregate := xsysinfo.GetGPUAggregateInfo() + if aggregate.TotalVRAM > 0 { + body["available_vram"] = aggregate.FreeVRAM + } + + // CPU-only workers (or workers that lost GPU visibility momentarily): + // report system RAM so the scheduler still has capacity info. + if aggregate.TotalVRAM == 0 { + if ramInfo, err := xsysinfo.GetSystemRAMInfo(); err == nil { + body["available_ram"] = ramInfo.Available + } + } + return body +} diff --git a/core/cli/worker_replica_test.go b/core/services/worker/replica_test.go similarity index 95% rename from core/cli/worker_replica_test.go rename to core/services/worker/replica_test.go index 36e88a67a..a8c0cbb05 100644 --- a/core/cli/worker_replica_test.go +++ b/core/services/worker/replica_test.go @@ -1,4 +1,4 @@ -package cli +package worker import ( . "github.com/onsi/ginkgo/v2" @@ -30,11 +30,11 @@ var _ = Describe("Worker per-replica process keying", func() { Describe("registrationBody", func() { It("includes max_replicas_per_model and the auto-label", func() { - cmd := &WorkerCMD{ + cfg := &Config{ Addr: "worker.example.com:50051", MaxReplicasPerModel: 4, } - body := cmd.registrationBody() + body := cfg.registrationBody() Expect(body).To(HaveKey("max_replicas_per_model")) Expect(body["max_replicas_per_model"]).To(Equal(4)) @@ -45,8 +45,8 @@ var _ = Describe("Worker per-replica process keying", func() { }) It("coerces zero/unset MaxReplicasPerModel to 1", func() { - cmd := &WorkerCMD{Addr: "worker.example.com:50051"} - body := cmd.registrationBody() + cfg := &Config{Addr: "worker.example.com:50051"} + body := cfg.registrationBody() Expect(body["max_replicas_per_model"]).To(Equal(1), "unset must default to single-replica behavior, not capacity 0") @@ -55,12 +55,12 @@ var _ = Describe("Worker per-replica process keying", func() { }) It("preserves user-provided labels alongside the auto-label", func() { - cmd := &WorkerCMD{ + cfg := &Config{ Addr: "worker.example.com:50051", MaxReplicasPerModel: 2, NodeLabels: "tier=fast,gpu=a100", } - body := cmd.registrationBody() + body := cfg.registrationBody() labels := body["labels"].(map[string]string) Expect(labels).To(HaveKeyWithValue("tier", "fast")) Expect(labels).To(HaveKeyWithValue("gpu", "a100")) diff --git a/core/services/worker/reply.go b/core/services/worker/reply.go new file mode 100644 index 000000000..9700f19fa --- /dev/null +++ b/core/services/worker/reply.go @@ -0,0 +1,17 @@ +package worker + +import ( + "encoding/json" + + "github.com/mudler/xlog" +) + +// replyJSON marshals v to JSON and calls the reply function. +func replyJSON(reply func([]byte), v any) { + data, err := json.Marshal(v) + if err != nil { + xlog.Error("Failed to marshal NATS reply", "error", err) + data = []byte(`{"error":"internal marshal error"}`) + } + reply(data) +} diff --git a/core/services/worker/supervisor.go b/core/services/worker/supervisor.go new file mode 100644 index 000000000..5fda9bae8 --- /dev/null +++ b/core/services/worker/supervisor.go @@ -0,0 +1,272 @@ +package worker + +import ( + "context" + "fmt" + "maps" + "net" + "os" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/messaging" + grpc "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/system" + process "github.com/mudler/go-processmanager" + "github.com/mudler/xlog" +) + +// backendProcess represents a single gRPC backend process. +type backendProcess struct { + proc *process.Process + backend string + addr string // gRPC address (host:port) +} + +// backendSupervisor manages multiple backend gRPC processes on different ports. +// Each backend type (e.g., llama-cpp, bert-embeddings) gets its own process and port. +type backendSupervisor struct { + cfg *Config + ml *model.ModelLoader + systemState *system.SystemState + galleries []config.Gallery + nodeID string + nats messaging.MessagingClient + sigCh chan<- os.Signal // send shutdown signal instead of os.Exit + + mu sync.Mutex + processes map[string]*backendProcess // key: backend name + nextPort int // next available port for new backends + freePorts []int // ports freed by stopBackend, reused before nextPort + + // backendLocks serializes gallery operations against the same on-disk + // artifact. Two installs of different backends on the same worker run + // concurrently (their handlers are each in a goroutine); two operations + // on the same backend (install vs upgrade, or two parallel installs of + // the same not-yet-cached backend) are serialized here so the gallery + // download path doesn't race itself on the same directory. + backendLocks map[string]*sync.Mutex +} + +// startBackend starts a gRPC backend process on a dynamically allocated port. +// Returns the gRPC address. +func (s *backendSupervisor) startBackend(backend, backendPath string) (string, error) { + s.mu.Lock() + + // Already running? + if bp, ok := s.processes[backend]; ok { + if bp.proc != nil && bp.proc.IsAlive() { + s.mu.Unlock() + return bp.addr, nil + } + // Process died — clean up and restart + xlog.Warn("Backend process died unexpectedly, restarting", "backend", backend) + delete(s.processes, backend) + } + + // Allocate port — recycle freed ports first, then grow upward from basePort + var port int + if len(s.freePorts) > 0 { + port = s.freePorts[len(s.freePorts)-1] + s.freePorts = s.freePorts[:len(s.freePorts)-1] + } else { + port = s.nextPort + s.nextPort++ + } + bindAddr := fmt.Sprintf("0.0.0.0:%d", port) + clientAddr := fmt.Sprintf("127.0.0.1:%d", port) + + proc, err := s.ml.StartProcess(backendPath, backend, bindAddr) + if err != nil { + s.mu.Unlock() + return "", fmt.Errorf("starting backend process: %w", err) + } + + s.processes[backend] = &backendProcess{ + proc: proc, + backend: backend, + addr: clientAddr, + } + xlog.Info("Backend process started", "backend", backend, "addr", clientAddr) + + // Capture reference before unlocking for race-safe health check. + // Another goroutine could stopBackend and recycle the port while we poll. + bp := s.processes[backend] + s.mu.Unlock() + + // Wait for the gRPC server to be ready before reporting success. + // Slow nodes (Jetson Orin doing first-boot CUDA init, large CGO libs) + // can take 10-15s before the gRPC port accepts connections; the previous + // 4s window made the worker reply Success on a not-yet-listening port, + // which manifested upstream as "connect: connection refused" on the + // frontend's first LoadModel dial. + client := grpc.NewClientWithToken(clientAddr, false, nil, false, s.cfg.RegistrationToken) + const ( + readinessPollInterval = 200 * time.Millisecond + readinessTimeout = 30 * time.Second + ) + deadline := time.Now().Add(readinessTimeout) + for time.Now().Before(deadline) { + time.Sleep(readinessPollInterval) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + if ok, _ := client.HealthCheck(ctx); ok { + cancel() + // Verify the process wasn't stopped/replaced while health-checking + s.mu.Lock() + currentBP, exists := s.processes[backend] + s.mu.Unlock() + if !exists || currentBP != bp { + return "", fmt.Errorf("backend %s was stopped during startup", backend) + } + xlog.Debug("Backend gRPC server is ready", "backend", backend, "addr", clientAddr) + return clientAddr, nil + } + cancel() + + // Check if the process died (e.g. OOM, CUDA error, missing libs) + if !proc.IsAlive() { + stderrTail := readLastLinesFromFile(proc.StderrPath(), 20) + xlog.Warn("Backend process died during startup", "backend", backend, "stderr", stderrTail) + s.mu.Lock() + delete(s.processes, backend) + s.freePorts = append(s.freePorts, port) + s.mu.Unlock() + return "", fmt.Errorf("backend process %s died during startup. Last stderr:\n%s", backend, stderrTail) + } + } + + // Readiness deadline exceeded. Returning success here would leave the + // frontend with an unbound address (it dials, gets ECONNREFUSED, and + // the operator sees a misleading "connection refused" instead of the + // real cause). Stop the half-started process, recycle the port, and + // surface the failure to the caller with the backend's stderr tail. + stderrTail := readLastLinesFromFile(proc.StderrPath(), 20) + xlog.Error("Backend gRPC server not ready before deadline; aborting install", "backend", backend, "addr", clientAddr, "timeout", readinessTimeout, "stderr", stderrTail) + if killErr := proc.Stop(); killErr != nil { + xlog.Warn("Failed to stop unready backend process", "backend", backend, "error", killErr) + } + s.mu.Lock() + if cur, ok := s.processes[backend]; ok && cur == bp { + delete(s.processes, backend) + s.freePorts = append(s.freePorts, port) + } + s.mu.Unlock() + return "", fmt.Errorf("backend %s did not become ready within %s. Last stderr:\n%s", backend, readinessTimeout, stderrTail) +} + +// resolveProcessKeys turns a caller-supplied identifier into the set of +// process map keys it refers to. PR #9583 changed s.processes to be keyed by +// `modelID#replicaIndex`, but external NATS handlers still pass the bare +// model ID — without this resolver, those lookups silently no-op'd, so +// admin "Unload model" / "Delete backend" left the worker process alive. +// +// - Exact match wins. Callers that already know the full processKey +// (stopAllBackends iterating its own map) get exactly that entry. +// - Else, an identifier without `#` is treated as a model prefix and +// every `id#N` replica is returned. +// - An identifier that contains `#` but doesn't match anything returns +// nothing — no spurious prefix fallback when the caller was explicit. +func (s *backendSupervisor) resolveProcessKeys(id string) []string { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.processes[id]; ok { + return []string{id} + } + if strings.Contains(id, "#") { + return nil + } + prefix := id + "#" + var keys []string + for k := range s.processes { + if strings.HasPrefix(k, prefix) { + keys = append(keys, k) + } + } + return keys +} + +// stopBackend stops the backend process(es) matching the given identifier. +// Accepts a bare modelID (stops every replica) or a full processKey +// (stops just that replica). +func (s *backendSupervisor) stopBackend(id string) { + for _, key := range s.resolveProcessKeys(id) { + s.stopBackendExact(key) + } +} + +// stopBackendExact stops the process under exactly this key. Locking and +// network I/O are split: the map mutation runs under the lock, the gRPC +// Free() and proc.Stop() calls run after release so they don't block +// other supervisor operations. +func (s *backendSupervisor) stopBackendExact(key string) { + s.mu.Lock() + bp, ok := s.processes[key] + if !ok || bp.proc == nil { + s.mu.Unlock() + return + } + delete(s.processes, key) + if _, portStr, err := net.SplitHostPort(bp.addr); err == nil { + if p, err := strconv.Atoi(portStr); err == nil { + s.freePorts = append(s.freePorts, p) + } + } + s.mu.Unlock() + + client := grpc.NewClientWithToken(bp.addr, false, nil, false, s.cfg.RegistrationToken) + xlog.Debug("Calling Free() before stopping backend", "backend", key) + if err := client.Free(context.Background()); err != nil { + xlog.Warn("Free() failed (best-effort)", "backend", key, "error", err) + } + + xlog.Info("Stopping backend process", "backend", key, "addr", bp.addr) + if err := bp.proc.Stop(); err != nil { + xlog.Error("Error stopping backend process", "backend", key, "error", err) + } +} + +// stopAllBackends stops all running backend processes. +func (s *backendSupervisor) stopAllBackends() { + s.mu.Lock() + backends := slices.Collect(maps.Keys(s.processes)) + s.mu.Unlock() + + for _, b := range backends { + s.stopBackend(b) + } +} + +// isRunning returns whether at least one backend process matching the given +// identifier is currently running. Accepts a bare modelID (matches any +// replica) or a full processKey (exact match). Callers like the +// backend.delete pre-check rely on the bare-name path. +func (s *backendSupervisor) isRunning(id string) bool { + keys := s.resolveProcessKeys(id) + if len(keys) == 0 { + // Same lock-free zero-process check the caller would have done. + return false + } + s.mu.Lock() + defer s.mu.Unlock() + for _, key := range keys { + if bp, ok := s.processes[key]; ok && bp.proc != nil && bp.proc.IsAlive() { + return true + } + } + return false +} + +// getAddr returns the gRPC address for a running backend, or empty string. +func (s *backendSupervisor) getAddr(backend string) string { + s.mu.Lock() + defer s.mu.Unlock() + if bp, ok := s.processes[backend]; ok { + return bp.addr + } + return "" +} diff --git a/core/services/worker/worker.go b/core/services/worker/worker.go new file mode 100644 index 000000000..744e368ac --- /dev/null +++ b/core/services/worker/worker.go @@ -0,0 +1,154 @@ +package worker + +import ( + "cmp" + "context" + "encoding/json" + "fmt" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + cliContext "github.com/mudler/LocalAI/core/cli/context" + "github.com/mudler/LocalAI/core/cli/workerregistry" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/core/services/nodes" + grpc "github.com/mudler/LocalAI/pkg/grpc" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/sanitize" + "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/xlog" +) + +// Run starts the distributed agent worker: registers with the frontend, +// subscribes to NATS lifecycle subjects, and blocks on signals. +func Run(ctx *cliContext.Context, cfg *Config) error { + xlog.Info("Starting worker", "advertise", cfg.advertiseAddr(), "basePort", cfg.effectiveBasePort()) + + systemState, err := system.GetSystemState( + system.WithModelPath(cfg.ModelsPath), + system.WithBackendPath(cfg.BackendsPath), + system.WithBackendSystemPath(cfg.BackendsSystemPath), + ) + if err != nil { + return fmt.Errorf("getting system state: %w", err) + } + + ml := model.NewModelLoader(systemState) + ml.SetBackendLoggingEnabled(true) + + // Register already-installed backends + gallery.RegisterBackends(systemState, ml) + + // Parse galleries config + var galleries []config.Gallery + if err := json.Unmarshal([]byte(cfg.BackendGalleries), &galleries); err != nil { + xlog.Warn("Failed to parse backend galleries", "error", err) + } + + // Self-registration with frontend (with retry) + regClient := &workerregistry.RegistrationClient{ + FrontendURL: cfg.RegisterTo, + RegistrationToken: cfg.RegistrationToken, + } + + registrationBody := cfg.registrationBody() + nodeID, _, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10) + if err != nil { + return fmt.Errorf("failed to register with frontend: %w", err) + } + + xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cfg.RegisterTo) + heartbeatInterval, err := time.ParseDuration(cfg.HeartbeatInterval) + if err != nil && cfg.HeartbeatInterval != "" { + xlog.Warn("invalid heartbeat interval, using default 10s", "input", cfg.HeartbeatInterval, "error", err) + } + heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second) + // Context cancelled on shutdown — used by heartbeat and other background goroutines + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + defer shutdownCancel() + + // Start HTTP file transfer server + httpAddr := cfg.resolveHTTPAddr() + stagingDir := filepath.Join(cfg.ModelsPath, "..", "staging") + dataDir := filepath.Join(cfg.ModelsPath, "..", "data") + httpServer, err := nodes.StartFileTransferServer(httpAddr, stagingDir, cfg.ModelsPath, dataDir, cfg.RegistrationToken, config.DefaultMaxUploadSize, ml.BackendLogs()) + if err != nil { + return fmt.Errorf("starting HTTP file transfer server: %w", err) + } + + // Connect to NATS + xlog.Info("Connecting to NATS", "url", sanitize.URL(cfg.NatsURL)) + natsClient, err := messaging.New(cfg.NatsURL) + if err != nil { + nodes.ShutdownFileTransferServer(httpServer) + return fmt.Errorf("connecting to NATS: %w", err) + } + defer natsClient.Close() + + // Start heartbeat goroutine (after NATS is connected so IsConnected check works) + go func() { + ticker := time.NewTicker(heartbeatInterval) + defer ticker.Stop() + for { + select { + case <-shutdownCtx.Done(): + return + case <-ticker.C: + if !natsClient.IsConnected() { + xlog.Warn("Skipping heartbeat: NATS disconnected") + continue + } + body := cfg.heartbeatBody() + if err := regClient.Heartbeat(shutdownCtx, nodeID, body); err != nil { + xlog.Warn("Heartbeat failed", "error", err) + } + } + } + }() + + // Process supervisor — manages multiple backend gRPC processes on different ports + basePort := cfg.effectiveBasePort() + // Buffered so NATS stop handler can send without blocking + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + // Set the registration token once before any backends are started + if cfg.RegistrationToken != "" { + os.Setenv(grpc.AuthTokenEnvVar, cfg.RegistrationToken) + } + + supervisor := &backendSupervisor{ + cfg: cfg, + ml: ml, + systemState: systemState, + galleries: galleries, + nodeID: nodeID, + nats: natsClient, + sigCh: sigCh, + processes: make(map[string]*backendProcess), + nextPort: basePort, + } + supervisor.subscribeLifecycleEvents() + + // Subscribe to file staging NATS subjects if S3 is configured + if cfg.StorageURL != "" { + if err := cfg.subscribeFileStaging(natsClient, nodeID); err != nil { + xlog.Error("Failed to subscribe to file staging subjects", "error", err) + } + } + + xlog.Info("Worker ready, waiting for backend.install events") + <-sigCh + + xlog.Info("Shutting down worker") + shutdownCancel() // stop heartbeat loop immediately + regClient.GracefulDeregister(nodeID) + supervisor.stopAllBackends() + nodes.ShutdownFileTransferServer(httpServer) + return nil +} diff --git a/core/services/worker/worker_suite_test.go b/core/services/worker/worker_suite_test.go new file mode 100644 index 000000000..64186d88f --- /dev/null +++ b/core/services/worker/worker_suite_test.go @@ -0,0 +1,13 @@ +package worker + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestWorker(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Worker Suite") +} diff --git a/tests/e2e/distributed/node_lifecycle_test.go b/tests/e2e/distributed/node_lifecycle_test.go index 83e4c0b72..ea69e2a2d 100644 --- a/tests/e2e/distributed/node_lifecycle_test.go +++ b/tests/e2e/distributed/node_lifecycle_test.go @@ -57,7 +57,7 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f FlushNATS(infra.NC) adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC) - installReply, err := adapter.InstallBackend(node.ID, "llama-cpp", "", "", "", "", "", 0, false) + installReply, err := adapter.InstallBackend(node.ID, "llama-cpp", "", "", "", "", "", 0) Expect(err).ToNot(HaveOccurred()) Expect(installReply.Success).To(BeTrue()) }) @@ -78,7 +78,7 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f FlushNATS(infra.NC) adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC) - installReply, err := adapter.InstallBackend(node.ID, "nonexistent", "", "", "", "", "", 0, false) + installReply, err := adapter.InstallBackend(node.ID, "nonexistent", "", "", "", "", "", 0) Expect(err).ToNot(HaveOccurred()) Expect(installReply.Success).To(BeFalse()) Expect(installReply.Error).To(ContainSubstring("backend not found"))