From 8862e3ce601fc7a446c0a532cba7d72ea825a4a7 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 31 Mar 2026 08:28:56 +0200 Subject: [PATCH] feat: add node reconciler, allow to schedule to group of nodes, min/max autoscaler (#9186) * always enable parallel requests Signed-off-by: Ettore Di Giacinto * feat: add node reconciler, allow to schedule to group of nodes, min/max autoscaler Signed-off-by: Ettore Di Giacinto * chore: move tests to ginkgo Signed-off-by: Ettore Di Giacinto * chore(smart router): order by available vram Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- core/application/config_file_watcher.go | 4 - core/application/distributed.go | 13 + core/application/startup.go | 9 +- core/backend/options.go | 4 +- core/cli/agent_test.go | 311 +++++++++----------- core/cli/cli_suite_test.go | 13 + core/cli/completion_test.go | 116 +++----- core/cli/run.go | 5 - core/cli/worker.go | 90 +++--- core/cli/worker_addr_test.go | 83 ++++++ core/config/application_config.go | 11 - core/config/application_config_test.go | 6 - core/config/runtime_settings.go | 2 - core/http/endpoints/localai/nodes.go | 186 +++++++++++- core/http/react-ui/src/pages/Nodes.jsx | 280 +++++++++++++++++- core/http/react-ui/src/pages/Settings.jsx | 3 - core/http/react-ui/src/utils/api.js | 7 + core/http/react-ui/src/utils/config.js | 5 + core/http/routes/nodes.go | 13 + core/services/nodes/interfaces.go | 6 + core/services/nodes/model_router_test.go | 18 ++ core/services/nodes/reconciler.go | 236 +++++++++++++++ core/services/nodes/reconciler_test.go | 241 ++++++++++++++++ core/services/nodes/registry.go | 336 +++++++++++++++++++++- core/services/nodes/registry_test.go | 246 ++++++++++++++++ core/services/nodes/router.go | 172 +++++++++-- core/services/nodes/router_test.go | 173 +++++++++++ docs/content/features/distributed-mode.md | 110 +++++++ docs/content/features/runtime-settings.md | 2 - docs/content/reference/cli-reference.md | 1 - 30 files changed, 2337 insertions(+), 365 deletions(-) create mode 100644 core/cli/cli_suite_test.go create mode 100644 core/cli/worker_addr_test.go create mode 100644 core/services/nodes/reconciler.go create mode 100644 core/services/nodes/reconciler_test.go diff --git a/core/application/config_file_watcher.go b/core/application/config_file_watcher.go index 094eebbbe..8eb26355d 100644 --- a/core/application/config_file_watcher.go +++ b/core/application/config_file_watcher.go @@ -199,7 +199,6 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends - envParallelRequests := appConfig.ParallelBackendRequests == startupAppConfig.ParallelBackendRequests envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled envMemoryReclaimerThreshold := appConfig.MemoryReclaimerThreshold == startupAppConfig.MemoryReclaimerThreshold envThreads := appConfig.Threads == startupAppConfig.Threads @@ -271,9 +270,6 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand appConfig.MaxActiveBackends = 0 } } - if settings.ParallelBackendRequests != nil && !envParallelRequests { - appConfig.ParallelBackendRequests = *settings.ParallelBackendRequests - } if settings.MemoryReclaimerEnabled != nil && !envMemoryReclaimerEnabled { appConfig.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled if appConfig.MemoryReclaimerEnabled { diff --git a/core/application/distributed.go b/core/application/distributed.go index 257fecdd4..31e87fdab 100644 --- a/core/application/distributed.go +++ b/core/application/distributed.go @@ -7,6 +7,7 @@ import ( "io" "strings" "sync" + "time" "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" @@ -28,6 +29,7 @@ type DistributedServices struct { Registry *nodes.NodeRegistry Router *nodes.SmartRouter Health *nodes.HealthMonitor + Reconciler *nodes.ReplicaReconciler JobStore *jobs.JobStore Dispatcher *jobs.Dispatcher AgentStore *agents.AgentStore @@ -240,6 +242,16 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut DB: authDB, }) + // Create ReplicaReconciler for auto-scaling model replicas + reconciler := nodes.NewReplicaReconciler(nodes.ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: router, + Unloader: remoteUnloader, + DB: authDB, + Interval: 30 * time.Second, + ScaleDownDelay: 5 * time.Minute, + }) + // Create ModelRouterAdapter to wire into ModelLoader modelAdapter := nodes.NewModelRouterAdapter(router) @@ -250,6 +262,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut Registry: registry, Router: router, Health: healthMon, + Reconciler: reconciler, JobStore: jobStore, Dispatcher: dispatcher, AgentStore: agentStore, diff --git a/core/application/startup.go b/core/application/startup.go index 026e55226..728c3c972 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -158,6 +158,10 @@ func New(opts ...config.AppOption) (*Application, error) { application.modelLoader.SetModelStore(distStore) // Start health monitor distSvc.Health.Start(options.Context) + // Start replica reconciler for auto-scaling model replicas + if distSvc.Reconciler != nil { + go distSvc.Reconciler.Run(options.Context) + } // In distributed mode, MCP CI jobs are executed by agent workers (not the frontend) // because the frontend can't create MCP sessions (e.g., stdio servers using docker). // The dispatcher still subscribes to jobs.new for persistence (result/progress subs) @@ -439,11 +443,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) { } } } - if settings.ParallelBackendRequests != nil { - if !options.ParallelBackendRequests { - options.ParallelBackendRequests = *settings.ParallelBackendRequests - } - } if settings.MemoryReclaimerEnabled != nil { // Only apply if current value is default (false), suggesting it wasn't set from env var if !options.MemoryReclaimerEnabled { diff --git a/core/backend/options.go b/core/backend/options.go index 2d410c90a..b09782ce2 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -59,9 +59,7 @@ func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...mo grpcOpts := grpcModelOpts(c, so.SystemState.Model.ModelsPath) defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts)) - if so.ParallelBackendRequests { - defOpts = append(defOpts, model.EnableParallelRequests) - } + defOpts = append(defOpts, model.EnableParallelRequests) if c.GRPC.Attempts != 0 { defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts)) diff --git a/core/cli/agent_test.go b/core/cli/agent_test.go index d5d29e5d8..2a952d99a 100644 --- a/core/cli/agent_test.go +++ b/core/cli/agent_test.go @@ -4,211 +4,172 @@ import ( "encoding/json" "os" "path/filepath" - "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" "github.com/mudler/LocalAGI/core/state" ) -func TestAgentRunCMD_LoadAgentConfigFromFile(t *testing.T) { - // Create a temporary agent config file - tmpDir := t.TempDir() - configFile := filepath.Join(tmpDir, "agent.json") +var _ = Describe("AgentRunCMD", func() { + Describe("loadAgentConfig", func() { + It("loads agent config from file", func() { + tmpDir := GinkgoT().TempDir() + configFile := filepath.Join(tmpDir, "agent.json") - cfg := state.AgentConfig{ - Name: "test-agent", - Model: "llama3", - SystemPrompt: "You are a helpful assistant", - } - data, err := json.MarshalIndent(cfg, "", " ") - if err != nil { - t.Fatal(err) - } - if err := os.WriteFile(configFile, data, 0644); err != nil { - t.Fatal(err) - } + cfg := state.AgentConfig{ + Name: "test-agent", + Model: "llama3", + SystemPrompt: "You are a helpful assistant", + } + data, err := json.MarshalIndent(cfg, "", " ") + Expect(err).ToNot(HaveOccurred()) + Expect(os.WriteFile(configFile, data, 0644)).To(Succeed()) - cmd := &AgentRunCMD{ - Config: configFile, - StateDir: tmpDir, - } + cmd := &AgentRunCMD{ + Config: configFile, + StateDir: tmpDir, + } - loaded, err := cmd.loadAgentConfig() - if err != nil { - t.Fatalf("loadAgentConfig() error: %v", err) - } - if loaded.Name != "test-agent" { - t.Errorf("expected name %q, got %q", "test-agent", loaded.Name) - } - if loaded.Model != "llama3" { - t.Errorf("expected model %q, got %q", "llama3", loaded.Model) - } -} + loaded, err := cmd.loadAgentConfig() + Expect(err).ToNot(HaveOccurred()) + Expect(loaded.Name).To(Equal("test-agent")) + Expect(loaded.Model).To(Equal("llama3")) + }) -func TestAgentRunCMD_LoadAgentConfigFromPool(t *testing.T) { - tmpDir := t.TempDir() + It("loads agent config from pool", func() { + tmpDir := GinkgoT().TempDir() - pool := map[string]state.AgentConfig{ - "my-agent": { - Model: "gpt-4", - Description: "A test agent", - SystemPrompt: "Hello", - }, - "other-agent": { - Model: "llama3", - }, - } - data, err := json.MarshalIndent(pool, "", " ") - if err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil { - t.Fatal(err) - } + pool := map[string]state.AgentConfig{ + "my-agent": { + Model: "gpt-4", + Description: "A test agent", + SystemPrompt: "Hello", + }, + "other-agent": { + Model: "llama3", + }, + } + data, err := json.MarshalIndent(pool, "", " ") + Expect(err).ToNot(HaveOccurred()) + Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed()) - cmd := &AgentRunCMD{ - Name: "my-agent", - StateDir: tmpDir, - } + cmd := &AgentRunCMD{ + Name: "my-agent", + StateDir: tmpDir, + } - loaded, err := cmd.loadAgentConfig() - if err != nil { - t.Fatalf("loadAgentConfig() error: %v", err) - } - if loaded.Name != "my-agent" { - t.Errorf("expected name %q, got %q", "my-agent", loaded.Name) - } - if loaded.Model != "gpt-4" { - t.Errorf("expected model %q, got %q", "gpt-4", loaded.Model) - } -} + loaded, err := cmd.loadAgentConfig() + Expect(err).ToNot(HaveOccurred()) + Expect(loaded.Name).To(Equal("my-agent")) + Expect(loaded.Model).To(Equal("gpt-4")) + }) -func TestAgentRunCMD_LoadAgentConfigFromPool_NotFound(t *testing.T) { - tmpDir := t.TempDir() + It("returns error for missing agent in pool", func() { + tmpDir := GinkgoT().TempDir() - pool := map[string]state.AgentConfig{ - "existing-agent": {Model: "llama3"}, - } - data, err := json.MarshalIndent(pool, "", " ") - if err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil { - t.Fatal(err) - } + pool := map[string]state.AgentConfig{ + "existing-agent": {Model: "llama3"}, + } + data, err := json.MarshalIndent(pool, "", " ") + Expect(err).ToNot(HaveOccurred()) + Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed()) - cmd := &AgentRunCMD{ - Name: "nonexistent", - StateDir: tmpDir, - } + cmd := &AgentRunCMD{ + Name: "nonexistent", + StateDir: tmpDir, + } - _, err = cmd.loadAgentConfig() - if err == nil { - t.Fatal("expected error for missing agent, got nil") - } -} + _, err = cmd.loadAgentConfig() + Expect(err).To(HaveOccurred()) + }) -func TestAgentRunCMD_LoadAgentConfigNoNameOrConfig(t *testing.T) { - cmd := &AgentRunCMD{ - StateDir: t.TempDir(), - } + It("returns error when no pool.json exists", func() { + cmd := &AgentRunCMD{ + StateDir: GinkgoT().TempDir(), + } - _, err := cmd.loadAgentConfig() - if err == nil { - t.Fatal("expected error when no pool.json exists, got nil") - } -} + _, err := cmd.loadAgentConfig() + Expect(err).To(HaveOccurred()) + }) -func TestAgentRunCMD_ApplyOverrides(t *testing.T) { - cfg := &state.AgentConfig{ - Name: "test", - } + It("returns error for config with no name", func() { + tmpDir := GinkgoT().TempDir() + configFile := filepath.Join(tmpDir, "agent.json") - cmd := &AgentRunCMD{ - APIURL: "http://localhost:9090", - APIKey: "secret", - DefaultModel: "my-model", - } + cfg := state.AgentConfig{ + Model: "llama3", + } + data, _ := json.MarshalIndent(cfg, "", " ") + Expect(os.WriteFile(configFile, data, 0644)).To(Succeed()) - cmd.applyOverrides(cfg) + cmd := &AgentRunCMD{ + Config: configFile, + StateDir: tmpDir, + } - if cfg.APIURL != "http://localhost:9090" { - t.Errorf("expected APIURL %q, got %q", "http://localhost:9090", cfg.APIURL) - } - if cfg.APIKey != "secret" { - t.Errorf("expected APIKey %q, got %q", "secret", cfg.APIKey) - } - if cfg.Model != "my-model" { - t.Errorf("expected Model %q, got %q", "my-model", cfg.Model) - } -} + _, err := cmd.loadAgentConfig() + Expect(err).To(HaveOccurred()) + }) + }) -func TestAgentRunCMD_ApplyOverridesDoesNotOverwriteExisting(t *testing.T) { - cfg := &state.AgentConfig{ - Name: "test", - Model: "existing-model", - } + Describe("applyOverrides", func() { + It("applies overrides to empty fields", func() { + cfg := &state.AgentConfig{ + Name: "test", + } - cmd := &AgentRunCMD{ - DefaultModel: "override-model", - } + cmd := &AgentRunCMD{ + APIURL: "http://localhost:9090", + APIKey: "secret", + DefaultModel: "my-model", + } - cmd.applyOverrides(cfg) + cmd.applyOverrides(cfg) - if cfg.Model != "existing-model" { - t.Errorf("expected Model to remain %q, got %q", "existing-model", cfg.Model) - } -} + Expect(cfg.APIURL).To(Equal("http://localhost:9090")) + Expect(cfg.APIKey).To(Equal("secret")) + Expect(cfg.Model).To(Equal("my-model")) + }) -func TestAgentRunCMD_LoadConfigMissingName(t *testing.T) { - tmpDir := t.TempDir() - configFile := filepath.Join(tmpDir, "agent.json") + It("does not overwrite existing model", func() { + cfg := &state.AgentConfig{ + Name: "test", + Model: "existing-model", + } - // Agent config with no name - cfg := state.AgentConfig{ - Model: "llama3", - } - data, _ := json.MarshalIndent(cfg, "", " ") - os.WriteFile(configFile, data, 0644) + cmd := &AgentRunCMD{ + DefaultModel: "override-model", + } - cmd := &AgentRunCMD{ - Config: configFile, - StateDir: tmpDir, - } + cmd.applyOverrides(cfg) - _, err := cmd.loadAgentConfig() - if err == nil { - t.Fatal("expected error for config with no name, got nil") - } -} + Expect(cfg.Model).To(Equal("existing-model")) + }) + }) +}) -func TestAgentListCMD_NoPoolFile(t *testing.T) { - cmd := &AgentListCMD{ - StateDir: t.TempDir(), - } +var _ = Describe("AgentListCMD", func() { + It("runs without error when no pool file exists", func() { + cmd := &AgentListCMD{ + StateDir: GinkgoT().TempDir(), + } + Expect(cmd.Run(nil)).To(Succeed()) + }) - // Should not error, just print "no agents found" - err := cmd.Run(nil) - if err != nil { - t.Fatalf("expected no error, got: %v", err) - } -} + It("runs without error with agents in pool", func() { + tmpDir := GinkgoT().TempDir() -func TestAgentListCMD_WithAgents(t *testing.T) { - tmpDir := t.TempDir() + pool := map[string]state.AgentConfig{ + "agent-a": {Model: "llama3", Description: "First agent"}, + "agent-b": {Model: "gpt-4"}, + } + data, _ := json.MarshalIndent(pool, "", " ") + Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed()) - pool := map[string]state.AgentConfig{ - "agent-a": {Model: "llama3", Description: "First agent"}, - "agent-b": {Model: "gpt-4"}, - } - data, _ := json.MarshalIndent(pool, "", " ") - os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644) - - cmd := &AgentListCMD{ - StateDir: tmpDir, - } - - err := cmd.Run(nil) - if err != nil { - t.Fatalf("expected no error, got: %v", err) - } -} + cmd := &AgentListCMD{ + StateDir: tmpDir, + } + Expect(cmd.Run(nil)).To(Succeed()) + }) +}) diff --git a/core/cli/cli_suite_test.go b/core/cli/cli_suite_test.go new file mode 100644 index 000000000..0b71fcdb9 --- /dev/null +++ b/core/cli/cli_suite_test.go @@ -0,0 +1,13 @@ +package cli + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestCLI(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "CLI Suite") +} diff --git a/core/cli/completion_test.go b/core/cli/completion_test.go index be625d051..dc5793075 100644 --- a/core/cli/completion_test.go +++ b/core/cli/completion_test.go @@ -1,10 +1,9 @@ package cli import ( - "strings" - "testing" - "github.com/alecthomas/kong" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" ) func getTestApp() *kong.Application { @@ -21,76 +20,55 @@ func getTestApp() *kong.Application { return k.Model } -func TestGenerateBashCompletion(t *testing.T) { - app := getTestApp() - script := generateBashCompletion(app) +var _ = Describe("Shell completions", func() { + var app *kong.Application - if !strings.Contains(script, "complete -F _local_ai_completions local-ai") { - t.Error("bash completion missing complete command registration") - } - if !strings.Contains(script, "run") { - t.Error("bash completion missing 'run' command") - } - if !strings.Contains(script, "models") { - t.Error("bash completion missing 'models' command") - } - if !strings.Contains(script, "completion") { - t.Error("bash completion missing 'completion' command") - } -} + BeforeEach(func() { + app = getTestApp() + }) -func TestGenerateZshCompletion(t *testing.T) { - app := getTestApp() - script := generateZshCompletion(app) + Describe("generateBashCompletion", func() { + It("generates valid bash completion script", func() { + script := generateBashCompletion(app) + Expect(script).To(ContainSubstring("complete -F _local_ai_completions local-ai")) + Expect(script).To(ContainSubstring("run")) + Expect(script).To(ContainSubstring("models")) + Expect(script).To(ContainSubstring("completion")) + }) + }) - if !strings.Contains(script, "#compdef local-ai") { - t.Error("zsh completion missing compdef header") - } - if !strings.Contains(script, "run") { - t.Error("zsh completion missing 'run' command") - } - if !strings.Contains(script, "models") { - t.Error("zsh completion missing 'models' command") - } -} + Describe("generateZshCompletion", func() { + It("generates valid zsh completion script", func() { + script := generateZshCompletion(app) + Expect(script).To(ContainSubstring("#compdef local-ai")) + Expect(script).To(ContainSubstring("run")) + Expect(script).To(ContainSubstring("models")) + }) + }) -func TestGenerateFishCompletion(t *testing.T) { - app := getTestApp() - script := generateFishCompletion(app) + Describe("generateFishCompletion", func() { + It("generates valid fish completion script", func() { + script := generateFishCompletion(app) + Expect(script).To(ContainSubstring("complete -c local-ai")) + Expect(script).To(ContainSubstring("__fish_use_subcommand")) + Expect(script).To(ContainSubstring("run")) + Expect(script).To(ContainSubstring("models")) + }) + }) - if !strings.Contains(script, "complete -c local-ai") { - t.Error("fish completion missing complete command") - } - if !strings.Contains(script, "__fish_use_subcommand") { - t.Error("fish completion missing subcommand detection") - } - if !strings.Contains(script, "run") { - t.Error("fish completion missing 'run' command") - } - if !strings.Contains(script, "models") { - t.Error("fish completion missing 'models' command") - } -} + Describe("collectCommands", func() { + It("collects all commands and subcommands", func() { + cmds := collectCommands(app.Node, "") -func TestCollectCommands(t *testing.T) { - app := getTestApp() - cmds := collectCommands(app.Node, "") + names := make(map[string]bool) + for _, cmd := range cmds { + names[cmd.fullName] = true + } - names := make(map[string]bool) - for _, cmd := range cmds { - names[cmd.fullName] = true - } - - if !names["run"] { - t.Error("missing 'run' command") - } - if !names["models"] { - t.Error("missing 'models' command") - } - if !names["models list"] { - t.Error("missing 'models list' subcommand") - } - if !names["models install"] { - t.Error("missing 'models install' subcommand") - } -} + Expect(names).To(HaveKey("run")) + Expect(names).To(HaveKey("models")) + Expect(names).To(HaveKey("models list")) + Expect(names).To(HaveKey("models install")) + }) + }) +}) diff --git a/core/cli/run.go b/core/cli/run.go index d3f1ac103..a478ab1c0 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -74,7 +74,6 @@ type RunCMD struct { Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"` Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2p-token" aliases:"p2ptoken" help:"Token for P2P mode (optional; --p2ptoken is deprecated, use --p2p-token)" group:"p2p"` Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"` - ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"` SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time (deprecated: use --max-active-backends=1 instead)" group:"backends"` MaxActiveBackends int `env:"LOCALAI_MAX_ACTIVE_BACKENDS,MAX_ACTIVE_BACKENDS" default:"0" help:"Maximum number of backends to keep loaded at once (0 = unlimited, 1 = single backend mode). Least recently used backends are evicted when limit is reached" group:"backends"` PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"` @@ -438,10 +437,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { opts = append(opts, config.WithMemoryReclaimer(true, r.MemoryReclaimerThreshold)) } - if r.ParallelRequests { - opts = append(opts, config.EnableParallelBackendRequests) - } - // Handle max active backends (LRU eviction) // MaxActiveBackends takes precedence over SingleActiveBackend if r.MaxActiveBackends > 0 { diff --git a/core/cli/worker.go b/core/cli/worker.go index f4c707e0c..b43d38fbb 100644 --- a/core/cli/worker.go +++ b/core/cli/worker.go @@ -67,22 +67,28 @@ func isPathAllowed(path string, allowedDirs []string) bool { // // Model loading (LoadModel) is always via direct gRPC — no NATS needed for that. type WorkerCMD struct { - Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"` + // Primary address — the reachable address of this worker. + // Host is used for advertise, port is the base for gRPC backends. + // HTTP file transfer runs on port-1. + Addr string `env:"LOCALAI_ADDR" help:"Address where this worker is reachable (host:port). Port is base for gRPC backends, port-1 for HTTP." group:"server"` + ServeAddr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"(Advanced) gRPC base port bind address" group:"server" hidden:""` + BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"` BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"` BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"` // HTTP file transfer - HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"` - AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"` + HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server" hidden:""` + AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server" hidden:""` // Registration (required) - AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"` + AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""` RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"` NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"` RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"` HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"` + NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"` // NATS (required) NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"` @@ -96,7 +102,7 @@ type WorkerCMD struct { } func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error { - xlog.Info("Starting worker", "addr", cmd.Addr) + xlog.Info("Starting worker", "advertise", cmd.advertiseAddr(), "basePort", cmd.effectiveBasePort()) systemState, err := system.GetSystemState( system.WithModelPath(cmd.ModelsPath), @@ -181,15 +187,7 @@ func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error { }() // Process supervisor — manages multiple backend gRPC processes on different ports - basePort := 50051 - if cmd.Addr != "" { - // Extract port from addr (e.g., "0.0.0.0:50051" → 50051) - if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil { - if p, err := strconv.Atoi(portStr); err == nil { - basePort = p - } - } - } + basePort := cmd.effectiveBasePort() // Buffered so NATS stop handler can send without blocking sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) @@ -667,10 +665,10 @@ func (s *backendSupervisor) subscribeLifecycleEvents() { // Return the gRPC address so the router knows which port to use advertiseAddr := addr - if s.cmd.AdvertiseAddr != "" { - // Replace 0.0.0.0 with the advertised host but keep the dynamic port + advAddr := s.cmd.advertiseAddr() + if advAddr != addr { // only remap if advertise differs from bind _, port, _ := net.SplitHostPort(addr) - advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr) + advertiseHost, _, _ := net.SplitHostPort(advAddr) advertiseAddr = net.JoinHostPort(advertiseHost, port) } resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr} @@ -816,18 +814,31 @@ func (s *backendSupervisor) subscribeLifecycleEvents() { }) } +// effectiveBasePort returns the port used as base for gRPC backend processes. +// Priority: Addr port → ServeAddr port → 50051 +func (cmd *WorkerCMD) effectiveBasePort() int { + for _, addr := range []string{cmd.Addr, cmd.ServeAddr} { + if addr != "" { + if _, portStr, ok := strings.Cut(addr, ":"); ok { + if p, _ := strconv.Atoi(portStr); p > 0 { + return p + } + } + } + } + return 50051 +} + // advertiseAddr returns the address the frontend should use to reach this node. func (cmd *WorkerCMD) advertiseAddr() string { if cmd.AdvertiseAddr != "" { return cmd.AdvertiseAddr } - host, port, ok := strings.Cut(cmd.Addr, ":") - if ok && (host == "0.0.0.0" || host == "") { - if hostname, err := os.Hostname(); err == nil { - return hostname + ":" + port - } + if cmd.Addr != "" { + return cmd.Addr } - return cmd.Addr + hostname, _ := os.Hostname() + return fmt.Sprintf("%s:%d", cmp.Or(hostname, "localhost"), cmd.effectiveBasePort()) } // resolveHTTPAddr returns the address to bind the HTTP file transfer server to. @@ -837,12 +848,7 @@ func (cmd *WorkerCMD) resolveHTTPAddr() string { if cmd.HTTPAddr != "" { return cmd.HTTPAddr } - host, port, ok := strings.Cut(cmd.Addr, ":") - if !ok { - return "0.0.0.0:50050" - } - portNum, _ := strconv.Atoi(port) - return fmt.Sprintf("%s:%d", host, portNum-1) + return fmt.Sprintf("0.0.0.0:%d", cmd.effectiveBasePort()-1) } // advertiseHTTPAddr returns the HTTP address the frontend should use to reach @@ -851,14 +857,9 @@ func (cmd *WorkerCMD) advertiseHTTPAddr() string { if cmd.AdvertiseHTTPAddr != "" { return cmd.AdvertiseHTTPAddr } - httpAddr := cmd.resolveHTTPAddr() - host, port, ok := strings.Cut(httpAddr, ":") - if ok && (host == "0.0.0.0" || host == "") { - if hostname, err := os.Hostname(); err == nil { - return hostname + ":" + port - } - } - return httpAddr + advHost, _, _ := strings.Cut(cmd.advertiseAddr(), ":") + httpPort := cmd.effectiveBasePort() - 1 + return fmt.Sprintf("%s:%d", advHost, httpPort) } // registrationBody builds the JSON body for node registration. @@ -896,6 +897,21 @@ func (cmd *WorkerCMD) registrationBody() map[string]any { if cmd.RegistrationToken != "" { body["token"] = cmd.RegistrationToken } + + // Parse and add static node labels + if cmd.NodeLabels != "" { + labels := make(map[string]string) + for _, pair := range strings.Split(cmd.NodeLabels, ",") { + pair = strings.TrimSpace(pair) + if k, v, ok := strings.Cut(pair, "="); ok { + labels[strings.TrimSpace(k)] = strings.TrimSpace(v) + } + } + if len(labels) > 0 { + body["labels"] = labels + } + } + return body } diff --git a/core/cli/worker_addr_test.go b/core/cli/worker_addr_test.go new file mode 100644 index 000000000..6dd7b0ab9 --- /dev/null +++ b/core/cli/worker_addr_test.go @@ -0,0 +1,83 @@ +package cli + +import ( + "os" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("WorkerCMD address resolution", func() { + Describe("effectiveBasePort", func() { + DescribeTable("returns the correct port", + func(addr, serve string, want int) { + cmd := &WorkerCMD{Addr: addr, ServeAddr: serve} + Expect(cmd.effectiveBasePort()).To(Equal(want)) + }, + Entry("Addr takes priority", "worker1.example.com:60000", "0.0.0.0:50051", 60000), + Entry("falls back to ServeAddr", "", "0.0.0.0:50051", 50051), + Entry("returns 50051 when neither set", "", "", 50051), + Entry("Addr with custom port", "10.0.0.5:7000", "", 7000), + Entry("invalid port in Addr falls through to ServeAddr", "host:notanumber", "0.0.0.0:9999", 9999), + ) + }) + + Describe("advertiseAddr", func() { + It("returns AdvertiseAddr when set", func() { + cmd := &WorkerCMD{ + AdvertiseAddr: "public.example.com:50051", + Addr: "10.0.0.5:60000", + } + Expect(cmd.advertiseAddr()).To(Equal("public.example.com:50051")) + }) + + It("returns Addr when set", func() { + cmd := &WorkerCMD{Addr: "worker1.example.com:60000"} + Expect(cmd.advertiseAddr()).To(Equal("worker1.example.com:60000")) + }) + + It("falls back to hostname:basePort", func() { + cmd := &WorkerCMD{ServeAddr: "0.0.0.0:50051"} + got := cmd.advertiseAddr() + _, port, _ := strings.Cut(got, ":") + Expect(port).To(Equal("50051")) + + hostname, _ := os.Hostname() + if hostname != "" { + host, _, _ := strings.Cut(got, ":") + Expect(host).To(Equal(hostname)) + } + }) + }) + + Describe("resolveHTTPAddr", func() { + DescribeTable("returns the correct address", + func(httpAddr, addr, serve, want string) { + cmd := &WorkerCMD{HTTPAddr: httpAddr, Addr: addr, ServeAddr: serve} + Expect(cmd.resolveHTTPAddr()).To(Equal(want)) + }, + Entry("HTTPAddr takes priority", "0.0.0.0:8080", "", "", "0.0.0.0:8080"), + Entry("derives from Addr port minus 1", "", "worker1:60000", "0.0.0.0:50051", "0.0.0.0:59999"), + Entry("derives from ServeAddr port minus 1", "", "", "0.0.0.0:50051", "0.0.0.0:50050"), + Entry("default when nothing set", "", "", "", "0.0.0.0:50050"), + ) + }) + + Describe("advertiseHTTPAddr", func() { + DescribeTable("returns the correct address", + func(advertiseHTTP, advertise, addr, serve, want string) { + cmd := &WorkerCMD{ + AdvertiseHTTPAddr: advertiseHTTP, + AdvertiseAddr: advertise, + Addr: addr, + ServeAddr: serve, + } + Expect(cmd.advertiseHTTPAddr()).To(Equal(want)) + }, + Entry("AdvertiseHTTPAddr takes priority", "public.example.com:8080", "", "", "", "public.example.com:8080"), + Entry("derives from advertiseAddr host + basePort-1", "", "", "worker1.example.com:60000", "", "worker1.example.com:59999"), + Entry("uses AdvertiseAddr host with basePort-1", "", "public.example.com:60000", "10.0.0.5:60000", "", "public.example.com:59999"), + ) + }) +}) diff --git a/core/config/application_config.go b/core/config/application_config.go index 2bfb552d7..8d55f83a6 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -59,8 +59,6 @@ type ApplicationConfig struct { SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode) - ParallelBackendRequests bool - WatchDogIdle bool WatchDogBusy bool WatchDog bool @@ -379,10 +377,6 @@ func WithLRUEvictionRetryInterval(interval time.Duration) AppOption { } } -var EnableParallelBackendRequests = func(o *ApplicationConfig) { - o.ParallelBackendRequests = true -} - var EnableGalleriesAutoload = func(o *ApplicationConfig) { o.AutoloadGalleries = true } @@ -842,7 +836,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { watchdogBusy := o.WatchDogBusy singleBackend := o.SingleBackend maxActiveBackends := o.MaxActiveBackends - parallelBackendRequests := o.ParallelBackendRequests memoryReclaimerEnabled := o.MemoryReclaimerEnabled memoryReclaimerThreshold := o.MemoryReclaimerThreshold forceEvictionWhenBusy := o.ForceEvictionWhenBusy @@ -915,7 +908,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { WatchdogInterval: &watchdogInterval, SingleBackend: &singleBackend, MaxActiveBackends: &maxActiveBackends, - ParallelBackendRequests: ¶llelBackendRequests, MemoryReclaimerEnabled: &memoryReclaimerEnabled, MemoryReclaimerThreshold: &memoryReclaimerThreshold, ForceEvictionWhenBusy: &forceEvictionWhenBusy, @@ -1008,9 +1000,6 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req } requireRestart = true } - if settings.ParallelBackendRequests != nil { - o.ParallelBackendRequests = *settings.ParallelBackendRequests - } if settings.MemoryReclaimerEnabled != nil { o.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled if *settings.MemoryReclaimerEnabled { diff --git a/core/config/application_config_test.go b/core/config/application_config_test.go index 4ea89e8d6..c5559ce53 100644 --- a/core/config/application_config_test.go +++ b/core/config/application_config_test.go @@ -18,7 +18,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { WatchDogBusyTimeout: 10 * time.Minute, SingleBackend: false, MaxActiveBackends: 5, - ParallelBackendRequests: true, MemoryReclaimerEnabled: true, MemoryReclaimerThreshold: 0.85, Threads: 8, @@ -62,9 +61,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { Expect(rs.MaxActiveBackends).ToNot(BeNil()) Expect(*rs.MaxActiveBackends).To(Equal(5)) - Expect(rs.ParallelBackendRequests).ToNot(BeNil()) - Expect(*rs.ParallelBackendRequests).To(BeTrue()) - Expect(rs.MemoryReclaimerEnabled).ToNot(BeNil()) Expect(*rs.MemoryReclaimerEnabled).To(BeTrue()) @@ -455,7 +451,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { WatchDogBusyTimeout: 12 * time.Minute, SingleBackend: false, MaxActiveBackends: 3, - ParallelBackendRequests: true, MemoryReclaimerEnabled: true, MemoryReclaimerThreshold: 0.92, Threads: 12, @@ -487,7 +482,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { Expect(target.WatchDogIdleTimeout).To(Equal(original.WatchDogIdleTimeout)) Expect(target.WatchDogBusyTimeout).To(Equal(original.WatchDogBusyTimeout)) Expect(target.MaxActiveBackends).To(Equal(original.MaxActiveBackends)) - Expect(target.ParallelBackendRequests).To(Equal(original.ParallelBackendRequests)) Expect(target.MemoryReclaimerEnabled).To(Equal(original.MemoryReclaimerEnabled)) Expect(target.MemoryReclaimerThreshold).To(Equal(original.MemoryReclaimerThreshold)) Expect(target.Threads).To(Equal(original.Threads)) diff --git a/core/config/runtime_settings.go b/core/config/runtime_settings.go index 611759110..6a4117f06 100644 --- a/core/config/runtime_settings.go +++ b/core/config/runtime_settings.go @@ -20,8 +20,6 @@ type RuntimeSettings struct { // Backend management SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode) - ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"` - // Memory Reclaimer settings (works with GPU if available, otherwise RAM) MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%) diff --git a/core/http/endpoints/localai/nodes.go b/core/http/endpoints/localai/nodes.go index bbdb33471..0ac828a09 100644 --- a/core/http/endpoints/localai/nodes.go +++ b/core/http/endpoints/localai/nodes.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "crypto/subtle" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" @@ -37,7 +38,7 @@ func nodeError(code int, message string) schema.ErrorResponse { func ListNodesEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { return func(c echo.Context) error { ctx := c.Request().Context() - nodeList, err := registry.List(ctx) + nodeList, err := registry.ListWithExtras(ctx) if err != nil { xlog.Error("Failed to list nodes", "error", err) return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list nodes")) @@ -70,7 +71,8 @@ type RegisterNodeRequest struct { AvailableVRAM uint64 `json:"available_vram,omitempty"` TotalRAM uint64 `json:"total_ram,omitempty"` AvailableRAM uint64 `json:"available_ram,omitempty"` - GPUVendor string `json:"gpu_vendor,omitempty"` + GPUVendor string `json:"gpu_vendor,omitempty"` + Labels map[string]string `json:"labels,omitempty"` } // RegisterNodeEndpoint registers a new backend node. @@ -148,6 +150,14 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to register node")) } + // Store static labels from worker and apply auto-labels + if len(req.Labels) > 0 { + if err := registry.SetNodeLabels(ctx, node.ID, req.Labels); err != nil { + xlog.Warn("Failed to set node labels", "node", node.ID, "error", err) + } + } + registry.ApplyAutoLabels(ctx, node.ID, node) + response := map[string]any{ "id": node.ID, "name": node.Name, @@ -618,6 +628,178 @@ func NodeBackendLogsWSEndpoint(registry *nodes.NodeRegistry, registrationToken s } } +// GetNodeLabelsEndpoint returns labels for a specific node. +func GetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + nodeID := c.Param("id") + labels, err := registry.GetNodeLabels(ctx, nodeID) + if err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get labels")) + } + // Convert to map for cleaner API response + result := make(map[string]string) + for _, l := range labels { + result[l.Key] = l.Value + } + return c.JSON(http.StatusOK, result) + } +} + +// SetNodeLabelsEndpoint replaces all labels for a node. +func SetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + nodeID := c.Param("id") + if _, err := registry.Get(ctx, nodeID); err != nil { + return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found")) + } + var labels map[string]string + if err := c.Bind(&labels); err != nil { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body")) + } + if err := registry.SetNodeLabels(ctx, nodeID, labels); err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set labels")) + } + return c.JSON(http.StatusOK, labels) + } +} + +// MergeNodeLabelsEndpoint adds/updates labels without removing existing ones. +func MergeNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + nodeID := c.Param("id") + if _, err := registry.Get(ctx, nodeID); err != nil { + return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found")) + } + var labels map[string]string + if err := c.Bind(&labels); err != nil { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body")) + } + for k, v := range labels { + if err := registry.SetNodeLabel(ctx, nodeID, k, v); err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to merge labels")) + } + } + // Return updated labels + updated, _ := registry.GetNodeLabels(ctx, nodeID) + result := make(map[string]string) + for _, l := range updated { + result[l.Key] = l.Value + } + return c.JSON(http.StatusOK, result) + } +} + +// DeleteNodeLabelEndpoint removes a single label from a node. +func DeleteNodeLabelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + nodeID := c.Param("id") + key := c.Param("key") + if key == "" { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "label key is required")) + } + if err := registry.RemoveNodeLabel(ctx, nodeID, key); err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to remove label")) + } + return c.NoContent(http.StatusNoContent) + } +} + +// ListSchedulingEndpoint returns all model scheduling configs. +func ListSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + configs, err := registry.ListModelSchedulings(ctx) + if err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list scheduling configs")) + } + return c.JSON(http.StatusOK, configs) + } +} + +// GetSchedulingEndpoint returns the scheduling config for a specific model. +func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + modelName := c.Param("model") + config, err := registry.GetModelScheduling(ctx, modelName) + if err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get scheduling config")) + } + if config == nil { + return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "no scheduling config for model")) + } + return c.JSON(http.StatusOK, config) + } +} + +// SetSchedulingRequest is the request body for creating/updating a scheduling config. +type SetSchedulingRequest struct { + ModelName string `json:"model_name"` + NodeSelector map[string]string `json:"node_selector,omitempty"` + MinReplicas int `json:"min_replicas"` + MaxReplicas int `json:"max_replicas"` +} + +// SetSchedulingEndpoint creates or updates a model scheduling config. +func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + var req SetSchedulingRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body")) + } + if req.ModelName == "" { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required")) + } + if req.MinReplicas < 0 { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0")) + } + if req.MaxReplicas < 0 { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0")) + } + if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas")) + } + + // Serialize node selector to JSON + var selectorJSON string + if len(req.NodeSelector) > 0 { + b, err := json.Marshal(req.NodeSelector) + if err != nil { + return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid node_selector")) + } + selectorJSON = string(b) + } + + config := &nodes.ModelSchedulingConfig{ + ModelName: req.ModelName, + NodeSelector: selectorJSON, + MinReplicas: req.MinReplicas, + MaxReplicas: req.MaxReplicas, + } + if err := registry.SetModelScheduling(ctx, config); err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config")) + } + return c.JSON(http.StatusOK, config) + } +} + +// DeleteSchedulingEndpoint removes a model scheduling config. +func DeleteSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + modelName := c.Param("model") + if err := registry.DeleteModelScheduling(ctx, modelName); err != nil { + return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete scheduling config")) + } + return c.NoContent(http.StatusNoContent) + } +} + // proxyHTTPToWorker makes a GET request to a worker's HTTP server with bearer token auth. func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error) { reqURL := fmt.Sprintf("http://%s%s", httpAddress, path) diff --git a/core/http/react-ui/src/pages/Nodes.jsx b/core/http/react-ui/src/pages/Nodes.jsx index 24812d6a9..0ec97a7c8 100644 --- a/core/http/react-ui/src/pages/Nodes.jsx +++ b/core/http/react-ui/src/pages/Nodes.jsx @@ -159,6 +159,62 @@ function WorkerHintCard({ addToast, activeTab, hasWorkers }) { ) } +function SchedulingForm({ onSave, onCancel }) { + const [modelName, setModelName] = useState('') + const [selectorText, setSelectorText] = useState('') + const [minReplicas, setMinReplicas] = useState(0) + const [maxReplicas, setMaxReplicas] = useState(0) + + const handleSubmit = () => { + let nodeSelector = null + if (selectorText.trim()) { + const pairs = {} + selectorText.split(',').forEach(p => { + const [k, v] = p.split('=').map(s => s.trim()) + if (k) pairs[k] = v || '' + }) + nodeSelector = pairs + } + onSave({ + model_name: modelName, + node_selector: nodeSelector ? JSON.stringify(nodeSelector) : '', + min_replicas: minReplicas, + max_replicas: maxReplicas, + }) + } + + return ( +
+
+
+ + setModelName(e.target.value)} + placeholder="e.g. llama3" style={{ width: '100%' }} /> +
+
+ + setSelectorText(e.target.value)} + placeholder="e.g. gpu.vendor=nvidia,tier=fast" style={{ width: '100%' }} /> +
+
+ + setMinReplicas(parseInt(e.target.value) || 0)} + style={{ width: '100%' }} /> +
+
+ + setMaxReplicas(parseInt(e.target.value) || 0)} + style={{ width: '100%' }} /> +
+
+
+ + +
+
+ ) +} + export default function Nodes() { const { addToast } = useOutletContext() const navigate = useNavigate() @@ -170,7 +226,9 @@ export default function Nodes() { const [nodeBackends, setNodeBackends] = useState({}) const [confirmDelete, setConfirmDelete] = useState(null) const [showTips, setShowTips] = useState(false) - const [activeTab, setActiveTab] = useState('backend') // 'backend' or 'agent' + const [activeTab, setActiveTab] = useState('backend') // 'backend', 'agent', or 'scheduling' + const [schedulingConfigs, setSchedulingConfigs] = useState([]) + const [showSchedulingForm, setShowSchedulingForm] = useState(false) const fetchNodes = useCallback(async () => { try { @@ -186,11 +244,19 @@ export default function Nodes() { } }, []) + const fetchScheduling = useCallback(async () => { + try { + const data = await nodesApi.listScheduling() + setSchedulingConfigs(Array.isArray(data) ? data : []) + } catch { setSchedulingConfigs([]) } + }, []) + useEffect(() => { fetchNodes() + fetchScheduling() const interval = setInterval(fetchNodes, 5000) return () => clearInterval(interval) - }, [fetchNodes]) + }, [fetchNodes, fetchScheduling]) const fetchModels = useCallback(async (nodeId) => { try { @@ -254,6 +320,36 @@ export default function Nodes() { } } + const handleUnloadModel = async (nodeId, modelName) => { + try { + await nodesApi.unloadModel(nodeId, modelName) + addToast(`Model "${modelName}" unloaded`, 'success') + fetchModels(nodeId) + } catch (err) { + addToast(`Failed to unload model: ${err.message}`, 'error') + } + } + + const handleAddLabel = async (nodeId, key, value) => { + try { + await nodesApi.mergeLabels(nodeId, { [key]: value }) + addToast(`Label "${key}=${value}" added`, 'success') + fetchNodes() + } catch (err) { + addToast(`Failed to add label: ${err.message}`, 'error') + } + } + + const handleDeleteLabel = async (nodeId, key) => { + try { + await nodesApi.deleteLabel(nodeId, key) + addToast(`Label "${key}" removed`, 'success') + fetchNodes() + } catch (err) { + addToast(`Failed to remove label: ${err.message}`, 'error') + } + } + const handleDelete = async (nodeId) => { try { await nodesApi.delete(nodeId) @@ -422,8 +518,23 @@ export default function Nodes() { Agent Workers ({agentNodes.length}) + + {activeTab !== 'scheduling' && <> {/* Stat cards */}
@@ -433,6 +544,23 @@ export default function Nodes() { {pending > 0 && ( )} + {activeTab === 'backend' && (() => { + const clusterTotalVRAM = backendNodes.reduce((sum, n) => sum + (n.total_vram || 0), 0) + const clusterUsedVRAM = backendNodes.reduce((sum, n) => { + if (n.total_vram && n.available_vram != null) return sum + (n.total_vram - n.available_vram) + return sum + }, 0) + const totalModelsLoaded = backendNodes.reduce((sum, n) => sum + (n.model_count || 0), 0) + return ( + <> + {clusterTotalVRAM > 0 && ( + + )} + + + ) + })()}
{/* Worker tips */} @@ -506,6 +634,22 @@ export default function Nodes() {
{node.address}
+ {node.labels && Object.keys(node.labels).length > 0 && ( +
+ {Object.entries(node.labels).slice(0, 5).map(([k, v]) => ( + {k}={v} + ))} + {Object.keys(node.labels).length > 5 && ( + + +{Object.keys(node.labels).length - 5} more + + )} +
+ )} @@ -593,6 +737,7 @@ export default function Nodes() { State In-Flight Logs + Actions @@ -628,6 +773,21 @@ export default function Nodes() { + + + ) })} @@ -689,6 +849,50 @@ export default function Nodes() { )} + + {/* Labels */} +
+

+ + Labels +

+
+ {node.labels && Object.entries(node.labels).map(([k, v]) => ( + + {k}={v} + + + ))} +
+ {/* Add label form */} +
+ + + +
+
@@ -700,6 +904,78 @@ export default function Nodes() { )} + } + + {activeTab === 'scheduling' && ( +
+ + {showSchedulingForm && { + try { + await nodesApi.setScheduling(config) + fetchScheduling() + setShowSchedulingForm(false) + addToast('Scheduling rule saved', 'success') + } catch (err) { + addToast(`Failed to save rule: ${err.message}`, 'error') + } + }} onCancel={() => setShowSchedulingForm(false)} />} + {schedulingConfigs.length === 0 && !showSchedulingForm ? ( +

+ No scheduling rules configured. Add a rule to control how models are placed on nodes. +

+ ) : schedulingConfigs.length > 0 && ( +
+ + + + + + + + + + {schedulingConfigs.map(cfg => ( + + + + + + + + ))} + +
ModelNode SelectorMin ReplicasMax ReplicasActions
{cfg.model_name} + {cfg.node_selector ? (() => { + try { + const sel = typeof cfg.node_selector === 'string' ? JSON.parse(cfg.node_selector) : cfg.node_selector + return Object.entries(sel).map(([k,v]) => ( + {k}={v} + )) + } catch { return {cfg.node_selector} } + })() : Any node} + {cfg.min_replicas || '-'}{cfg.max_replicas || 'unlimited'} + +
+
+ )} +
+ )} update('max_active_backends', parseInt(e.target.value) || 0)} placeholder="0" /> - - update('parallel_backend_requests', v)} /> - diff --git a/core/http/react-ui/src/utils/api.js b/core/http/react-ui/src/utils/api.js index dc25f051d..ee53bc377 100644 --- a/core/http/react-ui/src/utils/api.js +++ b/core/http/react-ui/src/utils/api.js @@ -423,6 +423,13 @@ export const nodesApi = { deleteBackend: (id, backend) => postJSON(API_CONFIG.endpoints.nodeBackendsDelete(id), { backend }), getBackendLogs: (id) => fetchJSON(API_CONFIG.endpoints.nodeBackendLogs(id)), getBackendLogLines: (id, modelId) => fetchJSON(API_CONFIG.endpoints.nodeBackendLogsModel(id, modelId)), + unloadModel: (id, modelName) => postJSON(API_CONFIG.endpoints.nodeModelsUnload(id), { model_name: modelName }), + getLabels: (id) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id)), + mergeLabels: (id, labels) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id), { method: 'PATCH', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(labels) }), + deleteLabel: (id, key) => fetchJSON(API_CONFIG.endpoints.nodeLabelKey(id, key), { method: 'DELETE' }), + listScheduling: () => fetchJSON(API_CONFIG.endpoints.nodesScheduling), + setScheduling: (config) => postJSON(API_CONFIG.endpoints.nodesScheduling, config), + deleteScheduling: (model) => fetchJSON(API_CONFIG.endpoints.nodesSchedulingModel(model), { method: 'DELETE' }), } // File to base64 helper diff --git a/core/http/react-ui/src/utils/config.js b/core/http/react-ui/src/utils/config.js index 76776bfce..06c84d509 100644 --- a/core/http/react-ui/src/utils/config.js +++ b/core/http/react-ui/src/utils/config.js @@ -107,5 +107,10 @@ export const API_CONFIG = { nodeBackendsDelete: (id) => `/api/nodes/${id}/backends/delete`, nodeBackendLogs: (id) => `/api/nodes/${id}/backend-logs`, nodeBackendLogsModel: (id, modelId) => `/api/nodes/${id}/backend-logs/${encodeURIComponent(modelId)}`, + nodeModelsUnload: (id) => `/api/nodes/${id}/models/unload`, + nodeLabels: (id) => `/api/nodes/${id}/labels`, + nodeLabelKey: (id, key) => `/api/nodes/${id}/labels/${key}`, + nodesScheduling: '/api/nodes/scheduling', + nodesSchedulingModel: (model) => `/api/nodes/scheduling/${encodeURIComponent(model)}`, }, } diff --git a/core/http/routes/nodes.go b/core/http/routes/nodes.go index 5ebd87ce6..2f23e3036 100644 --- a/core/http/routes/nodes.go +++ b/core/http/routes/nodes.go @@ -61,6 +61,13 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade admin := e.Group("/api/nodes", readyMw, adminMw) admin.GET("", localai.ListNodesEndpoint(registry)) + + // Model scheduling (registered before /:id to avoid route conflicts) + admin.GET("/scheduling", localai.ListSchedulingEndpoint(registry)) + admin.GET("/scheduling/:model", localai.GetSchedulingEndpoint(registry)) + admin.POST("/scheduling", localai.SetSchedulingEndpoint(registry)) + admin.DELETE("/scheduling/:model", localai.DeleteSchedulingEndpoint(registry)) + admin.GET("/:id", localai.GetNodeEndpoint(registry)) admin.GET("/:id/models", localai.GetNodeModelsEndpoint(registry)) admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry)) @@ -80,6 +87,12 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade admin.GET("/:id/backend-logs", localai.NodeBackendLogsListEndpoint(registry, registrationToken)) admin.GET("/:id/backend-logs/:modelId", localai.NodeBackendLogsLinesEndpoint(registry, registrationToken)) + // Label management + admin.GET("/:id/labels", localai.GetNodeLabelsEndpoint(registry)) + admin.PUT("/:id/labels", localai.SetNodeLabelsEndpoint(registry)) + admin.PATCH("/:id/labels", localai.MergeNodeLabelsEndpoint(registry)) + admin.DELETE("/:id/labels/:key", localai.DeleteNodeLabelEndpoint(registry)) + // WebSocket proxy for real-time log streaming from workers e.GET("/ws/nodes/:id/backend-logs/:modelId", localai.NodeBackendLogsWSEndpoint(registry, registrationToken), readyMw, adminMw) } diff --git a/core/services/nodes/interfaces.go b/core/services/nodes/interfaces.go index 3b3cbf137..53d78ccb3 100644 --- a/core/services/nodes/interfaces.go +++ b/core/services/nodes/interfaces.go @@ -21,6 +21,12 @@ type ModelRouter interface { FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (*NodeModel, error) FindLRUModel(ctx context.Context, nodeID string) (*NodeModel, error) Get(ctx context.Context, nodeID string) (*BackendNode, error) + GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error) + FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error) + FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error) + FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) + FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) + GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error) } // NodeHealthStore is used by HealthMonitor for node status management. diff --git a/core/services/nodes/model_router_test.go b/core/services/nodes/model_router_test.go index 5c98fb0de..ce0165e60 100644 --- a/core/services/nodes/model_router_test.go +++ b/core/services/nodes/model_router_test.go @@ -72,6 +72,24 @@ func (f *fakeModelRouterForSmartRouter) Get(_ context.Context, nodeID string) (* } return nil, nil } +func (f *fakeModelRouterForSmartRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) { + return nil, nil +} +func (f *fakeModelRouterForSmartRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) { + return nil, nil +} +func (f *fakeModelRouterForSmartRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) { + return nil, nil +} +func (f *fakeModelRouterForSmartRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) { + return nil, nil +} +func (f *fakeModelRouterForSmartRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) { + return nil, nil +} +func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) { + return nil, nil +} // Compile-time check var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil) diff --git a/core/services/nodes/reconciler.go b/core/services/nodes/reconciler.go new file mode 100644 index 000000000..92ba76edc --- /dev/null +++ b/core/services/nodes/reconciler.go @@ -0,0 +1,236 @@ +package nodes + +import ( + "context" + "encoding/json" + "time" + + "github.com/mudler/LocalAI/core/services/advisorylock" + "github.com/mudler/xlog" + "gorm.io/gorm" +) + +// ReplicaReconciler periodically ensures model replica counts match their +// scheduling configs. It scales up replicas when below MinReplicas or when +// all replicas are busy (up to MaxReplicas), and scales down idle replicas +// above MinReplicas. +// +// Only processes models with auto-scaling enabled (MinReplicas > 0 or MaxReplicas > 0). +type ReplicaReconciler struct { + registry *NodeRegistry + scheduler ModelScheduler // interface for scheduling new models + unloader NodeCommandSender + db *gorm.DB + interval time.Duration + scaleDownDelay time.Duration +} + +// ModelScheduler abstracts the scheduling logic needed by the reconciler. +// SmartRouter implements this interface. +type ModelScheduler interface { + // ScheduleAndLoadModel picks a node (optionally from candidateNodeIDs), + // installs the backend, and loads the model. Returns the node used. + ScheduleAndLoadModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) +} + +// ReplicaReconcilerOptions holds configuration for creating a ReplicaReconciler. +type ReplicaReconcilerOptions struct { + Registry *NodeRegistry + Scheduler ModelScheduler + Unloader NodeCommandSender + DB *gorm.DB + Interval time.Duration // default 30s + ScaleDownDelay time.Duration // default 5m +} + +// NewReplicaReconciler creates a new ReplicaReconciler. +func NewReplicaReconciler(opts ReplicaReconcilerOptions) *ReplicaReconciler { + interval := opts.Interval + if interval == 0 { + interval = 30 * time.Second + } + scaleDownDelay := opts.ScaleDownDelay + if scaleDownDelay == 0 { + scaleDownDelay = 5 * time.Minute + } + return &ReplicaReconciler{ + registry: opts.Registry, + scheduler: opts.Scheduler, + unloader: opts.Unloader, + db: opts.DB, + interval: interval, + scaleDownDelay: scaleDownDelay, + } +} + +// Run starts the reconciliation loop. It blocks until ctx is cancelled. +func (rc *ReplicaReconciler) Run(ctx context.Context) { + ticker := time.NewTicker(rc.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + rc.reconcileOnce(ctx) + } + } +} + +// reconcileOnce performs a single reconciliation pass. +// Uses an advisory lock so only one frontend instance reconciles at a time. +func (rc *ReplicaReconciler) reconcileOnce(ctx context.Context) { + if rc.db != nil { + lockKey := advisorylock.KeyFromString("replica-reconciler") + _ = advisorylock.WithLockCtx(ctx, rc.db, lockKey, func() error { + rc.reconcile(ctx) + return nil + }) + } else { + rc.reconcile(ctx) + } +} + +func (rc *ReplicaReconciler) reconcile(ctx context.Context) { + configs, err := rc.registry.ListAutoScalingConfigs(ctx) + if err != nil { + xlog.Warn("Reconciler: failed to list auto-scaling configs", "error", err) + return + } + + for _, cfg := range configs { + if err := ctx.Err(); err != nil { + return // context cancelled + } + rc.reconcileModel(ctx, cfg) + } +} + +func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedulingConfig) { + current, err := rc.registry.CountLoadedReplicas(ctx, cfg.ModelName) + if err != nil { + xlog.Warn("Reconciler: failed to count replicas", "model", cfg.ModelName, "error", err) + return + } + + // 1. Ensure minimum replicas + if cfg.MinReplicas > 0 && int(current) < cfg.MinReplicas { + needed := cfg.MinReplicas - int(current) + xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName, + "current", current, "min", cfg.MinReplicas, "adding", needed) + rc.scaleUp(ctx, cfg, needed) + return + } + + // 2. Auto-scale up if all replicas are busy + if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) { + if rc.allReplicasBusy(ctx, cfg.ModelName) { + xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName, + "current", current) + rc.scaleUp(ctx, cfg, 1) + } + } + + // 3. Scale down idle replicas above minimum + floor := cfg.MinReplicas + if floor < 1 { + floor = 1 + } + if int(current) > floor { + rc.scaleDownIdle(ctx, cfg, int(current), floor) + } +} + +// scaleUp schedules additional replicas of the model. +func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) { + if rc.scheduler == nil { + xlog.Warn("Reconciler: no scheduler available, cannot scale up") + return + } + + // Determine candidate nodes from selector + var candidateNodeIDs []string + if cfg.NodeSelector != "" { + selector := parseSelector(cfg.NodeSelector) + if len(selector) > 0 { + candidates, err := rc.registry.FindNodesBySelector(ctx, selector) + if err != nil || len(candidates) == 0 { + xlog.Warn("Reconciler: no nodes match selector", "model", cfg.ModelName, + "selector", cfg.NodeSelector) + return + } + candidateNodeIDs = make([]string, len(candidates)) + for i, n := range candidates { + candidateNodeIDs[i] = n.ID + } + } + } + + for i := 0; i < count; i++ { + node, err := rc.scheduler.ScheduleAndLoadModel(ctx, cfg.ModelName, candidateNodeIDs) + if err != nil { + xlog.Warn("Reconciler: failed to scale up replica", "model", cfg.ModelName, + "attempt", i+1, "error", err) + return // stop trying on first failure + } + xlog.Info("Reconciler: scaled up replica", "model", cfg.ModelName, "node", node.Name) + } +} + +// scaleDownIdle removes idle replicas above the floor. +func (rc *ReplicaReconciler) scaleDownIdle(ctx context.Context, cfg ModelSchedulingConfig, current, floor int) { + if rc.unloader == nil { + return + } + + // Find idle replicas that have been unused for longer than scaleDownDelay + cutoff := time.Now().Add(-rc.scaleDownDelay) + var idleModels []NodeModel + rc.registry.db.WithContext(ctx). + Where("model_name = ? AND state = ? AND in_flight = 0 AND last_used < ?", + cfg.ModelName, "loaded", cutoff). + Order("last_used ASC"). + Find(&idleModels) + + toRemove := current - floor + removed := 0 + for _, nm := range idleModels { + if removed >= toRemove { + break + } + // Remove from registry + if err := rc.registry.RemoveNodeModel(ctx, nm.NodeID, nm.ModelName); err != nil { + xlog.Warn("Reconciler: failed to remove model record", "error", err) + continue + } + // Unload from worker + if err := rc.unloader.UnloadModelOnNode(nm.NodeID, nm.ModelName); err != nil { + xlog.Warn("Reconciler: unload failed (model already removed from registry)", "error", err) + } + xlog.Info("Reconciler: scaled down idle replica", "model", cfg.ModelName, "node", nm.NodeID) + removed++ + } +} + +// allReplicasBusy returns true if all loaded replicas of a model have in-flight requests. +func (rc *ReplicaReconciler) allReplicasBusy(ctx context.Context, modelName string) bool { + var idleCount int64 + rc.registry.db.WithContext(ctx).Model(&NodeModel{}). + Where("model_name = ? AND state = ? AND in_flight = 0", modelName, "loaded"). + Count(&idleCount) + return idleCount == 0 +} + +// parseSelector decodes a JSON node selector string into a map. +func parseSelector(selectorJSON string) map[string]string { + if selectorJSON == "" { + return nil + } + var selector map[string]string + if err := json.Unmarshal([]byte(selectorJSON), &selector); err != nil { + xlog.Warn("Failed to parse node selector", "selector", selectorJSON, "error", err) + return nil + } + return selector +} diff --git a/core/services/nodes/reconciler_test.go b/core/services/nodes/reconciler_test.go new file mode 100644 index 000000000..e95f8bcea --- /dev/null +++ b/core/services/nodes/reconciler_test.go @@ -0,0 +1,241 @@ +package nodes + +import ( + "context" + "runtime" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/services/testutil" + "gorm.io/gorm" +) + +// --------------------------------------------------------------------------- +// Fake ModelScheduler +// --------------------------------------------------------------------------- + +type fakeScheduler struct { + scheduleNode *BackendNode + scheduleErr error + scheduleCalls []scheduleCall +} + +type scheduleCall struct { + modelName string + candidateIDs []string +} + +func (f *fakeScheduler) ScheduleAndLoadModel(_ context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) { + f.scheduleCalls = append(f.scheduleCalls, scheduleCall{modelName, candidateNodeIDs}) + return f.scheduleNode, f.scheduleErr +} + +var _ = Describe("ReplicaReconciler", func() { + var ( + db *gorm.DB + registry *NodeRegistry + ) + + BeforeEach(func() { + if runtime.GOOS == "darwin" { + Skip("testcontainers requires Docker, not available on macOS CI") + } + db = testutil.SetupTestDB() + var err error + registry, err = NewNodeRegistry(db) + Expect(err).ToNot(HaveOccurred()) + }) + + // Helper to register a healthy node. + registerNode := func(name, address string) *BackendNode { + node := &BackendNode{ + Name: name, + NodeType: NodeTypeBackend, + Address: address, + } + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + return node + } + + // Helper to set up a scheduling config. + setSchedulingConfig := func(modelName string, minReplicas, maxReplicas int, nodeSelector string) { + cfg := &ModelSchedulingConfig{ + ModelName: modelName, + MinReplicas: minReplicas, + MaxReplicas: maxReplicas, + NodeSelector: nodeSelector, + } + Expect(registry.SetModelScheduling(context.Background(), cfg)).To(Succeed()) + } + + Context("model below min_replicas", func() { + It("scales up to min_replicas", func() { + node := registerNode("node-1", "10.0.0.1:50051") + setSchedulingConfig("model-a", 2, 4, "") + + scheduler := &fakeScheduler{ + scheduleNode: node, + } + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: scheduler, + DB: db, + }) + + // No replicas loaded — should schedule 2 + reconciler.reconcile(context.Background()) + + Expect(scheduler.scheduleCalls).To(HaveLen(2)) + Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-a")) + Expect(scheduler.scheduleCalls[1].modelName).To(Equal("model-a")) + }) + }) + + Context("all replicas busy and below max_replicas", func() { + It("scales up by 1", func() { + node := registerNode("node-busy", "10.0.0.2:50051") + setSchedulingConfig("model-b", 1, 4, "") + + // Load 2 replicas, both busy (in_flight > 0) + Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", "loaded", "addr1", 0)).To(Succeed()) + Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-b")).To(Succeed()) + + node2 := registerNode("node-busy-2", "10.0.0.3:50051") + Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-b", "loaded", "addr2", 0)).To(Succeed()) + Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-b")).To(Succeed()) + + scheduler := &fakeScheduler{ + scheduleNode: node, + } + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: scheduler, + DB: db, + }) + + reconciler.reconcile(context.Background()) + + Expect(scheduler.scheduleCalls).To(HaveLen(1)) + Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-b")) + }) + }) + + Context("all replicas busy and at max_replicas", func() { + It("does not scale up", func() { + node := registerNode("node-max", "10.0.0.4:50051") + setSchedulingConfig("model-c", 1, 2, "") + + // Load 2 replicas (at max), both busy + Expect(registry.SetNodeModel(context.Background(), node.ID, "model-c", "loaded", "addr1", 0)).To(Succeed()) + Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-c")).To(Succeed()) + + node2 := registerNode("node-max-2", "10.0.0.5:50051") + Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-c", "loaded", "addr2", 0)).To(Succeed()) + Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-c")).To(Succeed()) + + scheduler := &fakeScheduler{ + scheduleNode: node, + } + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: scheduler, + DB: db, + }) + + reconciler.reconcile(context.Background()) + + Expect(scheduler.scheduleCalls).To(BeEmpty()) + }) + }) + + Context("idle replicas above min_replicas", func() { + It("scales down after idle delay", func() { + node1 := registerNode("node-idle-1", "10.0.0.6:50051") + node2 := registerNode("node-idle-2", "10.0.0.7:50051") + node3 := registerNode("node-idle-3", "10.0.0.8:50051") + setSchedulingConfig("model-d", 1, 4, "") + + // Load 3 replicas, all idle with last_used in the past + pastTime := time.Now().Add(-10 * time.Minute) + for _, n := range []*BackendNode{node1, node2, node3} { + Expect(registry.SetNodeModel(context.Background(), n.ID, "model-d", "loaded", "", 0)).To(Succeed()) + // Set last_used to past time to trigger scale-down + db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-d"). + Update("last_used", pastTime) + } + + unloader := &fakeUnloader{} + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Unloader: unloader, + DB: db, + ScaleDownDelay: 1 * time.Minute, // short delay for test + }) + + reconciler.reconcile(context.Background()) + + // Should scale down 2 replicas (3 - floor of 1) + Expect(unloader.unloadCalls).To(HaveLen(2)) + }) + }) + + Context("idle replicas at min_replicas", func() { + It("does not scale down", func() { + node1 := registerNode("node-keep-1", "10.0.0.9:50051") + node2 := registerNode("node-keep-2", "10.0.0.10:50051") + setSchedulingConfig("model-e", 2, 4, "") + + // Load exactly 2 replicas (at min), both idle with past last_used + pastTime := time.Now().Add(-10 * time.Minute) + for _, n := range []*BackendNode{node1, node2} { + Expect(registry.SetNodeModel(context.Background(), n.ID, "model-e", "loaded", "", 0)).To(Succeed()) + db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-e"). + Update("last_used", pastTime) + } + + unloader := &fakeUnloader{} + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Unloader: unloader, + DB: db, + ScaleDownDelay: 1 * time.Minute, + }) + + reconciler.reconcile(context.Background()) + + Expect(unloader.unloadCalls).To(BeEmpty()) + }) + }) + + Context("model with node_selector", func() { + It("passes candidate node IDs to scheduler", func() { + node1 := registerNode("gpu-node", "10.0.0.11:50051") + node2 := registerNode("cpu-node", "10.0.0.12:50051") + + // Add labels — only node1 matches the selector + Expect(registry.SetNodeLabel(context.Background(), node1.ID, "gpu.vendor", "nvidia")).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), node2.ID, "gpu.vendor", "none")).To(Succeed()) + + setSchedulingConfig("model-f", 1, 2, `{"gpu.vendor":"nvidia"}`) + + scheduler := &fakeScheduler{ + scheduleNode: node1, + } + reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{ + Registry: registry, + Scheduler: scheduler, + DB: db, + }) + + // No replicas loaded — should schedule 1 with candidate node IDs + reconciler.reconcile(context.Background()) + + Expect(scheduler.scheduleCalls).To(HaveLen(1)) + Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-f")) + Expect(scheduler.scheduleCalls[0].candidateIDs).To(ContainElement(node1.ID)) + Expect(scheduler.scheduleCalls[0].candidateIDs).ToNot(ContainElement(node2.ID)) + }) + }) +}) diff --git a/core/services/nodes/registry.go b/core/services/nodes/registry.go index c7e6ccd3d..5d7404f5f 100644 --- a/core/services/nodes/registry.go +++ b/core/services/nodes/registry.go @@ -68,6 +68,39 @@ type NodeModel struct { UpdatedAt time.Time `json:"updated_at"` } +// NodeLabel is a key-value label on a node (like K8s labels). +type NodeLabel struct { + ID string `gorm:"primaryKey;size:36" json:"id"` + NodeID string `gorm:"uniqueIndex:idx_node_label;size:36" json:"node_id"` + Key string `gorm:"uniqueIndex:idx_node_label;size:128" json:"key"` + Value string `gorm:"size:255" json:"value"` +} + +// ModelSchedulingConfig defines how a model should be scheduled across the cluster. +// All fields are optional: +// - NodeSelector only → constrain nodes, single replica +// - MinReplicas/MaxReplicas only → auto-scale on any node +// - Both → auto-scale on matching nodes +// - Neither → no-op (default behavior) +// +// Auto-scaling is enabled when MinReplicas > 0 or MaxReplicas > 0. +type ModelSchedulingConfig struct { + ID string `gorm:"primaryKey;size:36" json:"id"` + ModelName string `gorm:"uniqueIndex;size:255" json:"model_name"` + NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...} + MinReplicas int `gorm:"default:0" json:"min_replicas"` + MaxReplicas int `gorm:"default:0" json:"max_replicas"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// NodeWithExtras extends BackendNode with computed fields for list views. +type NodeWithExtras struct { + BackendNode + ModelCount int `json:"model_count"` + Labels map[string]string `json:"labels,omitempty"` +} + // NodeRegistry manages backend node registration and lookup in PostgreSQL. type NodeRegistry struct { db *gorm.DB @@ -78,7 +111,7 @@ type NodeRegistry struct { // when multiple instances (frontend + workers) start at the same time. func NewNodeRegistry(db *gorm.DB) (*NodeRegistry, error) { if err := advisorylock.WithLockCtx(context.Background(), db, advisorylock.KeySchemaMigrate, func() error { - return db.AutoMigrate(&BackendNode{}, &NodeModel{}) + return db.AutoMigrate(&BackendNode{}, &NodeModel{}, &NodeLabel{}, &ModelSchedulingConfig{}) }); err != nil { return nil, fmt.Errorf("migrating node tables: %w", err) } @@ -207,18 +240,19 @@ func (r *NodeRegistry) FindNodeWithVRAM(ctx context.Context, minBytes uint64) (* Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight"). Group("node_id") - // Try idle nodes with enough VRAM first + // Try idle nodes with enough VRAM first, prefer the one with most free VRAM var node BackendNode err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, minBytes, loadedModels). + Order("available_vram DESC"). First(&node).Error if err == nil { return &node, nil } - // Fall back to least-loaded nodes with enough VRAM + // Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker err = db.Where("status = ? AND node_type = ? AND available_vram >= ?", StatusHealthy, NodeTypeBackend, minBytes). Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery). - Order("COALESCE(load.total_inflight, 0) ASC"). + Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC"). First(&node).Error if err != nil { return nil, fmt.Errorf("no healthy nodes with %d bytes available VRAM: %w", minBytes, err) @@ -407,9 +441,12 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s var node BackendNode err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // Order by in_flight ASC (least busy replica), then by available_vram DESC + // (prefer nodes with more free VRAM to spread load across the cluster). if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). - Where("model_name = ? AND state = ?", modelName, "loaded"). - Order("in_flight ASC"). + Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id"). + Where("node_models.model_name = ? AND node_models.state = ?", modelName, "loaded"). + Order("node_models.in_flight ASC, backend_nodes.available_vram DESC"). First(&nm).Error; err != nil { return err } @@ -463,7 +500,7 @@ func (r *NodeRegistry) FindLeastLoadedNode(ctx context.Context) (*BackendNode, e Group("node_id") err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery). - Order("COALESCE(load.total_inflight, 0) ASC"). + Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC"). First(&node).Error if err != nil { return nil, fmt.Errorf("finding least loaded node: %w", err) @@ -482,6 +519,7 @@ func (r *NodeRegistry) FindIdleNode(ctx context.Context) (*BackendNode, error) { Where("state = ?", "loaded"). Group("node_id") err := db.Where("status = ? AND node_type = ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, loadedModels). + Order("available_vram DESC"). First(&node).Error if err != nil { return nil, err @@ -578,3 +616,287 @@ func (r *NodeRegistry) FindGlobalLRUModelWithZeroInFlight(ctx context.Context) ( } return &nm, nil } + +// --- NodeLabel operations --- + +// SetNodeLabel upserts a single label on a node. +func (r *NodeRegistry) SetNodeLabel(ctx context.Context, nodeID, key, value string) error { + label := NodeLabel{ + ID: uuid.New().String(), + NodeID: nodeID, + Key: key, + Value: value, + } + return r.db.WithContext(ctx). + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "node_id"}, {Name: "key"}}, + DoUpdates: clause.AssignmentColumns([]string{"value"}), + }). + Create(&label).Error +} + +// SetNodeLabels replaces all labels for a node with the given map. +func (r *NodeRegistry) SetNodeLabels(ctx context.Context, nodeID string, labels map[string]string) error { + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Where("node_id = ?", nodeID).Delete(&NodeLabel{}).Error; err != nil { + return err + } + for k, v := range labels { + label := NodeLabel{ID: uuid.New().String(), NodeID: nodeID, Key: k, Value: v} + if err := tx.Create(&label).Error; err != nil { + return err + } + } + return nil + }) +} + +// RemoveNodeLabel removes a single label from a node. +func (r *NodeRegistry) RemoveNodeLabel(ctx context.Context, nodeID, key string) error { + return r.db.WithContext(ctx).Where("node_id = ? AND key = ?", nodeID, key).Delete(&NodeLabel{}).Error +} + +// GetNodeLabels returns all labels for a node. +func (r *NodeRegistry) GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error) { + var labels []NodeLabel + err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Find(&labels).Error + return labels, err +} + +// GetAllNodeLabelsMap returns all labels grouped by node ID. +func (r *NodeRegistry) GetAllNodeLabelsMap(ctx context.Context) (map[string]map[string]string, error) { + var labels []NodeLabel + if err := r.db.WithContext(ctx).Find(&labels).Error; err != nil { + return nil, err + } + result := make(map[string]map[string]string) + for _, l := range labels { + if result[l.NodeID] == nil { + result[l.NodeID] = make(map[string]string) + } + result[l.NodeID][l.Key] = l.Value + } + return result, nil +} + +// --- Selector-based queries --- + +// FindNodesBySelector returns healthy backend nodes matching ALL key-value pairs in the selector. +func (r *NodeRegistry) FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error) { + if len(selector) == 0 { + // Empty selector matches all healthy backend nodes + var nodes []BackendNode + err := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend).Find(&nodes).Error + return nodes, err + } + + db := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend) + for k, v := range selector { + db = db.Where("EXISTS (SELECT 1 FROM node_labels WHERE node_labels.node_id = backend_nodes.id AND node_labels.key = ? AND node_labels.value = ?)", k, v) + } + + var nodes []BackendNode + err := db.Find(&nodes).Error + return nodes, err +} + +// FindNodeWithVRAMFromSet is like FindNodeWithVRAM but restricted to the given node IDs. +func (r *NodeRegistry) FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error) { + db := r.db.WithContext(ctx) + + loadedModels := db.Model(&NodeModel{}). + Select("node_id"). + Where("state = ?", "loaded"). + Group("node_id") + + subquery := db.Model(&NodeModel{}). + Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight"). + Group("node_id") + + // Try idle nodes with enough VRAM first, prefer the one with most free VRAM + var node BackendNode + err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, minBytes, loadedModels, nodeIDs). + Order("available_vram DESC"). + First(&node).Error + if err == nil { + return &node, nil + } + + // Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker + err = db.Where("status = ? AND node_type = ? AND available_vram >= ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, minBytes, nodeIDs). + Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery). + Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC"). + First(&node).Error + if err != nil { + return nil, fmt.Errorf("no healthy nodes in set with %d bytes available VRAM: %w", minBytes, err) + } + return &node, nil +} + +// FindIdleNodeFromSet is like FindIdleNode but restricted to the given node IDs. +func (r *NodeRegistry) FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) { + db := r.db.WithContext(ctx) + + var node BackendNode + loadedModels := db.Model(&NodeModel{}). + Select("node_id"). + Where("state = ?", "loaded"). + Group("node_id") + err := db.Where("status = ? AND node_type = ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, loadedModels, nodeIDs). + Order("available_vram DESC"). + First(&node).Error + if err != nil { + return nil, err + } + return &node, nil +} + +// FindLeastLoadedNodeFromSet is like FindLeastLoadedNode but restricted to the given node IDs. +func (r *NodeRegistry) FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) { + db := r.db.WithContext(ctx) + + var node BackendNode + query := db.Where("status = ? AND node_type = ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, nodeIDs) + // Order by total in-flight across all models on the node + subquery := db.Model(&NodeModel{}). + Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight"). + Group("node_id") + + err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery). + Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC"). + First(&node).Error + if err != nil { + return nil, fmt.Errorf("finding least loaded node in set: %w", err) + } + return &node, nil +} + +// --- ModelSchedulingConfig operations --- + +// SetModelScheduling creates or updates a scheduling config for a model. +func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSchedulingConfig) error { + if config.ID == "" { + config.ID = uuid.New().String() + } + return r.db.WithContext(ctx). + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "model_name"}}, + DoUpdates: clause.AssignmentColumns([]string{"node_selector", "min_replicas", "max_replicas", "updated_at"}), + }). + Create(config).Error +} + +// GetModelScheduling returns the scheduling config for a model, or nil if none exists. +func (r *NodeRegistry) GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error) { + var config ModelSchedulingConfig + err := r.db.WithContext(ctx).Where("model_name = ?", modelName).First(&config).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + if err != nil { + return nil, err + } + return &config, nil +} + +// ListModelSchedulings returns all scheduling configs. +func (r *NodeRegistry) ListModelSchedulings(ctx context.Context) ([]ModelSchedulingConfig, error) { + var configs []ModelSchedulingConfig + err := r.db.WithContext(ctx).Find(&configs).Error + return configs, err +} + +// ListAutoScalingConfigs returns scheduling configs where auto-scaling is enabled. +func (r *NodeRegistry) ListAutoScalingConfigs(ctx context.Context) ([]ModelSchedulingConfig, error) { + var configs []ModelSchedulingConfig + err := r.db.WithContext(ctx).Where("min_replicas > 0 OR max_replicas > 0").Find(&configs).Error + return configs, err +} + +// DeleteModelScheduling removes a scheduling config by model name. +func (r *NodeRegistry) DeleteModelScheduling(ctx context.Context, modelName string) error { + return r.db.WithContext(ctx).Where("model_name = ?", modelName).Delete(&ModelSchedulingConfig{}).Error +} + +// CountLoadedReplicas returns the number of loaded replicas for a model. +func (r *NodeRegistry) CountLoadedReplicas(ctx context.Context, modelName string) (int64, error) { + var count int64 + err := r.db.WithContext(ctx).Model(&NodeModel{}).Where("model_name = ? AND state = ?", modelName, "loaded").Count(&count).Error + return count, err +} + +// --- Composite queries --- + +// ListWithExtras returns all nodes with model counts and labels. +func (r *NodeRegistry) ListWithExtras(ctx context.Context) ([]NodeWithExtras, error) { + // Get all nodes + var nodes []BackendNode + if err := r.db.WithContext(ctx).Find(&nodes).Error; err != nil { + return nil, err + } + + // Get model counts per node + type modelCount struct { + NodeID string + Count int + } + var counts []modelCount + if err := r.db.WithContext(ctx).Model(&NodeModel{}). + Select("node_id, COUNT(*) as count"). + Where("state = ?", "loaded"). + Group("node_id"). + Find(&counts).Error; err != nil { + xlog.Warn("ListWithExtras: failed to get model counts", "error", err) + } + + countMap := make(map[string]int) + for _, c := range counts { + countMap[c.NodeID] = c.Count + } + + // Get all labels + labelsMap, err := r.GetAllNodeLabelsMap(ctx) + if err != nil { + xlog.Warn("ListWithExtras: failed to get labels", "error", err) + } + + // Build result + result := make([]NodeWithExtras, len(nodes)) + for i, n := range nodes { + result[i] = NodeWithExtras{ + BackendNode: n, + ModelCount: countMap[n.ID], + Labels: labelsMap[n.ID], + } + } + return result, nil +} + +// ApplyAutoLabels sets automatic labels based on node hardware info. +func (r *NodeRegistry) ApplyAutoLabels(ctx context.Context, nodeID string, node *BackendNode) { + if node.GPUVendor != "" { + _ = r.SetNodeLabel(ctx, nodeID, "gpu.vendor", node.GPUVendor) + } + if node.TotalVRAM > 0 { + gb := node.TotalVRAM / (1024 * 1024 * 1024) + var bucket string + switch { + case gb >= 80: + bucket = "80GB+" + case gb >= 48: + bucket = "48GB" + case gb >= 24: + bucket = "24GB" + case gb >= 16: + bucket = "16GB" + case gb >= 8: + bucket = "8GB" + default: + bucket = fmt.Sprintf("%dGB", gb) + } + _ = r.SetNodeLabel(ctx, nodeID, "gpu.vram", bucket) + } + if node.Name != "" { + _ = r.SetNodeLabel(ctx, nodeID, "node.name", node.Name) + } +} diff --git a/core/services/nodes/registry_test.go b/core/services/nodes/registry_test.go index a7dbd6a54..7b0a747b1 100644 --- a/core/services/nodes/registry_test.go +++ b/core/services/nodes/registry_test.go @@ -305,6 +305,252 @@ var _ = Describe("NodeRegistry", func() { }) }) + Describe("NodeLabel CRUD", func() { + It("sets and retrieves labels for a node", func() { + node := makeNode("label-node", "10.0.0.70:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + + Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed()) + + labels, err := registry.GetNodeLabels(context.Background(), node.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(labels).To(HaveLen(2)) + + labelMap := make(map[string]string) + for _, l := range labels { + labelMap[l.Key] = l.Value + } + Expect(labelMap["env"]).To(Equal("prod")) + Expect(labelMap["region"]).To(Equal("us-east")) + }) + + It("overwrites existing label with same key", func() { + node := makeNode("label-overwrite", "10.0.0.71:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + + Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "dev")).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed()) + + labels, err := registry.GetNodeLabels(context.Background(), node.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(labels).To(HaveLen(1)) + Expect(labels[0].Value).To(Equal("prod")) + }) + + It("removes a single label by key", func() { + node := makeNode("label-remove", "10.0.0.72:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + + Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed()) + + Expect(registry.RemoveNodeLabel(context.Background(), node.ID, "env")).To(Succeed()) + + labels, err := registry.GetNodeLabels(context.Background(), node.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(labels).To(HaveLen(1)) + Expect(labels[0].Key).To(Equal("region")) + }) + + It("SetNodeLabels replaces all labels", func() { + node := makeNode("label-replace", "10.0.0.73:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), node, true)).To(Succeed()) + + Expect(registry.SetNodeLabel(context.Background(), node.ID, "old-key", "old-val")).To(Succeed()) + + newLabels := map[string]string{"new-a": "val-a", "new-b": "val-b"} + Expect(registry.SetNodeLabels(context.Background(), node.ID, newLabels)).To(Succeed()) + + labels, err := registry.GetNodeLabels(context.Background(), node.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(labels).To(HaveLen(2)) + + labelMap := make(map[string]string) + for _, l := range labels { + labelMap[l.Key] = l.Value + } + Expect(labelMap).To(Equal(newLabels)) + }) + }) + + Describe("FindNodesBySelector", func() { + It("returns nodes matching all labels in selector", func() { + n1 := makeNode("sel-match", "10.0.0.80:50051", 8_000_000_000) + n2 := makeNode("sel-nomatch", "10.0.0.81:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) + Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) + + Expect(registry.SetNodeLabel(context.Background(), n1.ID, "env", "prod")).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), n1.ID, "region", "us-east")).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), n2.ID, "env", "dev")).To(Succeed()) + + nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod", "region": "us-east"}) + Expect(err).ToNot(HaveOccurred()) + Expect(nodes).To(HaveLen(1)) + Expect(nodes[0].Name).To(Equal("sel-match")) + }) + + It("returns empty when no nodes match", func() { + n := makeNode("sel-empty", "10.0.0.82:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), n, true)).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "dev")).To(Succeed()) + + nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"}) + Expect(err).ToNot(HaveOccurred()) + Expect(nodes).To(BeEmpty()) + }) + + It("ignores unhealthy nodes", func() { + n := makeNode("sel-unhealthy", "10.0.0.83:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), n, true)).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed()) + Expect(registry.MarkUnhealthy(context.Background(), n.ID)).To(Succeed()) + + nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"}) + Expect(err).ToNot(HaveOccurred()) + Expect(nodes).To(BeEmpty()) + }) + + It("matches nodes with more labels than selector requires", func() { + n := makeNode("sel-superset", "10.0.0.84:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), n, true)).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), n.ID, "region", "us-east")).To(Succeed()) + Expect(registry.SetNodeLabel(context.Background(), n.ID, "tier", "gpu")).To(Succeed()) + + nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"}) + Expect(err).ToNot(HaveOccurred()) + Expect(nodes).To(HaveLen(1)) + Expect(nodes[0].Name).To(Equal("sel-superset")) + }) + + It("returns all healthy nodes for empty selector", func() { + n1 := makeNode("sel-all-1", "10.0.0.85:50051", 8_000_000_000) + n2 := makeNode("sel-all-2", "10.0.0.86:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) + Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) + + nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(nodes)).To(BeNumerically(">=", 2)) + }) + }) + + Describe("ModelSchedulingConfig CRUD", func() { + It("creates and retrieves a scheduling config", func() { + config := &ModelSchedulingConfig{ + ModelName: "llama-7b", + NodeSelector: `{"gpu.vendor":"nvidia"}`, + MinReplicas: 1, + MaxReplicas: 3, + } + Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed()) + Expect(config.ID).ToNot(BeEmpty()) + + fetched, err := registry.GetModelScheduling(context.Background(), "llama-7b") + Expect(err).ToNot(HaveOccurred()) + Expect(fetched).ToNot(BeNil()) + Expect(fetched.ModelName).To(Equal("llama-7b")) + Expect(fetched.NodeSelector).To(Equal(`{"gpu.vendor":"nvidia"}`)) + Expect(fetched.MinReplicas).To(Equal(1)) + Expect(fetched.MaxReplicas).To(Equal(3)) + }) + + It("updates existing config via SetModelScheduling", func() { + config := &ModelSchedulingConfig{ + ModelName: "update-model", + MinReplicas: 1, + MaxReplicas: 2, + } + Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed()) + + config2 := &ModelSchedulingConfig{ + ModelName: "update-model", + MinReplicas: 2, + MaxReplicas: 5, + } + Expect(registry.SetModelScheduling(context.Background(), config2)).To(Succeed()) + + fetched, err := registry.GetModelScheduling(context.Background(), "update-model") + Expect(err).ToNot(HaveOccurred()) + Expect(fetched.MinReplicas).To(Equal(2)) + Expect(fetched.MaxReplicas).To(Equal(5)) + }) + + It("lists all configs", func() { + Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed()) + Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed()) + + configs, err := registry.ListModelSchedulings(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(len(configs)).To(BeNumerically(">=", 2)) + }) + + It("lists only auto-scaling configs", func() { + Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-a", MinReplicas: 2})).To(Succeed()) + Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-b", MaxReplicas: 3})).To(Succeed()) + Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "no-auto", NodeSelector: `{"env":"prod"}`})).To(Succeed()) + + configs, err := registry.ListAutoScalingConfigs(context.Background()) + Expect(err).ToNot(HaveOccurred()) + + names := make([]string, len(configs)) + for i, c := range configs { + names[i] = c.ModelName + } + Expect(names).To(ContainElement("auto-a")) + Expect(names).To(ContainElement("auto-b")) + Expect(names).ToNot(ContainElement("no-auto")) + }) + + It("deletes a config", func() { + Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "delete-me", MinReplicas: 1})).To(Succeed()) + + Expect(registry.DeleteModelScheduling(context.Background(), "delete-me")).To(Succeed()) + + fetched, err := registry.GetModelScheduling(context.Background(), "delete-me") + Expect(err).ToNot(HaveOccurred()) + Expect(fetched).To(BeNil()) + }) + + It("returns nil for non-existent model", func() { + fetched, err := registry.GetModelScheduling(context.Background(), "does-not-exist") + Expect(err).ToNot(HaveOccurred()) + Expect(fetched).To(BeNil()) + }) + }) + + Describe("CountLoadedReplicas", func() { + It("returns correct count of loaded replicas", func() { + n1 := makeNode("replica-node-1", "10.0.0.90:50051", 8_000_000_000) + n2 := makeNode("replica-node-2", "10.0.0.91:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) + Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) + + Expect(registry.SetNodeModel(context.Background(), n1.ID, "counted-model", "loaded", "", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), n2.ID, "counted-model", "loaded", "", 0)).To(Succeed()) + + count, err := registry.CountLoadedReplicas(context.Background(), "counted-model") + Expect(err).ToNot(HaveOccurred()) + Expect(count).To(Equal(int64(2))) + }) + + It("excludes non-loaded states", func() { + n1 := makeNode("replica-loaded", "10.0.0.92:50051", 8_000_000_000) + n2 := makeNode("replica-loading", "10.0.0.93:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) + Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) + + Expect(registry.SetNodeModel(context.Background(), n1.ID, "state-model", "loaded", "", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), n2.ID, "state-model", "loading", "", 0)).To(Succeed()) + + count, err := registry.CountLoadedReplicas(context.Background(), "state-model") + Expect(err).ToNot(HaveOccurred()) + Expect(count).To(Equal(int64(1))) + }) + }) + Describe("DecrementInFlight", func() { It("does not go below zero", func() { node := makeNode("dec-node", "10.0.0.50:50051", 4_000_000_000) diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go index c61a11d91..7d77b62bc 100644 --- a/core/services/nodes/router.go +++ b/core/services/nodes/router.go @@ -2,6 +2,7 @@ package nodes import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -69,6 +70,18 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter // Unloader returns the remote unloader adapter for external use. func (r *SmartRouter) Unloader() NodeCommandSender { return r.unloader } +// ScheduleAndLoadModel implements ModelScheduler for the reconciler. +// It schedules a model on a suitable node (optionally from candidates) and loads it. +func (r *SmartRouter) ScheduleAndLoadModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) { + // Use scheduleNewModel with empty backend type and nil model options. + // The reconciler doesn't know the backend type — it will be determined by the model config. + node, _, err := r.scheduleNewModel(ctx, "", modelName, nil) + if err != nil { + return nil, err + } + return node, nil +} + // RouteResult contains the routing decision. type RouteResult struct { Node *BackendNode @@ -110,18 +123,26 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType xlog.Warn("Backend not reachable for cached model, falling through to reload", "node", node.Name, "model", modelName) } else { - // Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight, - // and Release decrements it. No InFlightTrackingClient to avoid double-counting. - r.registry.TouchNodeModel(ctx, node.ID, trackingKey) - grpcClient := r.buildClientForAddr(node, modelAddr, parallel) - return &RouteResult{ - Node: node, - Client: grpcClient, - Release: func() { - r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) - closeClient(grpcClient) - }, - }, nil + // Verify node still matches scheduling constraints + if !r.nodeMatchesScheduling(ctx, node, trackingKey) { + r.registry.DecrementInFlight(ctx, node.ID, trackingKey) + xlog.Info("Cached model on node that no longer matches selector, falling through", + "node", node.Name, "model", trackingKey) + // Fall through to step 2 (scheduleNewModel) + } else { + // Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight, + // and Release decrements it. No InFlightTrackingClient to avoid double-counting. + r.registry.TouchNodeModel(ctx, node.ID, trackingKey) + grpcClient := r.buildClientForAddr(node, modelAddr, parallel) + return &RouteResult{ + Node: node, + Client: grpcClient, + Release: func() { + r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) + closeClient(grpcClient) + }, + }, nil + } } } @@ -143,17 +164,25 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType xlog.Warn("Backend not reachable for cached model inside lock, proceeding to load", "node", node.Name, "model", modelName) } else { - // Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting - r.registry.TouchNodeModel(ctx, node.ID, trackingKey) - grpcClient := r.buildClientForAddr(node, modelAddr, parallel) - return &RouteResult{ - Node: node, - Client: grpcClient, - Release: func() { - r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) - closeClient(grpcClient) - }, - }, nil + // Verify node still matches scheduling constraints + if !r.nodeMatchesScheduling(ctx, node, trackingKey) { + r.registry.DecrementInFlight(ctx, node.ID, trackingKey) + xlog.Info("Cached model on node that no longer matches selector, falling through", + "node", node.Name, "model", trackingKey) + // Fall through to scheduling below + } else { + // Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting + r.registry.TouchNodeModel(ctx, node.ID, trackingKey) + grpcClient := r.buildClientForAddr(node, modelAddr, parallel) + return &RouteResult{ + Node: node, + Client: grpcClient, + Release: func() { + r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey) + closeClient(grpcClient) + }, + }, nil + } } } @@ -223,6 +252,59 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType return loadModel() } +// parseSelectorJSON decodes a JSON node selector string into a map. +func parseSelectorJSON(selectorJSON string) map[string]string { + if selectorJSON == "" { + return nil + } + var selector map[string]string + if err := json.Unmarshal([]byte(selectorJSON), &selector); err != nil { + xlog.Warn("Failed to parse node selector", "selector", selectorJSON, "error", err) + return nil + } + return selector +} + +func extractNodeIDs(nodes []BackendNode) []string { + ids := make([]string, len(nodes)) + for i, n := range nodes { + ids[i] = n.ID + } + return ids +} + +// nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model. +// Returns true if no constraints exist or the node matches all selector labels. +func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool { + sched, err := r.registry.GetModelScheduling(ctx, modelName) + if err != nil || sched == nil || sched.NodeSelector == "" { + return true // no constraints + } + + selector := parseSelectorJSON(sched.NodeSelector) + if len(selector) == 0 { + return true + } + + labels, err := r.registry.GetNodeLabels(ctx, node.ID) + if err != nil { + xlog.Warn("Failed to get node labels for selector check", "node", node.ID, "error", err) + return true // fail open + } + + labelMap := make(map[string]string) + for _, l := range labels { + labelMap[l.Key] = l.Value + } + + for k, v := range selector { + if labelMap[k] != v { + return false + } + } + return true +} + // scheduleNewModel picks the best node for loading a new model. // Strategy: VRAM-aware → idle-first → least-loaded. // Sends backend.install via NATS so the chosen node has the right backend running. @@ -233,12 +315,30 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID estimatedVRAM = r.estimateModelVRAM(ctx, modelOpts) } + // Check for scheduling constraints (node selector) + sched, _ := r.registry.GetModelScheduling(ctx, modelID) + var candidateNodeIDs []string // nil = all nodes eligible + + if sched != nil && sched.NodeSelector != "" { + selector := parseSelectorJSON(sched.NodeSelector) + if len(selector) > 0 { + candidates, err := r.registry.FindNodesBySelector(ctx, selector) + if err != nil || len(candidates) == 0 { + return nil, "", fmt.Errorf("no healthy nodes match selector for model %s: %v", modelID, sched.NodeSelector) + } + candidateNodeIDs = extractNodeIDs(candidates) + } + } + var node *BackendNode var err error if estimatedVRAM > 0 { - // 1. Prefer nodes with enough VRAM (idle-first, then least-loaded) - node, err = r.registry.FindNodeWithVRAM(ctx, estimatedVRAM) + if candidateNodeIDs != nil { + node, err = r.registry.FindNodeWithVRAMFromSet(ctx, estimatedVRAM, candidateNodeIDs) + } else { + node, err = r.registry.FindNodeWithVRAM(ctx, estimatedVRAM) + } if err != nil { xlog.Warn("No nodes with enough VRAM, falling back to standard scheduling", "required_vram", vram.FormatBytes(estimatedVRAM), "error", err) @@ -246,11 +346,16 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID } if node == nil { - // 2. Prefer truly idle nodes (no loaded models, no in-flight) - node, err = r.registry.FindIdleNode(ctx) - if err != nil { - // 3. Fall back to least-loaded node (can run an additional backend process) - node, err = r.registry.FindLeastLoadedNode(ctx) + if candidateNodeIDs != nil { + node, err = r.registry.FindIdleNodeFromSet(ctx, candidateNodeIDs) + if err != nil { + node, err = r.registry.FindLeastLoadedNodeFromSet(ctx, candidateNodeIDs) + } + } else { + node, err = r.registry.FindIdleNode(ctx) + if err != nil { + node, err = r.registry.FindLeastLoadedNode(ctx) + } } } @@ -610,7 +715,12 @@ func (r *SmartRouter) evictLRUAndFreeNode(ctx context.Context) (*BackendNode, er // Lock the row so no other frontend can evict the same model if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id"). - Where("node_models.in_flight = 0 AND node_models.state = ? AND backend_nodes.status = ?", "loaded", StatusHealthy). + Where(`node_models.in_flight = 0 AND node_models.state = ? AND backend_nodes.status = ? + AND ( + NOT EXISTS (SELECT 1 FROM model_scheduling_configs sc WHERE sc.model_name = node_models.model_name AND (sc.min_replicas > 0 OR sc.max_replicas > 0)) + OR (SELECT COUNT(*) FROM node_models nm2 WHERE nm2.model_name = node_models.model_name AND nm2.state = 'loaded') + > COALESCE((SELECT sc2.min_replicas FROM model_scheduling_configs sc2 WHERE sc2.model_name = node_models.model_name), 1) + )`, "loaded", StatusHealthy). Order("node_models.last_used ASC"). First(&lru).Error; err != nil { return err diff --git a/core/services/nodes/router_test.go b/core/services/nodes/router_test.go index 9f807b34b..8c53531fe 100644 --- a/core/services/nodes/router_test.go +++ b/core/services/nodes/router_test.go @@ -86,6 +86,26 @@ type fakeModelRouter struct { getNode *BackendNode getErr error + // GetModelScheduling returns + getModelScheduling *ModelSchedulingConfig + getModelSchedErr error + + // FindNodesBySelector returns + findBySelectorNodes []BackendNode + findBySelectorErr error + + // *FromSet variants + findVRAMFromSetNode *BackendNode + findVRAMFromSetErr error + findIdleFromSetNode *BackendNode + findIdleFromSetErr error + findLeastLoadedFromSetNode *BackendNode + findLeastLoadedFromSetErr error + + // GetNodeLabels returns + getNodeLabels []NodeLabel + getNodeLabelsErr error + // Track calls for assertions decrementCalls []string // "nodeID:modelName" incrementCalls []string @@ -146,6 +166,30 @@ func (f *fakeModelRouter) Get(_ context.Context, _ string) (*BackendNode, error) return f.getNode, f.getErr } +func (f *fakeModelRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) { + return f.getModelScheduling, f.getModelSchedErr +} + +func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) { + return f.findBySelectorNodes, f.findBySelectorErr +} + +func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) { + return f.findVRAMFromSetNode, f.findVRAMFromSetErr +} + +func (f *fakeModelRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) { + return f.findIdleFromSetNode, f.findIdleFromSetErr +} + +func (f *fakeModelRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) { + return f.findLeastLoadedFromSetNode, f.findLeastLoadedFromSetErr +} + +func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) { + return f.getNodeLabels, f.getNodeLabelsErr +} + // --------------------------------------------------------------------------- // Fake BackendClientFactory + Backend // --------------------------------------------------------------------------- @@ -478,6 +522,135 @@ var _ = Describe("SmartRouter", func() { }) }) + Describe("scheduleNewModel with node selector (mock-based, via Route)", func() { + var ( + reg *fakeModelRouter + backend *stubBackend + factory *stubClientFactory + unloader *fakeUnloader + ) + + BeforeEach(func() { + reg = &fakeModelRouter{ + findAndLockErr: errors.New("not found"), + } + backend = &stubBackend{ + loadResult: &pb.Result{Success: true}, + } + factory = &stubClientFactory{client: backend} + unloader = &fakeUnloader{ + installReply: &messaging.BackendInstallReply{ + Success: true, + Address: "10.0.0.1:9001", + }, + } + }) + + It("uses *FromSet methods when model has a node selector", func() { + gpuNode := &BackendNode{ID: "gpu-1", Name: "gpu-node", Address: "10.0.0.50:50051"} + reg.getModelScheduling = &ModelSchedulingConfig{ + ModelName: "selector-model", + NodeSelector: `{"gpu.vendor":"nvidia"}`, + } + reg.findBySelectorNodes = []BackendNode{*gpuNode} + reg.findIdleFromSetNode = gpuNode + + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + }) + + result, err := router.Route(context.Background(), "selector-model", "models/selector.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + Expect(result).ToNot(BeNil()) + Expect(result.Node.ID).To(Equal("gpu-1")) + }) + + It("returns error when no nodes match selector", func() { + reg.getModelScheduling = &ModelSchedulingConfig{ + ModelName: "no-match-model", + NodeSelector: `{"gpu.vendor":"tpu"}`, + } + reg.findBySelectorNodes = nil + reg.findBySelectorErr = nil + + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + }) + + _, err := router.Route(context.Background(), "no-match-model", "models/nomatch.gguf", "llama-cpp", nil, false) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no healthy nodes match selector")) + }) + + It("uses regular methods when model has no scheduling config", func() { + reg.getModelScheduling = nil + idleNode := &BackendNode{ID: "regular-1", Name: "regular-node", Address: "10.0.0.60:50051"} + reg.findIdleNode = idleNode + + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + }) + + result, err := router.Route(context.Background(), "regular-model", "models/regular.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + Expect(result).ToNot(BeNil()) + Expect(result.Node.ID).To(Equal("regular-1")) + }) + }) + + Describe("Route with selector validation on cached model (mock-based)", func() { + It("falls through when cached node no longer matches selector", func() { + cachedNode := &BackendNode{ID: "n-old", Name: "old-node", Address: "10.0.0.70:50051"} + newNode := &BackendNode{ID: "n-new", Name: "new-node", Address: "10.0.0.71:50051"} + + backend := &stubBackend{ + healthResult: true, + loadResult: &pb.Result{Success: true}, + } + factory := &stubClientFactory{client: backend} + unloader := &fakeUnloader{ + installReply: &messaging.BackendInstallReply{ + Success: true, + Address: "10.0.0.71:9001", + }, + } + + reg := &fakeModelRouter{ + // Step 1: cached model found on old node + findAndLockNode: cachedNode, + findAndLockNM: &NodeModel{NodeID: "n-old", ModelName: "sel-model", Address: "10.0.0.70:9001"}, + // Scheduling config with selector that old node does NOT match + getModelScheduling: &ModelSchedulingConfig{ + ModelName: "sel-model", + NodeSelector: `{"gpu.vendor":"nvidia"}`, + }, + // Old node has no labels matching the selector + getNodeLabels: []NodeLabel{ + {NodeID: "n-old", Key: "gpu.vendor", Value: "amd"}, + }, + // For scheduling fallthrough: selector matches new node + findBySelectorNodes: []BackendNode{*newNode}, + findIdleFromSetNode: newNode, + } + + router := NewSmartRouter(reg, SmartRouterOptions{ + Unloader: unloader, + ClientFactory: factory, + }) + + result, err := router.Route(context.Background(), "sel-model", "models/sel.gguf", "llama-cpp", nil, false) + Expect(err).ToNot(HaveOccurred()) + Expect(result).ToNot(BeNil()) + // Should have fallen through to the new node + Expect(result.Node.ID).To(Equal("n-new")) + // Old node should have had its in-flight decremented + Expect(reg.decrementCalls).To(ContainElement("n-old:sel-model")) + }) + }) + // ----------------------------------------------------------------------- // Integration tests using real PostgreSQL (existing) // ----------------------------------------------------------------------- diff --git a/docs/content/features/distributed-mode.md b/docs/content/features/distributed-mode.md index 0433de3ee..b942c244b 100644 --- a/docs/content/features/distributed-mode.md +++ b/docs/content/features/distributed-mode.md @@ -134,6 +134,47 @@ local-ai worker \ **HTTP file transfer:** Each worker also runs a small HTTP server for file transfer (model files, configs). By default it listens on the gRPC base port - 1 (e.g., if gRPC base is 50051, HTTP is on 50050). gRPC ports grow upward from the base port as additional models are loaded. Set `--advertise-http-addr` if the auto-detected address is not routable from the frontend. {{% /notice %}} +### Worker Address Configuration + +The simplest way to configure a worker's network address is with a single variable: + +| Variable | Description | +|----------|-------------| +| `LOCALAI_ADDR` | Reachable address of this worker (`host:port`). The port is used as the base for gRPC backend processes, and `port-1` for the HTTP file transfer server. | + +**Example:** +```yaml +environment: + LOCALAI_ADDR: "192.168.1.100:50051" + LOCALAI_NATS_URL: "nats://frontend:4222" + LOCALAI_REGISTER_TO: "http://frontend:8080" + LOCALAI_REGISTRATION_TOKEN: "my-secret" +``` + +For advanced networking scenarios (NAT, load balancers, separate gRPC/HTTP ports), the following override variables are available: + +| Variable | Description | Default | +|----------|-------------|---------| +| `LOCALAI_SERVE_ADDR` | gRPC base port bind address | `0.0.0.0:50051` | +| `LOCALAI_HTTP_ADDR` | HTTP file transfer bind address | `0.0.0.0:{gRPC port - 1}` | +| `LOCALAI_ADVERTISE_ADDR` | Public gRPC address (if different from `LOCALAI_ADDR`) | Derived from `LOCALAI_ADDR` | +| `LOCALAI_ADVERTISE_HTTP_ADDR` | Public HTTP address (if different from gRPC host) | Derived from advertise host + HTTP port | + +### Node Labels + +Workers can declare labels at startup for scheduling constraints: + +| Variable | Description | Example | +|----------|-------------|---------| +| `LOCALAI_NODE_LABELS` | Comma-separated `key=value` labels | `tier=premium,gpu=a100,zone=us-east` | + +Labels can also be managed via the admin API (see [Label Management API](#label-management-api) below). + +The system automatically applies hardware-detected labels on registration: +- `gpu.vendor` -- GPU vendor (nvidia, amd, intel, vulkan) +- `gpu.vram` -- GPU VRAM bucket (8GB, 16GB, 24GB, 48GB, 80GB+) +- `node.name` -- The node's registered name + ### How Workers Operate Workers start as generic processes with no backend installed. When the SmartRouter needs to load a model on a worker, it sends a NATS `backend.install` event with the backend name and model ID. The worker: @@ -262,6 +303,75 @@ local-ai worker \ **Multiple frontend replicas:** Run multiple LocalAI frontends behind a load balancer. Since all state is in PostgreSQL and coordination is via NATS, frontends are fully stateless and interchangeable. +## Model Scheduling + +Model scheduling controls where models are placed and how many replicas are maintained. It combines two optional features: + +### Node Selectors + +Pin models to nodes with specific labels. Only nodes matching **all** selector labels are eligible: + +```bash +# Only schedule on NVIDIA nodes in the us-east zone +curl -X POST http://frontend:8080/api/nodes/scheduling \ + -H "Content-Type: application/json" \ + -d '{"model_name": "llama3", "node_selector": {"gpu.vendor": "nvidia", "zone": "us-east"}}' +``` + +Without a node selector, models can schedule on any healthy node (default behavior). + +### Replica Auto-Scaling + +Control the number of model replicas across the cluster: + +| Field | Description | +|-------|-------------| +| `min_replicas` | Minimum replicas to maintain (0 = no minimum, single replica default) | +| `max_replicas` | Maximum replicas allowed (0 = unlimited) | + +Auto-scaling is **only active** when `min_replicas > 0` or `max_replicas > 0`. + +```bash +# Scale llama3 between 2 and 4 replicas on NVIDIA nodes +curl -X POST http://frontend:8080/api/nodes/scheduling \ + -H "Content-Type: application/json" \ + -d '{ + "model_name": "llama3", + "node_selector": {"gpu.vendor": "nvidia"}, + "min_replicas": 2, + "max_replicas": 4 + }' +``` + +The **Replica Reconciler** runs as a background process on the frontend: +- **Scale up**: Adds replicas when all existing replicas are busy (have in-flight requests) +- **Scale down**: Removes idle replicas after 5 minutes of inactivity +- **Maintain minimum**: Ensures `min_replicas` are always loaded (recovers from node failures) +- **Eviction protection**: Models with auto-scaling enabled are never evicted below `min_replicas` + +All fields are optional and composable: +- Node selector only: pin model to matching nodes, single replica +- Replicas only: auto-scale across all nodes +- Both: auto-scale on matching nodes only + +## Label Management API + +| Method | Path | Description | +|--------|------|-------------| +| `GET` | `/api/nodes/:id/labels` | Get labels for a node | +| `PUT` | `/api/nodes/:id/labels` | Replace all labels (JSON object) | +| `PATCH` | `/api/nodes/:id/labels` | Merge labels (add/update) | +| `DELETE` | `/api/nodes/:id/labels/:key` | Remove a single label | + +## Scheduling API + +| Method | Path | Description | +|--------|------|-------------| +| `GET` | `/api/nodes/scheduling` | List all scheduling configs | +| `GET` | `/api/nodes/scheduling/:model` | Get config for a model | +| `POST` | `/api/nodes/scheduling` | Create/update config | +| `DELETE` | `/api/nodes/scheduling/:model` | Remove config | + ## Comparison with P2P | | P2P / Federation | Distributed Mode | diff --git a/docs/content/features/runtime-settings.md b/docs/content/features/runtime-settings.md index 79d4ca7c5..1110f1cf7 100644 --- a/docs/content/features/runtime-settings.md +++ b/docs/content/features/runtime-settings.md @@ -28,7 +28,6 @@ Changes to watchdog settings are applied immediately by restarting the watchdog ### Backend Configuration - **Max Active Backends**: Maximum number of active backends (loaded models). When exceeded, the least recently used model is automatically evicted. Set to `0` for unlimited, `1` for single-backend mode -- **Parallel Backend Requests**: Enable backends to handle multiple requests in parallel if supported - **Force Eviction When Busy**: Allow evicting models even when they have active API calls (default: disabled for safety). **Warning:** Enabling this can interrupt active requests - **LRU Eviction Max Retries**: Maximum number of retries when waiting for busy models to become idle before eviction (default: 30) - **LRU Eviction Retry Interval**: Interval between retries when waiting for busy models (default: `1s`) @@ -123,7 +122,6 @@ The `runtime_settings.json` file follows this structure: "watchdog_idle_timeout": "15m", "watchdog_busy_timeout": "5m", "max_active_backends": 0, - "parallel_backend_requests": true, "force_eviction_when_busy": false, "lru_eviction_max_retries": 30, "lru_eviction_retry_interval": "1s", diff --git a/docs/content/reference/cli-reference.md b/docs/content/reference/cli-reference.md index 9743c8682..556cf5a99 100644 --- a/docs/content/reference/cli-reference.md +++ b/docs/content/reference/cli-reference.md @@ -39,7 +39,6 @@ Complete reference for all LocalAI command-line interface (CLI) parameters and e | `--external-grpc-backends` | | A list of external gRPC backends (format: `BACKEND_NAME:URI`) | `$LOCALAI_EXTERNAL_GRPC_BACKENDS`, `$EXTERNAL_GRPC_BACKENDS` | | `--backend-galleries` | | JSON list of backend galleries | `$LOCALAI_BACKEND_GALLERIES`, `$BACKEND_GALLERIES` | | `--autoload-backend-galleries` | `true` | Automatically load backend galleries on startup | `$LOCALAI_AUTOLOAD_BACKEND_GALLERIES`, `$AUTOLOAD_BACKEND_GALLERIES` | -| `--parallel-requests` | `false` | Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm) | `$LOCALAI_PARALLEL_REQUESTS`, `$PARALLEL_REQUESTS` | | `--max-active-backends` | `0` | Maximum number of active backends (loaded models). When exceeded, the least recently used model is evicted. Set to `0` for unlimited, `1` for single-backend mode | `$LOCALAI_MAX_ACTIVE_BACKENDS`, `$MAX_ACTIVE_BACKENDS` | | `--single-active-backend` | `false` | **DEPRECATED** - Use `--max-active-backends=1` instead. Allow only one backend to be run at a time | `$LOCALAI_SINGLE_ACTIVE_BACKEND`, `$SINGLE_ACTIVE_BACKEND` | | `--preload-backend-only` | `false` | Do not launch the API services, only the preloaded models/backends are started (useful for multi-node setups) | `$LOCALAI_PRELOAD_BACKEND_ONLY`, `$PRELOAD_BACKEND_ONLY` |