diff --git a/core/application/config_file_watcher.go b/core/application/config_file_watcher.go
index 094eebbbe..8eb26355d 100644
--- a/core/application/config_file_watcher.go
+++ b/core/application/config_file_watcher.go
@@ -199,7 +199,6 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout
envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend
envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends
- envParallelRequests := appConfig.ParallelBackendRequests == startupAppConfig.ParallelBackendRequests
envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled
envMemoryReclaimerThreshold := appConfig.MemoryReclaimerThreshold == startupAppConfig.MemoryReclaimerThreshold
envThreads := appConfig.Threads == startupAppConfig.Threads
@@ -271,9 +270,6 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
appConfig.MaxActiveBackends = 0
}
}
- if settings.ParallelBackendRequests != nil && !envParallelRequests {
- appConfig.ParallelBackendRequests = *settings.ParallelBackendRequests
- }
if settings.MemoryReclaimerEnabled != nil && !envMemoryReclaimerEnabled {
appConfig.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
if appConfig.MemoryReclaimerEnabled {
diff --git a/core/application/distributed.go b/core/application/distributed.go
index 257fecdd4..31e87fdab 100644
--- a/core/application/distributed.go
+++ b/core/application/distributed.go
@@ -7,6 +7,7 @@ import (
"io"
"strings"
"sync"
+ "time"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
@@ -28,6 +29,7 @@ type DistributedServices struct {
Registry *nodes.NodeRegistry
Router *nodes.SmartRouter
Health *nodes.HealthMonitor
+ Reconciler *nodes.ReplicaReconciler
JobStore *jobs.JobStore
Dispatcher *jobs.Dispatcher
AgentStore *agents.AgentStore
@@ -240,6 +242,16 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
DB: authDB,
})
+ // Create ReplicaReconciler for auto-scaling model replicas
+ reconciler := nodes.NewReplicaReconciler(nodes.ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: router,
+ Unloader: remoteUnloader,
+ DB: authDB,
+ Interval: 30 * time.Second,
+ ScaleDownDelay: 5 * time.Minute,
+ })
+
// Create ModelRouterAdapter to wire into ModelLoader
modelAdapter := nodes.NewModelRouterAdapter(router)
@@ -250,6 +262,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
Registry: registry,
Router: router,
Health: healthMon,
+ Reconciler: reconciler,
JobStore: jobStore,
Dispatcher: dispatcher,
AgentStore: agentStore,
diff --git a/core/application/startup.go b/core/application/startup.go
index 026e55226..728c3c972 100644
--- a/core/application/startup.go
+++ b/core/application/startup.go
@@ -158,6 +158,10 @@ func New(opts ...config.AppOption) (*Application, error) {
application.modelLoader.SetModelStore(distStore)
// Start health monitor
distSvc.Health.Start(options.Context)
+ // Start replica reconciler for auto-scaling model replicas
+ if distSvc.Reconciler != nil {
+ go distSvc.Reconciler.Run(options.Context)
+ }
// In distributed mode, MCP CI jobs are executed by agent workers (not the frontend)
// because the frontend can't create MCP sessions (e.g., stdio servers using docker).
// The dispatcher still subscribes to jobs.new for persistence (result/progress subs)
@@ -439,11 +443,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
}
}
}
- if settings.ParallelBackendRequests != nil {
- if !options.ParallelBackendRequests {
- options.ParallelBackendRequests = *settings.ParallelBackendRequests
- }
- }
if settings.MemoryReclaimerEnabled != nil {
// Only apply if current value is default (false), suggesting it wasn't set from env var
if !options.MemoryReclaimerEnabled {
diff --git a/core/backend/options.go b/core/backend/options.go
index 2d410c90a..b09782ce2 100644
--- a/core/backend/options.go
+++ b/core/backend/options.go
@@ -59,9 +59,7 @@ func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...mo
grpcOpts := grpcModelOpts(c, so.SystemState.Model.ModelsPath)
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
- if so.ParallelBackendRequests {
- defOpts = append(defOpts, model.EnableParallelRequests)
- }
+ defOpts = append(defOpts, model.EnableParallelRequests)
if c.GRPC.Attempts != 0 {
defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts))
diff --git a/core/cli/agent_test.go b/core/cli/agent_test.go
index d5d29e5d8..2a952d99a 100644
--- a/core/cli/agent_test.go
+++ b/core/cli/agent_test.go
@@ -4,211 +4,172 @@ import (
"encoding/json"
"os"
"path/filepath"
- "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
"github.com/mudler/LocalAGI/core/state"
)
-func TestAgentRunCMD_LoadAgentConfigFromFile(t *testing.T) {
- // Create a temporary agent config file
- tmpDir := t.TempDir()
- configFile := filepath.Join(tmpDir, "agent.json")
+var _ = Describe("AgentRunCMD", func() {
+ Describe("loadAgentConfig", func() {
+ It("loads agent config from file", func() {
+ tmpDir := GinkgoT().TempDir()
+ configFile := filepath.Join(tmpDir, "agent.json")
- cfg := state.AgentConfig{
- Name: "test-agent",
- Model: "llama3",
- SystemPrompt: "You are a helpful assistant",
- }
- data, err := json.MarshalIndent(cfg, "", " ")
- if err != nil {
- t.Fatal(err)
- }
- if err := os.WriteFile(configFile, data, 0644); err != nil {
- t.Fatal(err)
- }
+ cfg := state.AgentConfig{
+ Name: "test-agent",
+ Model: "llama3",
+ SystemPrompt: "You are a helpful assistant",
+ }
+ data, err := json.MarshalIndent(cfg, "", " ")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(os.WriteFile(configFile, data, 0644)).To(Succeed())
- cmd := &AgentRunCMD{
- Config: configFile,
- StateDir: tmpDir,
- }
+ cmd := &AgentRunCMD{
+ Config: configFile,
+ StateDir: tmpDir,
+ }
- loaded, err := cmd.loadAgentConfig()
- if err != nil {
- t.Fatalf("loadAgentConfig() error: %v", err)
- }
- if loaded.Name != "test-agent" {
- t.Errorf("expected name %q, got %q", "test-agent", loaded.Name)
- }
- if loaded.Model != "llama3" {
- t.Errorf("expected model %q, got %q", "llama3", loaded.Model)
- }
-}
+ loaded, err := cmd.loadAgentConfig()
+ Expect(err).ToNot(HaveOccurred())
+ Expect(loaded.Name).To(Equal("test-agent"))
+ Expect(loaded.Model).To(Equal("llama3"))
+ })
-func TestAgentRunCMD_LoadAgentConfigFromPool(t *testing.T) {
- tmpDir := t.TempDir()
+ It("loads agent config from pool", func() {
+ tmpDir := GinkgoT().TempDir()
- pool := map[string]state.AgentConfig{
- "my-agent": {
- Model: "gpt-4",
- Description: "A test agent",
- SystemPrompt: "Hello",
- },
- "other-agent": {
- Model: "llama3",
- },
- }
- data, err := json.MarshalIndent(pool, "", " ")
- if err != nil {
- t.Fatal(err)
- }
- if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
- t.Fatal(err)
- }
+ pool := map[string]state.AgentConfig{
+ "my-agent": {
+ Model: "gpt-4",
+ Description: "A test agent",
+ SystemPrompt: "Hello",
+ },
+ "other-agent": {
+ Model: "llama3",
+ },
+ }
+ data, err := json.MarshalIndent(pool, "", " ")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
- cmd := &AgentRunCMD{
- Name: "my-agent",
- StateDir: tmpDir,
- }
+ cmd := &AgentRunCMD{
+ Name: "my-agent",
+ StateDir: tmpDir,
+ }
- loaded, err := cmd.loadAgentConfig()
- if err != nil {
- t.Fatalf("loadAgentConfig() error: %v", err)
- }
- if loaded.Name != "my-agent" {
- t.Errorf("expected name %q, got %q", "my-agent", loaded.Name)
- }
- if loaded.Model != "gpt-4" {
- t.Errorf("expected model %q, got %q", "gpt-4", loaded.Model)
- }
-}
+ loaded, err := cmd.loadAgentConfig()
+ Expect(err).ToNot(HaveOccurred())
+ Expect(loaded.Name).To(Equal("my-agent"))
+ Expect(loaded.Model).To(Equal("gpt-4"))
+ })
-func TestAgentRunCMD_LoadAgentConfigFromPool_NotFound(t *testing.T) {
- tmpDir := t.TempDir()
+ It("returns error for missing agent in pool", func() {
+ tmpDir := GinkgoT().TempDir()
- pool := map[string]state.AgentConfig{
- "existing-agent": {Model: "llama3"},
- }
- data, err := json.MarshalIndent(pool, "", " ")
- if err != nil {
- t.Fatal(err)
- }
- if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
- t.Fatal(err)
- }
+ pool := map[string]state.AgentConfig{
+ "existing-agent": {Model: "llama3"},
+ }
+ data, err := json.MarshalIndent(pool, "", " ")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
- cmd := &AgentRunCMD{
- Name: "nonexistent",
- StateDir: tmpDir,
- }
+ cmd := &AgentRunCMD{
+ Name: "nonexistent",
+ StateDir: tmpDir,
+ }
- _, err = cmd.loadAgentConfig()
- if err == nil {
- t.Fatal("expected error for missing agent, got nil")
- }
-}
+ _, err = cmd.loadAgentConfig()
+ Expect(err).To(HaveOccurred())
+ })
-func TestAgentRunCMD_LoadAgentConfigNoNameOrConfig(t *testing.T) {
- cmd := &AgentRunCMD{
- StateDir: t.TempDir(),
- }
+ It("returns error when no pool.json exists", func() {
+ cmd := &AgentRunCMD{
+ StateDir: GinkgoT().TempDir(),
+ }
- _, err := cmd.loadAgentConfig()
- if err == nil {
- t.Fatal("expected error when no pool.json exists, got nil")
- }
-}
+ _, err := cmd.loadAgentConfig()
+ Expect(err).To(HaveOccurred())
+ })
-func TestAgentRunCMD_ApplyOverrides(t *testing.T) {
- cfg := &state.AgentConfig{
- Name: "test",
- }
+ It("returns error for config with no name", func() {
+ tmpDir := GinkgoT().TempDir()
+ configFile := filepath.Join(tmpDir, "agent.json")
- cmd := &AgentRunCMD{
- APIURL: "http://localhost:9090",
- APIKey: "secret",
- DefaultModel: "my-model",
- }
+ cfg := state.AgentConfig{
+ Model: "llama3",
+ }
+ data, _ := json.MarshalIndent(cfg, "", " ")
+ Expect(os.WriteFile(configFile, data, 0644)).To(Succeed())
- cmd.applyOverrides(cfg)
+ cmd := &AgentRunCMD{
+ Config: configFile,
+ StateDir: tmpDir,
+ }
- if cfg.APIURL != "http://localhost:9090" {
- t.Errorf("expected APIURL %q, got %q", "http://localhost:9090", cfg.APIURL)
- }
- if cfg.APIKey != "secret" {
- t.Errorf("expected APIKey %q, got %q", "secret", cfg.APIKey)
- }
- if cfg.Model != "my-model" {
- t.Errorf("expected Model %q, got %q", "my-model", cfg.Model)
- }
-}
+ _, err := cmd.loadAgentConfig()
+ Expect(err).To(HaveOccurred())
+ })
+ })
-func TestAgentRunCMD_ApplyOverridesDoesNotOverwriteExisting(t *testing.T) {
- cfg := &state.AgentConfig{
- Name: "test",
- Model: "existing-model",
- }
+ Describe("applyOverrides", func() {
+ It("applies overrides to empty fields", func() {
+ cfg := &state.AgentConfig{
+ Name: "test",
+ }
- cmd := &AgentRunCMD{
- DefaultModel: "override-model",
- }
+ cmd := &AgentRunCMD{
+ APIURL: "http://localhost:9090",
+ APIKey: "secret",
+ DefaultModel: "my-model",
+ }
- cmd.applyOverrides(cfg)
+ cmd.applyOverrides(cfg)
- if cfg.Model != "existing-model" {
- t.Errorf("expected Model to remain %q, got %q", "existing-model", cfg.Model)
- }
-}
+ Expect(cfg.APIURL).To(Equal("http://localhost:9090"))
+ Expect(cfg.APIKey).To(Equal("secret"))
+ Expect(cfg.Model).To(Equal("my-model"))
+ })
-func TestAgentRunCMD_LoadConfigMissingName(t *testing.T) {
- tmpDir := t.TempDir()
- configFile := filepath.Join(tmpDir, "agent.json")
+ It("does not overwrite existing model", func() {
+ cfg := &state.AgentConfig{
+ Name: "test",
+ Model: "existing-model",
+ }
- // Agent config with no name
- cfg := state.AgentConfig{
- Model: "llama3",
- }
- data, _ := json.MarshalIndent(cfg, "", " ")
- os.WriteFile(configFile, data, 0644)
+ cmd := &AgentRunCMD{
+ DefaultModel: "override-model",
+ }
- cmd := &AgentRunCMD{
- Config: configFile,
- StateDir: tmpDir,
- }
+ cmd.applyOverrides(cfg)
- _, err := cmd.loadAgentConfig()
- if err == nil {
- t.Fatal("expected error for config with no name, got nil")
- }
-}
+ Expect(cfg.Model).To(Equal("existing-model"))
+ })
+ })
+})
-func TestAgentListCMD_NoPoolFile(t *testing.T) {
- cmd := &AgentListCMD{
- StateDir: t.TempDir(),
- }
+var _ = Describe("AgentListCMD", func() {
+ It("runs without error when no pool file exists", func() {
+ cmd := &AgentListCMD{
+ StateDir: GinkgoT().TempDir(),
+ }
+ Expect(cmd.Run(nil)).To(Succeed())
+ })
- // Should not error, just print "no agents found"
- err := cmd.Run(nil)
- if err != nil {
- t.Fatalf("expected no error, got: %v", err)
- }
-}
+ It("runs without error with agents in pool", func() {
+ tmpDir := GinkgoT().TempDir()
-func TestAgentListCMD_WithAgents(t *testing.T) {
- tmpDir := t.TempDir()
+ pool := map[string]state.AgentConfig{
+ "agent-a": {Model: "llama3", Description: "First agent"},
+ "agent-b": {Model: "gpt-4"},
+ }
+ data, _ := json.MarshalIndent(pool, "", " ")
+ Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
- pool := map[string]state.AgentConfig{
- "agent-a": {Model: "llama3", Description: "First agent"},
- "agent-b": {Model: "gpt-4"},
- }
- data, _ := json.MarshalIndent(pool, "", " ")
- os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)
-
- cmd := &AgentListCMD{
- StateDir: tmpDir,
- }
-
- err := cmd.Run(nil)
- if err != nil {
- t.Fatalf("expected no error, got: %v", err)
- }
-}
+ cmd := &AgentListCMD{
+ StateDir: tmpDir,
+ }
+ Expect(cmd.Run(nil)).To(Succeed())
+ })
+})
diff --git a/core/cli/cli_suite_test.go b/core/cli/cli_suite_test.go
new file mode 100644
index 000000000..0b71fcdb9
--- /dev/null
+++ b/core/cli/cli_suite_test.go
@@ -0,0 +1,13 @@
+package cli
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestCLI(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "CLI Suite")
+}
diff --git a/core/cli/completion_test.go b/core/cli/completion_test.go
index be625d051..dc5793075 100644
--- a/core/cli/completion_test.go
+++ b/core/cli/completion_test.go
@@ -1,10 +1,9 @@
package cli
import (
- "strings"
- "testing"
-
"github.com/alecthomas/kong"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
)
func getTestApp() *kong.Application {
@@ -21,76 +20,55 @@ func getTestApp() *kong.Application {
return k.Model
}
-func TestGenerateBashCompletion(t *testing.T) {
- app := getTestApp()
- script := generateBashCompletion(app)
+var _ = Describe("Shell completions", func() {
+ var app *kong.Application
- if !strings.Contains(script, "complete -F _local_ai_completions local-ai") {
- t.Error("bash completion missing complete command registration")
- }
- if !strings.Contains(script, "run") {
- t.Error("bash completion missing 'run' command")
- }
- if !strings.Contains(script, "models") {
- t.Error("bash completion missing 'models' command")
- }
- if !strings.Contains(script, "completion") {
- t.Error("bash completion missing 'completion' command")
- }
-}
+ BeforeEach(func() {
+ app = getTestApp()
+ })
-func TestGenerateZshCompletion(t *testing.T) {
- app := getTestApp()
- script := generateZshCompletion(app)
+ Describe("generateBashCompletion", func() {
+ It("generates valid bash completion script", func() {
+ script := generateBashCompletion(app)
+ Expect(script).To(ContainSubstring("complete -F _local_ai_completions local-ai"))
+ Expect(script).To(ContainSubstring("run"))
+ Expect(script).To(ContainSubstring("models"))
+ Expect(script).To(ContainSubstring("completion"))
+ })
+ })
- if !strings.Contains(script, "#compdef local-ai") {
- t.Error("zsh completion missing compdef header")
- }
- if !strings.Contains(script, "run") {
- t.Error("zsh completion missing 'run' command")
- }
- if !strings.Contains(script, "models") {
- t.Error("zsh completion missing 'models' command")
- }
-}
+ Describe("generateZshCompletion", func() {
+ It("generates valid zsh completion script", func() {
+ script := generateZshCompletion(app)
+ Expect(script).To(ContainSubstring("#compdef local-ai"))
+ Expect(script).To(ContainSubstring("run"))
+ Expect(script).To(ContainSubstring("models"))
+ })
+ })
-func TestGenerateFishCompletion(t *testing.T) {
- app := getTestApp()
- script := generateFishCompletion(app)
+ Describe("generateFishCompletion", func() {
+ It("generates valid fish completion script", func() {
+ script := generateFishCompletion(app)
+ Expect(script).To(ContainSubstring("complete -c local-ai"))
+ Expect(script).To(ContainSubstring("__fish_use_subcommand"))
+ Expect(script).To(ContainSubstring("run"))
+ Expect(script).To(ContainSubstring("models"))
+ })
+ })
- if !strings.Contains(script, "complete -c local-ai") {
- t.Error("fish completion missing complete command")
- }
- if !strings.Contains(script, "__fish_use_subcommand") {
- t.Error("fish completion missing subcommand detection")
- }
- if !strings.Contains(script, "run") {
- t.Error("fish completion missing 'run' command")
- }
- if !strings.Contains(script, "models") {
- t.Error("fish completion missing 'models' command")
- }
-}
+ Describe("collectCommands", func() {
+ It("collects all commands and subcommands", func() {
+ cmds := collectCommands(app.Node, "")
-func TestCollectCommands(t *testing.T) {
- app := getTestApp()
- cmds := collectCommands(app.Node, "")
+ names := make(map[string]bool)
+ for _, cmd := range cmds {
+ names[cmd.fullName] = true
+ }
- names := make(map[string]bool)
- for _, cmd := range cmds {
- names[cmd.fullName] = true
- }
-
- if !names["run"] {
- t.Error("missing 'run' command")
- }
- if !names["models"] {
- t.Error("missing 'models' command")
- }
- if !names["models list"] {
- t.Error("missing 'models list' subcommand")
- }
- if !names["models install"] {
- t.Error("missing 'models install' subcommand")
- }
-}
+ Expect(names).To(HaveKey("run"))
+ Expect(names).To(HaveKey("models"))
+ Expect(names).To(HaveKey("models list"))
+ Expect(names).To(HaveKey("models install"))
+ })
+ })
+})
diff --git a/core/cli/run.go b/core/cli/run.go
index d3f1ac103..a478ab1c0 100644
--- a/core/cli/run.go
+++ b/core/cli/run.go
@@ -74,7 +74,6 @@ type RunCMD struct {
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2p-token" aliases:"p2ptoken" help:"Token for P2P mode (optional; --p2ptoken is deprecated, use --p2p-token)" group:"p2p"`
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
- ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time (deprecated: use --max-active-backends=1 instead)" group:"backends"`
MaxActiveBackends int `env:"LOCALAI_MAX_ACTIVE_BACKENDS,MAX_ACTIVE_BACKENDS" default:"0" help:"Maximum number of backends to keep loaded at once (0 = unlimited, 1 = single backend mode). Least recently used backends are evicted when limit is reached" group:"backends"`
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
@@ -438,10 +437,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
opts = append(opts, config.WithMemoryReclaimer(true, r.MemoryReclaimerThreshold))
}
- if r.ParallelRequests {
- opts = append(opts, config.EnableParallelBackendRequests)
- }
-
// Handle max active backends (LRU eviction)
// MaxActiveBackends takes precedence over SingleActiveBackend
if r.MaxActiveBackends > 0 {
diff --git a/core/cli/worker.go b/core/cli/worker.go
index f4c707e0c..b43d38fbb 100644
--- a/core/cli/worker.go
+++ b/core/cli/worker.go
@@ -67,22 +67,28 @@ func isPathAllowed(path string, allowedDirs []string) bool {
//
// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that.
type WorkerCMD struct {
- Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"`
+ // Primary address — the reachable address of this worker.
+ // Host is used for advertise, port is the base for gRPC backends.
+ // HTTP file transfer runs on port-1.
+ Addr string `env:"LOCALAI_ADDR" help:"Address where this worker is reachable (host:port). Port is base for gRPC backends, port-1 for HTTP." group:"server"`
+ ServeAddr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"(Advanced) gRPC base port bind address" group:"server" hidden:""`
+
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"`
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"`
// HTTP file transfer
- HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"`
- AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"`
+ HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server" hidden:""`
+ AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server" hidden:""`
// Registration (required)
- AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"`
+ AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""`
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
+ NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"`
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
@@ -96,7 +102,7 @@ type WorkerCMD struct {
}
func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
- xlog.Info("Starting worker", "addr", cmd.Addr)
+ xlog.Info("Starting worker", "advertise", cmd.advertiseAddr(), "basePort", cmd.effectiveBasePort())
systemState, err := system.GetSystemState(
system.WithModelPath(cmd.ModelsPath),
@@ -181,15 +187,7 @@ func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
}()
// Process supervisor — manages multiple backend gRPC processes on different ports
- basePort := 50051
- if cmd.Addr != "" {
- // Extract port from addr (e.g., "0.0.0.0:50051" → 50051)
- if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil {
- if p, err := strconv.Atoi(portStr); err == nil {
- basePort = p
- }
- }
- }
+ basePort := cmd.effectiveBasePort()
// Buffered so NATS stop handler can send without blocking
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
@@ -667,10 +665,10 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
// Return the gRPC address so the router knows which port to use
advertiseAddr := addr
- if s.cmd.AdvertiseAddr != "" {
- // Replace 0.0.0.0 with the advertised host but keep the dynamic port
+ advAddr := s.cmd.advertiseAddr()
+ if advAddr != addr { // only remap if advertise differs from bind
_, port, _ := net.SplitHostPort(addr)
- advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr)
+ advertiseHost, _, _ := net.SplitHostPort(advAddr)
advertiseAddr = net.JoinHostPort(advertiseHost, port)
}
resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr}
@@ -816,18 +814,31 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
})
}
+// effectiveBasePort returns the port used as base for gRPC backend processes.
+// Priority: Addr port → ServeAddr port → 50051
+func (cmd *WorkerCMD) effectiveBasePort() int {
+ for _, addr := range []string{cmd.Addr, cmd.ServeAddr} {
+ if addr != "" {
+ if _, portStr, ok := strings.Cut(addr, ":"); ok {
+ if p, _ := strconv.Atoi(portStr); p > 0 {
+ return p
+ }
+ }
+ }
+ }
+ return 50051
+}
+
// advertiseAddr returns the address the frontend should use to reach this node.
func (cmd *WorkerCMD) advertiseAddr() string {
if cmd.AdvertiseAddr != "" {
return cmd.AdvertiseAddr
}
- host, port, ok := strings.Cut(cmd.Addr, ":")
- if ok && (host == "0.0.0.0" || host == "") {
- if hostname, err := os.Hostname(); err == nil {
- return hostname + ":" + port
- }
+ if cmd.Addr != "" {
+ return cmd.Addr
}
- return cmd.Addr
+ hostname, _ := os.Hostname()
+ return fmt.Sprintf("%s:%d", cmp.Or(hostname, "localhost"), cmd.effectiveBasePort())
}
// resolveHTTPAddr returns the address to bind the HTTP file transfer server to.
@@ -837,12 +848,7 @@ func (cmd *WorkerCMD) resolveHTTPAddr() string {
if cmd.HTTPAddr != "" {
return cmd.HTTPAddr
}
- host, port, ok := strings.Cut(cmd.Addr, ":")
- if !ok {
- return "0.0.0.0:50050"
- }
- portNum, _ := strconv.Atoi(port)
- return fmt.Sprintf("%s:%d", host, portNum-1)
+ return fmt.Sprintf("0.0.0.0:%d", cmd.effectiveBasePort()-1)
}
// advertiseHTTPAddr returns the HTTP address the frontend should use to reach
@@ -851,14 +857,9 @@ func (cmd *WorkerCMD) advertiseHTTPAddr() string {
if cmd.AdvertiseHTTPAddr != "" {
return cmd.AdvertiseHTTPAddr
}
- httpAddr := cmd.resolveHTTPAddr()
- host, port, ok := strings.Cut(httpAddr, ":")
- if ok && (host == "0.0.0.0" || host == "") {
- if hostname, err := os.Hostname(); err == nil {
- return hostname + ":" + port
- }
- }
- return httpAddr
+ advHost, _, _ := strings.Cut(cmd.advertiseAddr(), ":")
+ httpPort := cmd.effectiveBasePort() - 1
+ return fmt.Sprintf("%s:%d", advHost, httpPort)
}
// registrationBody builds the JSON body for node registration.
@@ -896,6 +897,21 @@ func (cmd *WorkerCMD) registrationBody() map[string]any {
if cmd.RegistrationToken != "" {
body["token"] = cmd.RegistrationToken
}
+
+ // Parse and add static node labels
+ if cmd.NodeLabels != "" {
+ labels := make(map[string]string)
+ for _, pair := range strings.Split(cmd.NodeLabels, ",") {
+ pair = strings.TrimSpace(pair)
+ if k, v, ok := strings.Cut(pair, "="); ok {
+ labels[strings.TrimSpace(k)] = strings.TrimSpace(v)
+ }
+ }
+ if len(labels) > 0 {
+ body["labels"] = labels
+ }
+ }
+
return body
}
diff --git a/core/cli/worker_addr_test.go b/core/cli/worker_addr_test.go
new file mode 100644
index 000000000..6dd7b0ab9
--- /dev/null
+++ b/core/cli/worker_addr_test.go
@@ -0,0 +1,83 @@
+package cli
+
+import (
+ "os"
+ "strings"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("WorkerCMD address resolution", func() {
+ Describe("effectiveBasePort", func() {
+ DescribeTable("returns the correct port",
+ func(addr, serve string, want int) {
+ cmd := &WorkerCMD{Addr: addr, ServeAddr: serve}
+ Expect(cmd.effectiveBasePort()).To(Equal(want))
+ },
+ Entry("Addr takes priority", "worker1.example.com:60000", "0.0.0.0:50051", 60000),
+ Entry("falls back to ServeAddr", "", "0.0.0.0:50051", 50051),
+ Entry("returns 50051 when neither set", "", "", 50051),
+ Entry("Addr with custom port", "10.0.0.5:7000", "", 7000),
+ Entry("invalid port in Addr falls through to ServeAddr", "host:notanumber", "0.0.0.0:9999", 9999),
+ )
+ })
+
+ Describe("advertiseAddr", func() {
+ It("returns AdvertiseAddr when set", func() {
+ cmd := &WorkerCMD{
+ AdvertiseAddr: "public.example.com:50051",
+ Addr: "10.0.0.5:60000",
+ }
+ Expect(cmd.advertiseAddr()).To(Equal("public.example.com:50051"))
+ })
+
+ It("returns Addr when set", func() {
+ cmd := &WorkerCMD{Addr: "worker1.example.com:60000"}
+ Expect(cmd.advertiseAddr()).To(Equal("worker1.example.com:60000"))
+ })
+
+ It("falls back to hostname:basePort", func() {
+ cmd := &WorkerCMD{ServeAddr: "0.0.0.0:50051"}
+ got := cmd.advertiseAddr()
+ _, port, _ := strings.Cut(got, ":")
+ Expect(port).To(Equal("50051"))
+
+ hostname, _ := os.Hostname()
+ if hostname != "" {
+ host, _, _ := strings.Cut(got, ":")
+ Expect(host).To(Equal(hostname))
+ }
+ })
+ })
+
+ Describe("resolveHTTPAddr", func() {
+ DescribeTable("returns the correct address",
+ func(httpAddr, addr, serve, want string) {
+ cmd := &WorkerCMD{HTTPAddr: httpAddr, Addr: addr, ServeAddr: serve}
+ Expect(cmd.resolveHTTPAddr()).To(Equal(want))
+ },
+ Entry("HTTPAddr takes priority", "0.0.0.0:8080", "", "", "0.0.0.0:8080"),
+ Entry("derives from Addr port minus 1", "", "worker1:60000", "0.0.0.0:50051", "0.0.0.0:59999"),
+ Entry("derives from ServeAddr port minus 1", "", "", "0.0.0.0:50051", "0.0.0.0:50050"),
+ Entry("default when nothing set", "", "", "", "0.0.0.0:50050"),
+ )
+ })
+
+ Describe("advertiseHTTPAddr", func() {
+ DescribeTable("returns the correct address",
+ func(advertiseHTTP, advertise, addr, serve, want string) {
+ cmd := &WorkerCMD{
+ AdvertiseHTTPAddr: advertiseHTTP,
+ AdvertiseAddr: advertise,
+ Addr: addr,
+ ServeAddr: serve,
+ }
+ Expect(cmd.advertiseHTTPAddr()).To(Equal(want))
+ },
+ Entry("AdvertiseHTTPAddr takes priority", "public.example.com:8080", "", "", "", "public.example.com:8080"),
+ Entry("derives from advertiseAddr host + basePort-1", "", "", "worker1.example.com:60000", "", "worker1.example.com:59999"),
+ Entry("uses AdvertiseAddr host with basePort-1", "", "public.example.com:60000", "10.0.0.5:60000", "", "public.example.com:59999"),
+ )
+ })
+})
diff --git a/core/config/application_config.go b/core/config/application_config.go
index 2bfb552d7..8d55f83a6 100644
--- a/core/config/application_config.go
+++ b/core/config/application_config.go
@@ -59,8 +59,6 @@ type ApplicationConfig struct {
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
- ParallelBackendRequests bool
-
WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
@@ -379,10 +377,6 @@ func WithLRUEvictionRetryInterval(interval time.Duration) AppOption {
}
}
-var EnableParallelBackendRequests = func(o *ApplicationConfig) {
- o.ParallelBackendRequests = true
-}
-
var EnableGalleriesAutoload = func(o *ApplicationConfig) {
o.AutoloadGalleries = true
}
@@ -842,7 +836,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
watchdogBusy := o.WatchDogBusy
singleBackend := o.SingleBackend
maxActiveBackends := o.MaxActiveBackends
- parallelBackendRequests := o.ParallelBackendRequests
memoryReclaimerEnabled := o.MemoryReclaimerEnabled
memoryReclaimerThreshold := o.MemoryReclaimerThreshold
forceEvictionWhenBusy := o.ForceEvictionWhenBusy
@@ -915,7 +908,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
WatchdogInterval: &watchdogInterval,
SingleBackend: &singleBackend,
MaxActiveBackends: &maxActiveBackends,
- ParallelBackendRequests: ¶llelBackendRequests,
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
@@ -1008,9 +1000,6 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
}
requireRestart = true
}
- if settings.ParallelBackendRequests != nil {
- o.ParallelBackendRequests = *settings.ParallelBackendRequests
- }
if settings.MemoryReclaimerEnabled != nil {
o.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
if *settings.MemoryReclaimerEnabled {
diff --git a/core/config/application_config_test.go b/core/config/application_config_test.go
index 4ea89e8d6..c5559ce53 100644
--- a/core/config/application_config_test.go
+++ b/core/config/application_config_test.go
@@ -18,7 +18,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
WatchDogBusyTimeout: 10 * time.Minute,
SingleBackend: false,
MaxActiveBackends: 5,
- ParallelBackendRequests: true,
MemoryReclaimerEnabled: true,
MemoryReclaimerThreshold: 0.85,
Threads: 8,
@@ -62,9 +61,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
Expect(rs.MaxActiveBackends).ToNot(BeNil())
Expect(*rs.MaxActiveBackends).To(Equal(5))
- Expect(rs.ParallelBackendRequests).ToNot(BeNil())
- Expect(*rs.ParallelBackendRequests).To(BeTrue())
-
Expect(rs.MemoryReclaimerEnabled).ToNot(BeNil())
Expect(*rs.MemoryReclaimerEnabled).To(BeTrue())
@@ -455,7 +451,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
WatchDogBusyTimeout: 12 * time.Minute,
SingleBackend: false,
MaxActiveBackends: 3,
- ParallelBackendRequests: true,
MemoryReclaimerEnabled: true,
MemoryReclaimerThreshold: 0.92,
Threads: 12,
@@ -487,7 +482,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
Expect(target.WatchDogIdleTimeout).To(Equal(original.WatchDogIdleTimeout))
Expect(target.WatchDogBusyTimeout).To(Equal(original.WatchDogBusyTimeout))
Expect(target.MaxActiveBackends).To(Equal(original.MaxActiveBackends))
- Expect(target.ParallelBackendRequests).To(Equal(original.ParallelBackendRequests))
Expect(target.MemoryReclaimerEnabled).To(Equal(original.MemoryReclaimerEnabled))
Expect(target.MemoryReclaimerThreshold).To(Equal(original.MemoryReclaimerThreshold))
Expect(target.Threads).To(Equal(original.Threads))
diff --git a/core/config/runtime_settings.go b/core/config/runtime_settings.go
index 611759110..6a4117f06 100644
--- a/core/config/runtime_settings.go
+++ b/core/config/runtime_settings.go
@@ -20,8 +20,6 @@ type RuntimeSettings struct {
// Backend management
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
- ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"`
-
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
diff --git a/core/http/endpoints/localai/nodes.go b/core/http/endpoints/localai/nodes.go
index bbdb33471..0ac828a09 100644
--- a/core/http/endpoints/localai/nodes.go
+++ b/core/http/endpoints/localai/nodes.go
@@ -5,6 +5,7 @@ import (
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
+ "encoding/json"
"fmt"
"io"
"net/http"
@@ -37,7 +38,7 @@ func nodeError(code int, message string) schema.ErrorResponse {
func ListNodesEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
- nodeList, err := registry.List(ctx)
+ nodeList, err := registry.ListWithExtras(ctx)
if err != nil {
xlog.Error("Failed to list nodes", "error", err)
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list nodes"))
@@ -70,7 +71,8 @@ type RegisterNodeRequest struct {
AvailableVRAM uint64 `json:"available_vram,omitempty"`
TotalRAM uint64 `json:"total_ram,omitempty"`
AvailableRAM uint64 `json:"available_ram,omitempty"`
- GPUVendor string `json:"gpu_vendor,omitempty"`
+ GPUVendor string `json:"gpu_vendor,omitempty"`
+ Labels map[string]string `json:"labels,omitempty"`
}
// RegisterNodeEndpoint registers a new backend node.
@@ -148,6 +150,14 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to register node"))
}
+ // Store static labels from worker and apply auto-labels
+ if len(req.Labels) > 0 {
+ if err := registry.SetNodeLabels(ctx, node.ID, req.Labels); err != nil {
+ xlog.Warn("Failed to set node labels", "node", node.ID, "error", err)
+ }
+ }
+ registry.ApplyAutoLabels(ctx, node.ID, node)
+
response := map[string]any{
"id": node.ID,
"name": node.Name,
@@ -618,6 +628,178 @@ func NodeBackendLogsWSEndpoint(registry *nodes.NodeRegistry, registrationToken s
}
}
+// GetNodeLabelsEndpoint returns labels for a specific node.
+func GetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ nodeID := c.Param("id")
+ labels, err := registry.GetNodeLabels(ctx, nodeID)
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get labels"))
+ }
+ // Convert to map for cleaner API response
+ result := make(map[string]string)
+ for _, l := range labels {
+ result[l.Key] = l.Value
+ }
+ return c.JSON(http.StatusOK, result)
+ }
+}
+
+// SetNodeLabelsEndpoint replaces all labels for a node.
+func SetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ nodeID := c.Param("id")
+ if _, err := registry.Get(ctx, nodeID); err != nil {
+ return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
+ }
+ var labels map[string]string
+ if err := c.Bind(&labels); err != nil {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
+ }
+ if err := registry.SetNodeLabels(ctx, nodeID, labels); err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set labels"))
+ }
+ return c.JSON(http.StatusOK, labels)
+ }
+}
+
+// MergeNodeLabelsEndpoint adds/updates labels without removing existing ones.
+func MergeNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ nodeID := c.Param("id")
+ if _, err := registry.Get(ctx, nodeID); err != nil {
+ return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
+ }
+ var labels map[string]string
+ if err := c.Bind(&labels); err != nil {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
+ }
+ for k, v := range labels {
+ if err := registry.SetNodeLabel(ctx, nodeID, k, v); err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to merge labels"))
+ }
+ }
+ // Return updated labels
+ updated, _ := registry.GetNodeLabels(ctx, nodeID)
+ result := make(map[string]string)
+ for _, l := range updated {
+ result[l.Key] = l.Value
+ }
+ return c.JSON(http.StatusOK, result)
+ }
+}
+
+// DeleteNodeLabelEndpoint removes a single label from a node.
+func DeleteNodeLabelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ nodeID := c.Param("id")
+ key := c.Param("key")
+ if key == "" {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "label key is required"))
+ }
+ if err := registry.RemoveNodeLabel(ctx, nodeID, key); err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to remove label"))
+ }
+ return c.NoContent(http.StatusNoContent)
+ }
+}
+
+// ListSchedulingEndpoint returns all model scheduling configs.
+func ListSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ configs, err := registry.ListModelSchedulings(ctx)
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list scheduling configs"))
+ }
+ return c.JSON(http.StatusOK, configs)
+ }
+}
+
+// GetSchedulingEndpoint returns the scheduling config for a specific model.
+func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ modelName := c.Param("model")
+ config, err := registry.GetModelScheduling(ctx, modelName)
+ if err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get scheduling config"))
+ }
+ if config == nil {
+ return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "no scheduling config for model"))
+ }
+ return c.JSON(http.StatusOK, config)
+ }
+}
+
+// SetSchedulingRequest is the request body for creating/updating a scheduling config.
+type SetSchedulingRequest struct {
+ ModelName string `json:"model_name"`
+ NodeSelector map[string]string `json:"node_selector,omitempty"`
+ MinReplicas int `json:"min_replicas"`
+ MaxReplicas int `json:"max_replicas"`
+}
+
+// SetSchedulingEndpoint creates or updates a model scheduling config.
+func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ var req SetSchedulingRequest
+ if err := c.Bind(&req); err != nil {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
+ }
+ if req.ModelName == "" {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required"))
+ }
+ if req.MinReplicas < 0 {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0"))
+ }
+ if req.MaxReplicas < 0 {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0"))
+ }
+ if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas"))
+ }
+
+ // Serialize node selector to JSON
+ var selectorJSON string
+ if len(req.NodeSelector) > 0 {
+ b, err := json.Marshal(req.NodeSelector)
+ if err != nil {
+ return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid node_selector"))
+ }
+ selectorJSON = string(b)
+ }
+
+ config := &nodes.ModelSchedulingConfig{
+ ModelName: req.ModelName,
+ NodeSelector: selectorJSON,
+ MinReplicas: req.MinReplicas,
+ MaxReplicas: req.MaxReplicas,
+ }
+ if err := registry.SetModelScheduling(ctx, config); err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
+ }
+ return c.JSON(http.StatusOK, config)
+ }
+}
+
+// DeleteSchedulingEndpoint removes a model scheduling config.
+func DeleteSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ ctx := c.Request().Context()
+ modelName := c.Param("model")
+ if err := registry.DeleteModelScheduling(ctx, modelName); err != nil {
+ return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete scheduling config"))
+ }
+ return c.NoContent(http.StatusNoContent)
+ }
+}
+
// proxyHTTPToWorker makes a GET request to a worker's HTTP server with bearer token auth.
func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error) {
reqURL := fmt.Sprintf("http://%s%s", httpAddress, path)
diff --git a/core/http/react-ui/src/pages/Nodes.jsx b/core/http/react-ui/src/pages/Nodes.jsx
index 24812d6a9..0ec97a7c8 100644
--- a/core/http/react-ui/src/pages/Nodes.jsx
+++ b/core/http/react-ui/src/pages/Nodes.jsx
@@ -159,6 +159,62 @@ function WorkerHintCard({ addToast, activeTab, hasWorkers }) {
)
}
+function SchedulingForm({ onSave, onCancel }) {
+ const [modelName, setModelName] = useState('')
+ const [selectorText, setSelectorText] = useState('')
+ const [minReplicas, setMinReplicas] = useState(0)
+ const [maxReplicas, setMaxReplicas] = useState(0)
+
+ const handleSubmit = () => {
+ let nodeSelector = null
+ if (selectorText.trim()) {
+ const pairs = {}
+ selectorText.split(',').forEach(p => {
+ const [k, v] = p.split('=').map(s => s.trim())
+ if (k) pairs[k] = v || ''
+ })
+ nodeSelector = pairs
+ }
+ onSave({
+ model_name: modelName,
+ node_selector: nodeSelector ? JSON.stringify(nodeSelector) : '',
+ min_replicas: minReplicas,
+ max_replicas: maxReplicas,
+ })
+ }
+
+ return (
+
+ )
+}
+
export default function Nodes() {
const { addToast } = useOutletContext()
const navigate = useNavigate()
@@ -170,7 +226,9 @@ export default function Nodes() {
const [nodeBackends, setNodeBackends] = useState({})
const [confirmDelete, setConfirmDelete] = useState(null)
const [showTips, setShowTips] = useState(false)
- const [activeTab, setActiveTab] = useState('backend') // 'backend' or 'agent'
+ const [activeTab, setActiveTab] = useState('backend') // 'backend', 'agent', or 'scheduling'
+ const [schedulingConfigs, setSchedulingConfigs] = useState([])
+ const [showSchedulingForm, setShowSchedulingForm] = useState(false)
const fetchNodes = useCallback(async () => {
try {
@@ -186,11 +244,19 @@ export default function Nodes() {
}
}, [])
+ const fetchScheduling = useCallback(async () => {
+ try {
+ const data = await nodesApi.listScheduling()
+ setSchedulingConfigs(Array.isArray(data) ? data : [])
+ } catch { setSchedulingConfigs([]) }
+ }, [])
+
useEffect(() => {
fetchNodes()
+ fetchScheduling()
const interval = setInterval(fetchNodes, 5000)
return () => clearInterval(interval)
- }, [fetchNodes])
+ }, [fetchNodes, fetchScheduling])
const fetchModels = useCallback(async (nodeId) => {
try {
@@ -254,6 +320,36 @@ export default function Nodes() {
}
}
+ const handleUnloadModel = async (nodeId, modelName) => {
+ try {
+ await nodesApi.unloadModel(nodeId, modelName)
+ addToast(`Model "${modelName}" unloaded`, 'success')
+ fetchModels(nodeId)
+ } catch (err) {
+ addToast(`Failed to unload model: ${err.message}`, 'error')
+ }
+ }
+
+ const handleAddLabel = async (nodeId, key, value) => {
+ try {
+ await nodesApi.mergeLabels(nodeId, { [key]: value })
+ addToast(`Label "${key}=${value}" added`, 'success')
+ fetchNodes()
+ } catch (err) {
+ addToast(`Failed to add label: ${err.message}`, 'error')
+ }
+ }
+
+ const handleDeleteLabel = async (nodeId, key) => {
+ try {
+ await nodesApi.deleteLabel(nodeId, key)
+ addToast(`Label "${key}" removed`, 'success')
+ fetchNodes()
+ } catch (err) {
+ addToast(`Failed to remove label: ${err.message}`, 'error')
+ }
+ }
+
const handleDelete = async (nodeId) => {
try {
await nodesApi.delete(nodeId)
@@ -422,8 +518,23 @@ export default function Nodes() {
Agent Workers ({agentNodes.length})
+
+ {activeTab !== 'scheduling' && <>
{/* Stat cards */}
@@ -433,6 +544,23 @@ export default function Nodes() {
{pending > 0 && (
)}
+ {activeTab === 'backend' && (() => {
+ const clusterTotalVRAM = backendNodes.reduce((sum, n) => sum + (n.total_vram || 0), 0)
+ const clusterUsedVRAM = backendNodes.reduce((sum, n) => {
+ if (n.total_vram && n.available_vram != null) return sum + (n.total_vram - n.available_vram)
+ return sum
+ }, 0)
+ const totalModelsLoaded = backendNodes.reduce((sum, n) => sum + (n.model_count || 0), 0)
+ return (
+ <>
+ {clusterTotalVRAM > 0 && (
+
+ )}
+
+ >
+ )
+ })()}
{/* Worker tips */}
@@ -506,6 +634,22 @@ export default function Nodes() {
{node.address}
+ {node.labels && Object.keys(node.labels).length > 0 && (
+
+ {Object.entries(node.labels).slice(0, 5).map(([k, v]) => (
+ {k}={v}
+ ))}
+ {Object.keys(node.labels).length > 5 && (
+
+ +{Object.keys(node.labels).length - 5} more
+
+ )}
+
+ )}
@@ -593,6 +737,7 @@ export default function Nodes() {
State |
In-Flight |
Logs |
+ Actions |
@@ -628,6 +773,21 @@ export default function Nodes() {
+
+
+ |
)
})}
@@ -689,6 +849,50 @@ export default function Nodes() {
)}
+
+ {/* Labels */}
+
+
+
+ Labels
+
+
+ {node.labels && Object.entries(node.labels).map(([k, v]) => (
+
+ {k}={v}
+
+
+ ))}
+
+ {/* Add label form */}
+
+
+
+
+
+
@@ -700,6 +904,78 @@ export default function Nodes() {
)}
+ >}
+
+ {activeTab === 'scheduling' && (
+
+
+ {showSchedulingForm &&
{
+ try {
+ await nodesApi.setScheduling(config)
+ fetchScheduling()
+ setShowSchedulingForm(false)
+ addToast('Scheduling rule saved', 'success')
+ } catch (err) {
+ addToast(`Failed to save rule: ${err.message}`, 'error')
+ }
+ }} onCancel={() => setShowSchedulingForm(false)} />}
+ {schedulingConfigs.length === 0 && !showSchedulingForm ? (
+
+ No scheduling rules configured. Add a rule to control how models are placed on nodes.
+
+ ) : schedulingConfigs.length > 0 && (
+
+
+
+ | Model |
+ Node Selector |
+ Min Replicas |
+ Max Replicas |
+ Actions |
+
+
+ {schedulingConfigs.map(cfg => (
+
+ | {cfg.model_name} |
+
+ {cfg.node_selector ? (() => {
+ try {
+ const sel = typeof cfg.node_selector === 'string' ? JSON.parse(cfg.node_selector) : cfg.node_selector
+ return Object.entries(sel).map(([k,v]) => (
+ {k}={v}
+ ))
+ } catch { return {cfg.node_selector} }
+ })() : Any node}
+ |
+ {cfg.min_replicas || '-'} |
+ {cfg.max_replicas || 'unlimited'} |
+
+
+ |
+
+ ))}
+
+
+
+ )}
+
+ )}
update('max_active_backends', parseInt(e.target.value) || 0)} placeholder="0" />
-
- update('parallel_backend_requests', v)} />
-
diff --git a/core/http/react-ui/src/utils/api.js b/core/http/react-ui/src/utils/api.js
index dc25f051d..ee53bc377 100644
--- a/core/http/react-ui/src/utils/api.js
+++ b/core/http/react-ui/src/utils/api.js
@@ -423,6 +423,13 @@ export const nodesApi = {
deleteBackend: (id, backend) => postJSON(API_CONFIG.endpoints.nodeBackendsDelete(id), { backend }),
getBackendLogs: (id) => fetchJSON(API_CONFIG.endpoints.nodeBackendLogs(id)),
getBackendLogLines: (id, modelId) => fetchJSON(API_CONFIG.endpoints.nodeBackendLogsModel(id, modelId)),
+ unloadModel: (id, modelName) => postJSON(API_CONFIG.endpoints.nodeModelsUnload(id), { model_name: modelName }),
+ getLabels: (id) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id)),
+ mergeLabels: (id, labels) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id), { method: 'PATCH', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(labels) }),
+ deleteLabel: (id, key) => fetchJSON(API_CONFIG.endpoints.nodeLabelKey(id, key), { method: 'DELETE' }),
+ listScheduling: () => fetchJSON(API_CONFIG.endpoints.nodesScheduling),
+ setScheduling: (config) => postJSON(API_CONFIG.endpoints.nodesScheduling, config),
+ deleteScheduling: (model) => fetchJSON(API_CONFIG.endpoints.nodesSchedulingModel(model), { method: 'DELETE' }),
}
// File to base64 helper
diff --git a/core/http/react-ui/src/utils/config.js b/core/http/react-ui/src/utils/config.js
index 76776bfce..06c84d509 100644
--- a/core/http/react-ui/src/utils/config.js
+++ b/core/http/react-ui/src/utils/config.js
@@ -107,5 +107,10 @@ export const API_CONFIG = {
nodeBackendsDelete: (id) => `/api/nodes/${id}/backends/delete`,
nodeBackendLogs: (id) => `/api/nodes/${id}/backend-logs`,
nodeBackendLogsModel: (id, modelId) => `/api/nodes/${id}/backend-logs/${encodeURIComponent(modelId)}`,
+ nodeModelsUnload: (id) => `/api/nodes/${id}/models/unload`,
+ nodeLabels: (id) => `/api/nodes/${id}/labels`,
+ nodeLabelKey: (id, key) => `/api/nodes/${id}/labels/${key}`,
+ nodesScheduling: '/api/nodes/scheduling',
+ nodesSchedulingModel: (model) => `/api/nodes/scheduling/${encodeURIComponent(model)}`,
},
}
diff --git a/core/http/routes/nodes.go b/core/http/routes/nodes.go
index 5ebd87ce6..2f23e3036 100644
--- a/core/http/routes/nodes.go
+++ b/core/http/routes/nodes.go
@@ -61,6 +61,13 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
admin := e.Group("/api/nodes", readyMw, adminMw)
admin.GET("", localai.ListNodesEndpoint(registry))
+
+ // Model scheduling (registered before /:id to avoid route conflicts)
+ admin.GET("/scheduling", localai.ListSchedulingEndpoint(registry))
+ admin.GET("/scheduling/:model", localai.GetSchedulingEndpoint(registry))
+ admin.POST("/scheduling", localai.SetSchedulingEndpoint(registry))
+ admin.DELETE("/scheduling/:model", localai.DeleteSchedulingEndpoint(registry))
+
admin.GET("/:id", localai.GetNodeEndpoint(registry))
admin.GET("/:id/models", localai.GetNodeModelsEndpoint(registry))
admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry))
@@ -80,6 +87,12 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
admin.GET("/:id/backend-logs", localai.NodeBackendLogsListEndpoint(registry, registrationToken))
admin.GET("/:id/backend-logs/:modelId", localai.NodeBackendLogsLinesEndpoint(registry, registrationToken))
+ // Label management
+ admin.GET("/:id/labels", localai.GetNodeLabelsEndpoint(registry))
+ admin.PUT("/:id/labels", localai.SetNodeLabelsEndpoint(registry))
+ admin.PATCH("/:id/labels", localai.MergeNodeLabelsEndpoint(registry))
+ admin.DELETE("/:id/labels/:key", localai.DeleteNodeLabelEndpoint(registry))
+
// WebSocket proxy for real-time log streaming from workers
e.GET("/ws/nodes/:id/backend-logs/:modelId", localai.NodeBackendLogsWSEndpoint(registry, registrationToken), readyMw, adminMw)
}
diff --git a/core/services/nodes/interfaces.go b/core/services/nodes/interfaces.go
index 3b3cbf137..53d78ccb3 100644
--- a/core/services/nodes/interfaces.go
+++ b/core/services/nodes/interfaces.go
@@ -21,6 +21,12 @@ type ModelRouter interface {
FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (*NodeModel, error)
FindLRUModel(ctx context.Context, nodeID string) (*NodeModel, error)
Get(ctx context.Context, nodeID string) (*BackendNode, error)
+ GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error)
+ FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error)
+ FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error)
+ FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
+ FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
+ GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
}
// NodeHealthStore is used by HealthMonitor for node status management.
diff --git a/core/services/nodes/model_router_test.go b/core/services/nodes/model_router_test.go
index 5c98fb0de..ce0165e60 100644
--- a/core/services/nodes/model_router_test.go
+++ b/core/services/nodes/model_router_test.go
@@ -72,6 +72,24 @@ func (f *fakeModelRouterForSmartRouter) Get(_ context.Context, nodeID string) (*
}
return nil, nil
}
+func (f *fakeModelRouterForSmartRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
+ return nil, nil
+}
+func (f *fakeModelRouterForSmartRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
+ return nil, nil
+}
+func (f *fakeModelRouterForSmartRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
+ return nil, nil
+}
+func (f *fakeModelRouterForSmartRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
+ return nil, nil
+}
+func (f *fakeModelRouterForSmartRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
+ return nil, nil
+}
+func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
+ return nil, nil
+}
// Compile-time check
var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil)
diff --git a/core/services/nodes/reconciler.go b/core/services/nodes/reconciler.go
new file mode 100644
index 000000000..92ba76edc
--- /dev/null
+++ b/core/services/nodes/reconciler.go
@@ -0,0 +1,236 @@
+package nodes
+
+import (
+ "context"
+ "encoding/json"
+ "time"
+
+ "github.com/mudler/LocalAI/core/services/advisorylock"
+ "github.com/mudler/xlog"
+ "gorm.io/gorm"
+)
+
+// ReplicaReconciler periodically ensures model replica counts match their
+// scheduling configs. It scales up replicas when below MinReplicas or when
+// all replicas are busy (up to MaxReplicas), and scales down idle replicas
+// above MinReplicas.
+//
+// Only processes models with auto-scaling enabled (MinReplicas > 0 or MaxReplicas > 0).
+type ReplicaReconciler struct {
+ registry *NodeRegistry
+ scheduler ModelScheduler // interface for scheduling new models
+ unloader NodeCommandSender
+ db *gorm.DB
+ interval time.Duration
+ scaleDownDelay time.Duration
+}
+
+// ModelScheduler abstracts the scheduling logic needed by the reconciler.
+// SmartRouter implements this interface.
+type ModelScheduler interface {
+ // ScheduleAndLoadModel picks a node (optionally from candidateNodeIDs),
+ // installs the backend, and loads the model. Returns the node used.
+ ScheduleAndLoadModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error)
+}
+
+// ReplicaReconcilerOptions holds configuration for creating a ReplicaReconciler.
+type ReplicaReconcilerOptions struct {
+ Registry *NodeRegistry
+ Scheduler ModelScheduler
+ Unloader NodeCommandSender
+ DB *gorm.DB
+ Interval time.Duration // default 30s
+ ScaleDownDelay time.Duration // default 5m
+}
+
+// NewReplicaReconciler creates a new ReplicaReconciler.
+func NewReplicaReconciler(opts ReplicaReconcilerOptions) *ReplicaReconciler {
+ interval := opts.Interval
+ if interval == 0 {
+ interval = 30 * time.Second
+ }
+ scaleDownDelay := opts.ScaleDownDelay
+ if scaleDownDelay == 0 {
+ scaleDownDelay = 5 * time.Minute
+ }
+ return &ReplicaReconciler{
+ registry: opts.Registry,
+ scheduler: opts.Scheduler,
+ unloader: opts.Unloader,
+ db: opts.DB,
+ interval: interval,
+ scaleDownDelay: scaleDownDelay,
+ }
+}
+
+// Run starts the reconciliation loop. It blocks until ctx is cancelled.
+func (rc *ReplicaReconciler) Run(ctx context.Context) {
+ ticker := time.NewTicker(rc.interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ rc.reconcileOnce(ctx)
+ }
+ }
+}
+
+// reconcileOnce performs a single reconciliation pass.
+// Uses an advisory lock so only one frontend instance reconciles at a time.
+func (rc *ReplicaReconciler) reconcileOnce(ctx context.Context) {
+ if rc.db != nil {
+ lockKey := advisorylock.KeyFromString("replica-reconciler")
+ _ = advisorylock.WithLockCtx(ctx, rc.db, lockKey, func() error {
+ rc.reconcile(ctx)
+ return nil
+ })
+ } else {
+ rc.reconcile(ctx)
+ }
+}
+
+func (rc *ReplicaReconciler) reconcile(ctx context.Context) {
+ configs, err := rc.registry.ListAutoScalingConfigs(ctx)
+ if err != nil {
+ xlog.Warn("Reconciler: failed to list auto-scaling configs", "error", err)
+ return
+ }
+
+ for _, cfg := range configs {
+ if err := ctx.Err(); err != nil {
+ return // context cancelled
+ }
+ rc.reconcileModel(ctx, cfg)
+ }
+}
+
+func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedulingConfig) {
+ current, err := rc.registry.CountLoadedReplicas(ctx, cfg.ModelName)
+ if err != nil {
+ xlog.Warn("Reconciler: failed to count replicas", "model", cfg.ModelName, "error", err)
+ return
+ }
+
+ // 1. Ensure minimum replicas
+ if cfg.MinReplicas > 0 && int(current) < cfg.MinReplicas {
+ needed := cfg.MinReplicas - int(current)
+ xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName,
+ "current", current, "min", cfg.MinReplicas, "adding", needed)
+ rc.scaleUp(ctx, cfg, needed)
+ return
+ }
+
+ // 2. Auto-scale up if all replicas are busy
+ if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
+ if rc.allReplicasBusy(ctx, cfg.ModelName) {
+ xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName,
+ "current", current)
+ rc.scaleUp(ctx, cfg, 1)
+ }
+ }
+
+ // 3. Scale down idle replicas above minimum
+ floor := cfg.MinReplicas
+ if floor < 1 {
+ floor = 1
+ }
+ if int(current) > floor {
+ rc.scaleDownIdle(ctx, cfg, int(current), floor)
+ }
+}
+
+// scaleUp schedules additional replicas of the model.
+func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) {
+ if rc.scheduler == nil {
+ xlog.Warn("Reconciler: no scheduler available, cannot scale up")
+ return
+ }
+
+ // Determine candidate nodes from selector
+ var candidateNodeIDs []string
+ if cfg.NodeSelector != "" {
+ selector := parseSelector(cfg.NodeSelector)
+ if len(selector) > 0 {
+ candidates, err := rc.registry.FindNodesBySelector(ctx, selector)
+ if err != nil || len(candidates) == 0 {
+ xlog.Warn("Reconciler: no nodes match selector", "model", cfg.ModelName,
+ "selector", cfg.NodeSelector)
+ return
+ }
+ candidateNodeIDs = make([]string, len(candidates))
+ for i, n := range candidates {
+ candidateNodeIDs[i] = n.ID
+ }
+ }
+ }
+
+ for i := 0; i < count; i++ {
+ node, err := rc.scheduler.ScheduleAndLoadModel(ctx, cfg.ModelName, candidateNodeIDs)
+ if err != nil {
+ xlog.Warn("Reconciler: failed to scale up replica", "model", cfg.ModelName,
+ "attempt", i+1, "error", err)
+ return // stop trying on first failure
+ }
+ xlog.Info("Reconciler: scaled up replica", "model", cfg.ModelName, "node", node.Name)
+ }
+}
+
+// scaleDownIdle removes idle replicas above the floor.
+func (rc *ReplicaReconciler) scaleDownIdle(ctx context.Context, cfg ModelSchedulingConfig, current, floor int) {
+ if rc.unloader == nil {
+ return
+ }
+
+ // Find idle replicas that have been unused for longer than scaleDownDelay
+ cutoff := time.Now().Add(-rc.scaleDownDelay)
+ var idleModels []NodeModel
+ rc.registry.db.WithContext(ctx).
+ Where("model_name = ? AND state = ? AND in_flight = 0 AND last_used < ?",
+ cfg.ModelName, "loaded", cutoff).
+ Order("last_used ASC").
+ Find(&idleModels)
+
+ toRemove := current - floor
+ removed := 0
+ for _, nm := range idleModels {
+ if removed >= toRemove {
+ break
+ }
+ // Remove from registry
+ if err := rc.registry.RemoveNodeModel(ctx, nm.NodeID, nm.ModelName); err != nil {
+ xlog.Warn("Reconciler: failed to remove model record", "error", err)
+ continue
+ }
+ // Unload from worker
+ if err := rc.unloader.UnloadModelOnNode(nm.NodeID, nm.ModelName); err != nil {
+ xlog.Warn("Reconciler: unload failed (model already removed from registry)", "error", err)
+ }
+ xlog.Info("Reconciler: scaled down idle replica", "model", cfg.ModelName, "node", nm.NodeID)
+ removed++
+ }
+}
+
+// allReplicasBusy returns true if all loaded replicas of a model have in-flight requests.
+func (rc *ReplicaReconciler) allReplicasBusy(ctx context.Context, modelName string) bool {
+ var idleCount int64
+ rc.registry.db.WithContext(ctx).Model(&NodeModel{}).
+ Where("model_name = ? AND state = ? AND in_flight = 0", modelName, "loaded").
+ Count(&idleCount)
+ return idleCount == 0
+}
+
+// parseSelector decodes a JSON node selector string into a map.
+func parseSelector(selectorJSON string) map[string]string {
+ if selectorJSON == "" {
+ return nil
+ }
+ var selector map[string]string
+ if err := json.Unmarshal([]byte(selectorJSON), &selector); err != nil {
+ xlog.Warn("Failed to parse node selector", "selector", selectorJSON, "error", err)
+ return nil
+ }
+ return selector
+}
diff --git a/core/services/nodes/reconciler_test.go b/core/services/nodes/reconciler_test.go
new file mode 100644
index 000000000..e95f8bcea
--- /dev/null
+++ b/core/services/nodes/reconciler_test.go
@@ -0,0 +1,241 @@
+package nodes
+
+import (
+ "context"
+ "runtime"
+ "time"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+
+ "github.com/mudler/LocalAI/core/services/testutil"
+ "gorm.io/gorm"
+)
+
+// ---------------------------------------------------------------------------
+// Fake ModelScheduler
+// ---------------------------------------------------------------------------
+
+type fakeScheduler struct {
+ scheduleNode *BackendNode
+ scheduleErr error
+ scheduleCalls []scheduleCall
+}
+
+type scheduleCall struct {
+ modelName string
+ candidateIDs []string
+}
+
+func (f *fakeScheduler) ScheduleAndLoadModel(_ context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) {
+ f.scheduleCalls = append(f.scheduleCalls, scheduleCall{modelName, candidateNodeIDs})
+ return f.scheduleNode, f.scheduleErr
+}
+
+var _ = Describe("ReplicaReconciler", func() {
+ var (
+ db *gorm.DB
+ registry *NodeRegistry
+ )
+
+ BeforeEach(func() {
+ if runtime.GOOS == "darwin" {
+ Skip("testcontainers requires Docker, not available on macOS CI")
+ }
+ db = testutil.SetupTestDB()
+ var err error
+ registry, err = NewNodeRegistry(db)
+ Expect(err).ToNot(HaveOccurred())
+ })
+
+ // Helper to register a healthy node.
+ registerNode := func(name, address string) *BackendNode {
+ node := &BackendNode{
+ Name: name,
+ NodeType: NodeTypeBackend,
+ Address: address,
+ }
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+ return node
+ }
+
+ // Helper to set up a scheduling config.
+ setSchedulingConfig := func(modelName string, minReplicas, maxReplicas int, nodeSelector string) {
+ cfg := &ModelSchedulingConfig{
+ ModelName: modelName,
+ MinReplicas: minReplicas,
+ MaxReplicas: maxReplicas,
+ NodeSelector: nodeSelector,
+ }
+ Expect(registry.SetModelScheduling(context.Background(), cfg)).To(Succeed())
+ }
+
+ Context("model below min_replicas", func() {
+ It("scales up to min_replicas", func() {
+ node := registerNode("node-1", "10.0.0.1:50051")
+ setSchedulingConfig("model-a", 2, 4, "")
+
+ scheduler := &fakeScheduler{
+ scheduleNode: node,
+ }
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: scheduler,
+ DB: db,
+ })
+
+ // No replicas loaded — should schedule 2
+ reconciler.reconcile(context.Background())
+
+ Expect(scheduler.scheduleCalls).To(HaveLen(2))
+ Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-a"))
+ Expect(scheduler.scheduleCalls[1].modelName).To(Equal("model-a"))
+ })
+ })
+
+ Context("all replicas busy and below max_replicas", func() {
+ It("scales up by 1", func() {
+ node := registerNode("node-busy", "10.0.0.2:50051")
+ setSchedulingConfig("model-b", 1, 4, "")
+
+ // Load 2 replicas, both busy (in_flight > 0)
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", "loaded", "addr1", 0)).To(Succeed())
+ Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-b")).To(Succeed())
+
+ node2 := registerNode("node-busy-2", "10.0.0.3:50051")
+ Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-b", "loaded", "addr2", 0)).To(Succeed())
+ Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-b")).To(Succeed())
+
+ scheduler := &fakeScheduler{
+ scheduleNode: node,
+ }
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: scheduler,
+ DB: db,
+ })
+
+ reconciler.reconcile(context.Background())
+
+ Expect(scheduler.scheduleCalls).To(HaveLen(1))
+ Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-b"))
+ })
+ })
+
+ Context("all replicas busy and at max_replicas", func() {
+ It("does not scale up", func() {
+ node := registerNode("node-max", "10.0.0.4:50051")
+ setSchedulingConfig("model-c", 1, 2, "")
+
+ // Load 2 replicas (at max), both busy
+ Expect(registry.SetNodeModel(context.Background(), node.ID, "model-c", "loaded", "addr1", 0)).To(Succeed())
+ Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-c")).To(Succeed())
+
+ node2 := registerNode("node-max-2", "10.0.0.5:50051")
+ Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-c", "loaded", "addr2", 0)).To(Succeed())
+ Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-c")).To(Succeed())
+
+ scheduler := &fakeScheduler{
+ scheduleNode: node,
+ }
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: scheduler,
+ DB: db,
+ })
+
+ reconciler.reconcile(context.Background())
+
+ Expect(scheduler.scheduleCalls).To(BeEmpty())
+ })
+ })
+
+ Context("idle replicas above min_replicas", func() {
+ It("scales down after idle delay", func() {
+ node1 := registerNode("node-idle-1", "10.0.0.6:50051")
+ node2 := registerNode("node-idle-2", "10.0.0.7:50051")
+ node3 := registerNode("node-idle-3", "10.0.0.8:50051")
+ setSchedulingConfig("model-d", 1, 4, "")
+
+ // Load 3 replicas, all idle with last_used in the past
+ pastTime := time.Now().Add(-10 * time.Minute)
+ for _, n := range []*BackendNode{node1, node2, node3} {
+ Expect(registry.SetNodeModel(context.Background(), n.ID, "model-d", "loaded", "", 0)).To(Succeed())
+ // Set last_used to past time to trigger scale-down
+ db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-d").
+ Update("last_used", pastTime)
+ }
+
+ unloader := &fakeUnloader{}
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Unloader: unloader,
+ DB: db,
+ ScaleDownDelay: 1 * time.Minute, // short delay for test
+ })
+
+ reconciler.reconcile(context.Background())
+
+ // Should scale down 2 replicas (3 - floor of 1)
+ Expect(unloader.unloadCalls).To(HaveLen(2))
+ })
+ })
+
+ Context("idle replicas at min_replicas", func() {
+ It("does not scale down", func() {
+ node1 := registerNode("node-keep-1", "10.0.0.9:50051")
+ node2 := registerNode("node-keep-2", "10.0.0.10:50051")
+ setSchedulingConfig("model-e", 2, 4, "")
+
+ // Load exactly 2 replicas (at min), both idle with past last_used
+ pastTime := time.Now().Add(-10 * time.Minute)
+ for _, n := range []*BackendNode{node1, node2} {
+ Expect(registry.SetNodeModel(context.Background(), n.ID, "model-e", "loaded", "", 0)).To(Succeed())
+ db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-e").
+ Update("last_used", pastTime)
+ }
+
+ unloader := &fakeUnloader{}
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Unloader: unloader,
+ DB: db,
+ ScaleDownDelay: 1 * time.Minute,
+ })
+
+ reconciler.reconcile(context.Background())
+
+ Expect(unloader.unloadCalls).To(BeEmpty())
+ })
+ })
+
+ Context("model with node_selector", func() {
+ It("passes candidate node IDs to scheduler", func() {
+ node1 := registerNode("gpu-node", "10.0.0.11:50051")
+ node2 := registerNode("cpu-node", "10.0.0.12:50051")
+
+ // Add labels — only node1 matches the selector
+ Expect(registry.SetNodeLabel(context.Background(), node1.ID, "gpu.vendor", "nvidia")).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), node2.ID, "gpu.vendor", "none")).To(Succeed())
+
+ setSchedulingConfig("model-f", 1, 2, `{"gpu.vendor":"nvidia"}`)
+
+ scheduler := &fakeScheduler{
+ scheduleNode: node1,
+ }
+ reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
+ Registry: registry,
+ Scheduler: scheduler,
+ DB: db,
+ })
+
+ // No replicas loaded — should schedule 1 with candidate node IDs
+ reconciler.reconcile(context.Background())
+
+ Expect(scheduler.scheduleCalls).To(HaveLen(1))
+ Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-f"))
+ Expect(scheduler.scheduleCalls[0].candidateIDs).To(ContainElement(node1.ID))
+ Expect(scheduler.scheduleCalls[0].candidateIDs).ToNot(ContainElement(node2.ID))
+ })
+ })
+})
diff --git a/core/services/nodes/registry.go b/core/services/nodes/registry.go
index c7e6ccd3d..5d7404f5f 100644
--- a/core/services/nodes/registry.go
+++ b/core/services/nodes/registry.go
@@ -68,6 +68,39 @@ type NodeModel struct {
UpdatedAt time.Time `json:"updated_at"`
}
+// NodeLabel is a key-value label on a node (like K8s labels).
+type NodeLabel struct {
+ ID string `gorm:"primaryKey;size:36" json:"id"`
+ NodeID string `gorm:"uniqueIndex:idx_node_label;size:36" json:"node_id"`
+ Key string `gorm:"uniqueIndex:idx_node_label;size:128" json:"key"`
+ Value string `gorm:"size:255" json:"value"`
+}
+
+// ModelSchedulingConfig defines how a model should be scheduled across the cluster.
+// All fields are optional:
+// - NodeSelector only → constrain nodes, single replica
+// - MinReplicas/MaxReplicas only → auto-scale on any node
+// - Both → auto-scale on matching nodes
+// - Neither → no-op (default behavior)
+//
+// Auto-scaling is enabled when MinReplicas > 0 or MaxReplicas > 0.
+type ModelSchedulingConfig struct {
+ ID string `gorm:"primaryKey;size:36" json:"id"`
+ ModelName string `gorm:"uniqueIndex;size:255" json:"model_name"`
+ NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...}
+ MinReplicas int `gorm:"default:0" json:"min_replicas"`
+ MaxReplicas int `gorm:"default:0" json:"max_replicas"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+// NodeWithExtras extends BackendNode with computed fields for list views.
+type NodeWithExtras struct {
+ BackendNode
+ ModelCount int `json:"model_count"`
+ Labels map[string]string `json:"labels,omitempty"`
+}
+
// NodeRegistry manages backend node registration and lookup in PostgreSQL.
type NodeRegistry struct {
db *gorm.DB
@@ -78,7 +111,7 @@ type NodeRegistry struct {
// when multiple instances (frontend + workers) start at the same time.
func NewNodeRegistry(db *gorm.DB) (*NodeRegistry, error) {
if err := advisorylock.WithLockCtx(context.Background(), db, advisorylock.KeySchemaMigrate, func() error {
- return db.AutoMigrate(&BackendNode{}, &NodeModel{})
+ return db.AutoMigrate(&BackendNode{}, &NodeModel{}, &NodeLabel{}, &ModelSchedulingConfig{})
}); err != nil {
return nil, fmt.Errorf("migrating node tables: %w", err)
}
@@ -207,18 +240,19 @@ func (r *NodeRegistry) FindNodeWithVRAM(ctx context.Context, minBytes uint64) (*
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
Group("node_id")
- // Try idle nodes with enough VRAM first
+ // Try idle nodes with enough VRAM first, prefer the one with most free VRAM
var node BackendNode
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, minBytes, loadedModels).
+ Order("available_vram DESC").
First(&node).Error
if err == nil {
return &node, nil
}
- // Fall back to least-loaded nodes with enough VRAM
+ // Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
err = db.Where("status = ? AND node_type = ? AND available_vram >= ?", StatusHealthy, NodeTypeBackend, minBytes).
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
- Order("COALESCE(load.total_inflight, 0) ASC").
+ Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
First(&node).Error
if err != nil {
return nil, fmt.Errorf("no healthy nodes with %d bytes available VRAM: %w", minBytes, err)
@@ -407,9 +441,12 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
var node BackendNode
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ // Order by in_flight ASC (least busy replica), then by available_vram DESC
+ // (prefer nodes with more free VRAM to spread load across the cluster).
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
- Where("model_name = ? AND state = ?", modelName, "loaded").
- Order("in_flight ASC").
+ Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
+ Where("node_models.model_name = ? AND node_models.state = ?", modelName, "loaded").
+ Order("node_models.in_flight ASC, backend_nodes.available_vram DESC").
First(&nm).Error; err != nil {
return err
}
@@ -463,7 +500,7 @@ func (r *NodeRegistry) FindLeastLoadedNode(ctx context.Context) (*BackendNode, e
Group("node_id")
err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
- Order("COALESCE(load.total_inflight, 0) ASC").
+ Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
First(&node).Error
if err != nil {
return nil, fmt.Errorf("finding least loaded node: %w", err)
@@ -482,6 +519,7 @@ func (r *NodeRegistry) FindIdleNode(ctx context.Context) (*BackendNode, error) {
Where("state = ?", "loaded").
Group("node_id")
err := db.Where("status = ? AND node_type = ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, loadedModels).
+ Order("available_vram DESC").
First(&node).Error
if err != nil {
return nil, err
@@ -578,3 +616,287 @@ func (r *NodeRegistry) FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (
}
return &nm, nil
}
+
+// --- NodeLabel operations ---
+
+// SetNodeLabel upserts a single label on a node.
+func (r *NodeRegistry) SetNodeLabel(ctx context.Context, nodeID, key, value string) error {
+ label := NodeLabel{
+ ID: uuid.New().String(),
+ NodeID: nodeID,
+ Key: key,
+ Value: value,
+ }
+ return r.db.WithContext(ctx).
+ Clauses(clause.OnConflict{
+ Columns: []clause.Column{{Name: "node_id"}, {Name: "key"}},
+ DoUpdates: clause.AssignmentColumns([]string{"value"}),
+ }).
+ Create(&label).Error
+}
+
+// SetNodeLabels replaces all labels for a node with the given map.
+func (r *NodeRegistry) SetNodeLabels(ctx context.Context, nodeID string, labels map[string]string) error {
+ return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ if err := tx.Where("node_id = ?", nodeID).Delete(&NodeLabel{}).Error; err != nil {
+ return err
+ }
+ for k, v := range labels {
+ label := NodeLabel{ID: uuid.New().String(), NodeID: nodeID, Key: k, Value: v}
+ if err := tx.Create(&label).Error; err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+}
+
+// RemoveNodeLabel removes a single label from a node.
+func (r *NodeRegistry) RemoveNodeLabel(ctx context.Context, nodeID, key string) error {
+ return r.db.WithContext(ctx).Where("node_id = ? AND key = ?", nodeID, key).Delete(&NodeLabel{}).Error
+}
+
+// GetNodeLabels returns all labels for a node.
+func (r *NodeRegistry) GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error) {
+ var labels []NodeLabel
+ err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Find(&labels).Error
+ return labels, err
+}
+
+// GetAllNodeLabelsMap returns all labels grouped by node ID.
+func (r *NodeRegistry) GetAllNodeLabelsMap(ctx context.Context) (map[string]map[string]string, error) {
+ var labels []NodeLabel
+ if err := r.db.WithContext(ctx).Find(&labels).Error; err != nil {
+ return nil, err
+ }
+ result := make(map[string]map[string]string)
+ for _, l := range labels {
+ if result[l.NodeID] == nil {
+ result[l.NodeID] = make(map[string]string)
+ }
+ result[l.NodeID][l.Key] = l.Value
+ }
+ return result, nil
+}
+
+// --- Selector-based queries ---
+
+// FindNodesBySelector returns healthy backend nodes matching ALL key-value pairs in the selector.
+func (r *NodeRegistry) FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error) {
+ if len(selector) == 0 {
+ // Empty selector matches all healthy backend nodes
+ var nodes []BackendNode
+ err := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend).Find(&nodes).Error
+ return nodes, err
+ }
+
+ db := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend)
+ for k, v := range selector {
+ db = db.Where("EXISTS (SELECT 1 FROM node_labels WHERE node_labels.node_id = backend_nodes.id AND node_labels.key = ? AND node_labels.value = ?)", k, v)
+ }
+
+ var nodes []BackendNode
+ err := db.Find(&nodes).Error
+ return nodes, err
+}
+
+// FindNodeWithVRAMFromSet is like FindNodeWithVRAM but restricted to the given node IDs.
+func (r *NodeRegistry) FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error) {
+ db := r.db.WithContext(ctx)
+
+ loadedModels := db.Model(&NodeModel{}).
+ Select("node_id").
+ Where("state = ?", "loaded").
+ Group("node_id")
+
+ subquery := db.Model(&NodeModel{}).
+ Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
+ Group("node_id")
+
+ // Try idle nodes with enough VRAM first, prefer the one with most free VRAM
+ var node BackendNode
+ err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, minBytes, loadedModels, nodeIDs).
+ Order("available_vram DESC").
+ First(&node).Error
+ if err == nil {
+ return &node, nil
+ }
+
+ // Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
+ err = db.Where("status = ? AND node_type = ? AND available_vram >= ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, minBytes, nodeIDs).
+ Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
+ Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
+ First(&node).Error
+ if err != nil {
+ return nil, fmt.Errorf("no healthy nodes in set with %d bytes available VRAM: %w", minBytes, err)
+ }
+ return &node, nil
+}
+
+// FindIdleNodeFromSet is like FindIdleNode but restricted to the given node IDs.
+func (r *NodeRegistry) FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) {
+ db := r.db.WithContext(ctx)
+
+ var node BackendNode
+ loadedModels := db.Model(&NodeModel{}).
+ Select("node_id").
+ Where("state = ?", "loaded").
+ Group("node_id")
+ err := db.Where("status = ? AND node_type = ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, loadedModels, nodeIDs).
+ Order("available_vram DESC").
+ First(&node).Error
+ if err != nil {
+ return nil, err
+ }
+ return &node, nil
+}
+
+// FindLeastLoadedNodeFromSet is like FindLeastLoadedNode but restricted to the given node IDs.
+func (r *NodeRegistry) FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) {
+ db := r.db.WithContext(ctx)
+
+ var node BackendNode
+ query := db.Where("status = ? AND node_type = ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, nodeIDs)
+ // Order by total in-flight across all models on the node
+ subquery := db.Model(&NodeModel{}).
+ Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
+ Group("node_id")
+
+ err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
+ Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
+ First(&node).Error
+ if err != nil {
+ return nil, fmt.Errorf("finding least loaded node in set: %w", err)
+ }
+ return &node, nil
+}
+
+// --- ModelSchedulingConfig operations ---
+
+// SetModelScheduling creates or updates a scheduling config for a model.
+func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSchedulingConfig) error {
+ if config.ID == "" {
+ config.ID = uuid.New().String()
+ }
+ return r.db.WithContext(ctx).
+ Clauses(clause.OnConflict{
+ Columns: []clause.Column{{Name: "model_name"}},
+ DoUpdates: clause.AssignmentColumns([]string{"node_selector", "min_replicas", "max_replicas", "updated_at"}),
+ }).
+ Create(config).Error
+}
+
+// GetModelScheduling returns the scheduling config for a model, or nil if none exists.
+func (r *NodeRegistry) GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error) {
+ var config ModelSchedulingConfig
+ err := r.db.WithContext(ctx).Where("model_name = ?", modelName).First(&config).Error
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ return &config, nil
+}
+
+// ListModelSchedulings returns all scheduling configs.
+func (r *NodeRegistry) ListModelSchedulings(ctx context.Context) ([]ModelSchedulingConfig, error) {
+ var configs []ModelSchedulingConfig
+ err := r.db.WithContext(ctx).Find(&configs).Error
+ return configs, err
+}
+
+// ListAutoScalingConfigs returns scheduling configs where auto-scaling is enabled.
+func (r *NodeRegistry) ListAutoScalingConfigs(ctx context.Context) ([]ModelSchedulingConfig, error) {
+ var configs []ModelSchedulingConfig
+ err := r.db.WithContext(ctx).Where("min_replicas > 0 OR max_replicas > 0").Find(&configs).Error
+ return configs, err
+}
+
+// DeleteModelScheduling removes a scheduling config by model name.
+func (r *NodeRegistry) DeleteModelScheduling(ctx context.Context, modelName string) error {
+ return r.db.WithContext(ctx).Where("model_name = ?", modelName).Delete(&ModelSchedulingConfig{}).Error
+}
+
+// CountLoadedReplicas returns the number of loaded replicas for a model.
+func (r *NodeRegistry) CountLoadedReplicas(ctx context.Context, modelName string) (int64, error) {
+ var count int64
+ err := r.db.WithContext(ctx).Model(&NodeModel{}).Where("model_name = ? AND state = ?", modelName, "loaded").Count(&count).Error
+ return count, err
+}
+
+// --- Composite queries ---
+
+// ListWithExtras returns all nodes with model counts and labels.
+func (r *NodeRegistry) ListWithExtras(ctx context.Context) ([]NodeWithExtras, error) {
+ // Get all nodes
+ var nodes []BackendNode
+ if err := r.db.WithContext(ctx).Find(&nodes).Error; err != nil {
+ return nil, err
+ }
+
+ // Get model counts per node
+ type modelCount struct {
+ NodeID string
+ Count int
+ }
+ var counts []modelCount
+ if err := r.db.WithContext(ctx).Model(&NodeModel{}).
+ Select("node_id, COUNT(*) as count").
+ Where("state = ?", "loaded").
+ Group("node_id").
+ Find(&counts).Error; err != nil {
+ xlog.Warn("ListWithExtras: failed to get model counts", "error", err)
+ }
+
+ countMap := make(map[string]int)
+ for _, c := range counts {
+ countMap[c.NodeID] = c.Count
+ }
+
+ // Get all labels
+ labelsMap, err := r.GetAllNodeLabelsMap(ctx)
+ if err != nil {
+ xlog.Warn("ListWithExtras: failed to get labels", "error", err)
+ }
+
+ // Build result
+ result := make([]NodeWithExtras, len(nodes))
+ for i, n := range nodes {
+ result[i] = NodeWithExtras{
+ BackendNode: n,
+ ModelCount: countMap[n.ID],
+ Labels: labelsMap[n.ID],
+ }
+ }
+ return result, nil
+}
+
+// ApplyAutoLabels sets automatic labels based on node hardware info.
+func (r *NodeRegistry) ApplyAutoLabels(ctx context.Context, nodeID string, node *BackendNode) {
+ if node.GPUVendor != "" {
+ _ = r.SetNodeLabel(ctx, nodeID, "gpu.vendor", node.GPUVendor)
+ }
+ if node.TotalVRAM > 0 {
+ gb := node.TotalVRAM / (1024 * 1024 * 1024)
+ var bucket string
+ switch {
+ case gb >= 80:
+ bucket = "80GB+"
+ case gb >= 48:
+ bucket = "48GB"
+ case gb >= 24:
+ bucket = "24GB"
+ case gb >= 16:
+ bucket = "16GB"
+ case gb >= 8:
+ bucket = "8GB"
+ default:
+ bucket = fmt.Sprintf("%dGB", gb)
+ }
+ _ = r.SetNodeLabel(ctx, nodeID, "gpu.vram", bucket)
+ }
+ if node.Name != "" {
+ _ = r.SetNodeLabel(ctx, nodeID, "node.name", node.Name)
+ }
+}
diff --git a/core/services/nodes/registry_test.go b/core/services/nodes/registry_test.go
index a7dbd6a54..7b0a747b1 100644
--- a/core/services/nodes/registry_test.go
+++ b/core/services/nodes/registry_test.go
@@ -305,6 +305,252 @@ var _ = Describe("NodeRegistry", func() {
})
})
+ Describe("NodeLabel CRUD", func() {
+ It("sets and retrieves labels for a node", func() {
+ node := makeNode("label-node", "10.0.0.70:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+
+ Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed())
+
+ labels, err := registry.GetNodeLabels(context.Background(), node.ID)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(labels).To(HaveLen(2))
+
+ labelMap := make(map[string]string)
+ for _, l := range labels {
+ labelMap[l.Key] = l.Value
+ }
+ Expect(labelMap["env"]).To(Equal("prod"))
+ Expect(labelMap["region"]).To(Equal("us-east"))
+ })
+
+ It("overwrites existing label with same key", func() {
+ node := makeNode("label-overwrite", "10.0.0.71:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+
+ Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "dev")).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
+
+ labels, err := registry.GetNodeLabels(context.Background(), node.ID)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(labels).To(HaveLen(1))
+ Expect(labels[0].Value).To(Equal("prod"))
+ })
+
+ It("removes a single label by key", func() {
+ node := makeNode("label-remove", "10.0.0.72:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+
+ Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed())
+
+ Expect(registry.RemoveNodeLabel(context.Background(), node.ID, "env")).To(Succeed())
+
+ labels, err := registry.GetNodeLabels(context.Background(), node.ID)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(labels).To(HaveLen(1))
+ Expect(labels[0].Key).To(Equal("region"))
+ })
+
+ It("SetNodeLabels replaces all labels", func() {
+ node := makeNode("label-replace", "10.0.0.73:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), node, true)).To(Succeed())
+
+ Expect(registry.SetNodeLabel(context.Background(), node.ID, "old-key", "old-val")).To(Succeed())
+
+ newLabels := map[string]string{"new-a": "val-a", "new-b": "val-b"}
+ Expect(registry.SetNodeLabels(context.Background(), node.ID, newLabels)).To(Succeed())
+
+ labels, err := registry.GetNodeLabels(context.Background(), node.ID)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(labels).To(HaveLen(2))
+
+ labelMap := make(map[string]string)
+ for _, l := range labels {
+ labelMap[l.Key] = l.Value
+ }
+ Expect(labelMap).To(Equal(newLabels))
+ })
+ })
+
+ Describe("FindNodesBySelector", func() {
+ It("returns nodes matching all labels in selector", func() {
+ n1 := makeNode("sel-match", "10.0.0.80:50051", 8_000_000_000)
+ n2 := makeNode("sel-nomatch", "10.0.0.81:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
+ Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
+
+ Expect(registry.SetNodeLabel(context.Background(), n1.ID, "env", "prod")).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), n1.ID, "region", "us-east")).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), n2.ID, "env", "dev")).To(Succeed())
+
+ nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod", "region": "us-east"})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(nodes).To(HaveLen(1))
+ Expect(nodes[0].Name).To(Equal("sel-match"))
+ })
+
+ It("returns empty when no nodes match", func() {
+ n := makeNode("sel-empty", "10.0.0.82:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), n, true)).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "dev")).To(Succeed())
+
+ nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(nodes).To(BeEmpty())
+ })
+
+ It("ignores unhealthy nodes", func() {
+ n := makeNode("sel-unhealthy", "10.0.0.83:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), n, true)).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed())
+ Expect(registry.MarkUnhealthy(context.Background(), n.ID)).To(Succeed())
+
+ nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(nodes).To(BeEmpty())
+ })
+
+ It("matches nodes with more labels than selector requires", func() {
+ n := makeNode("sel-superset", "10.0.0.84:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), n, true)).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), n.ID, "region", "us-east")).To(Succeed())
+ Expect(registry.SetNodeLabel(context.Background(), n.ID, "tier", "gpu")).To(Succeed())
+
+ nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(nodes).To(HaveLen(1))
+ Expect(nodes[0].Name).To(Equal("sel-superset"))
+ })
+
+ It("returns all healthy nodes for empty selector", func() {
+ n1 := makeNode("sel-all-1", "10.0.0.85:50051", 8_000_000_000)
+ n2 := makeNode("sel-all-2", "10.0.0.86:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
+ Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
+
+ nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{})
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(nodes)).To(BeNumerically(">=", 2))
+ })
+ })
+
+ Describe("ModelSchedulingConfig CRUD", func() {
+ It("creates and retrieves a scheduling config", func() {
+ config := &ModelSchedulingConfig{
+ ModelName: "llama-7b",
+ NodeSelector: `{"gpu.vendor":"nvidia"}`,
+ MinReplicas: 1,
+ MaxReplicas: 3,
+ }
+ Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed())
+ Expect(config.ID).ToNot(BeEmpty())
+
+ fetched, err := registry.GetModelScheduling(context.Background(), "llama-7b")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(fetched).ToNot(BeNil())
+ Expect(fetched.ModelName).To(Equal("llama-7b"))
+ Expect(fetched.NodeSelector).To(Equal(`{"gpu.vendor":"nvidia"}`))
+ Expect(fetched.MinReplicas).To(Equal(1))
+ Expect(fetched.MaxReplicas).To(Equal(3))
+ })
+
+ It("updates existing config via SetModelScheduling", func() {
+ config := &ModelSchedulingConfig{
+ ModelName: "update-model",
+ MinReplicas: 1,
+ MaxReplicas: 2,
+ }
+ Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed())
+
+ config2 := &ModelSchedulingConfig{
+ ModelName: "update-model",
+ MinReplicas: 2,
+ MaxReplicas: 5,
+ }
+ Expect(registry.SetModelScheduling(context.Background(), config2)).To(Succeed())
+
+ fetched, err := registry.GetModelScheduling(context.Background(), "update-model")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(fetched.MinReplicas).To(Equal(2))
+ Expect(fetched.MaxReplicas).To(Equal(5))
+ })
+
+ It("lists all configs", func() {
+ Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed())
+ Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed())
+
+ configs, err := registry.ListModelSchedulings(context.Background())
+ Expect(err).ToNot(HaveOccurred())
+ Expect(len(configs)).To(BeNumerically(">=", 2))
+ })
+
+ It("lists only auto-scaling configs", func() {
+ Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-a", MinReplicas: 2})).To(Succeed())
+ Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-b", MaxReplicas: 3})).To(Succeed())
+ Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "no-auto", NodeSelector: `{"env":"prod"}`})).To(Succeed())
+
+ configs, err := registry.ListAutoScalingConfigs(context.Background())
+ Expect(err).ToNot(HaveOccurred())
+
+ names := make([]string, len(configs))
+ for i, c := range configs {
+ names[i] = c.ModelName
+ }
+ Expect(names).To(ContainElement("auto-a"))
+ Expect(names).To(ContainElement("auto-b"))
+ Expect(names).ToNot(ContainElement("no-auto"))
+ })
+
+ It("deletes a config", func() {
+ Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "delete-me", MinReplicas: 1})).To(Succeed())
+
+ Expect(registry.DeleteModelScheduling(context.Background(), "delete-me")).To(Succeed())
+
+ fetched, err := registry.GetModelScheduling(context.Background(), "delete-me")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(fetched).To(BeNil())
+ })
+
+ It("returns nil for non-existent model", func() {
+ fetched, err := registry.GetModelScheduling(context.Background(), "does-not-exist")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(fetched).To(BeNil())
+ })
+ })
+
+ Describe("CountLoadedReplicas", func() {
+ It("returns correct count of loaded replicas", func() {
+ n1 := makeNode("replica-node-1", "10.0.0.90:50051", 8_000_000_000)
+ n2 := makeNode("replica-node-2", "10.0.0.91:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
+ Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
+
+ Expect(registry.SetNodeModel(context.Background(), n1.ID, "counted-model", "loaded", "", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), n2.ID, "counted-model", "loaded", "", 0)).To(Succeed())
+
+ count, err := registry.CountLoadedReplicas(context.Background(), "counted-model")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(count).To(Equal(int64(2)))
+ })
+
+ It("excludes non-loaded states", func() {
+ n1 := makeNode("replica-loaded", "10.0.0.92:50051", 8_000_000_000)
+ n2 := makeNode("replica-loading", "10.0.0.93:50051", 8_000_000_000)
+ Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
+ Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
+
+ Expect(registry.SetNodeModel(context.Background(), n1.ID, "state-model", "loaded", "", 0)).To(Succeed())
+ Expect(registry.SetNodeModel(context.Background(), n2.ID, "state-model", "loading", "", 0)).To(Succeed())
+
+ count, err := registry.CountLoadedReplicas(context.Background(), "state-model")
+ Expect(err).ToNot(HaveOccurred())
+ Expect(count).To(Equal(int64(1)))
+ })
+ })
+
Describe("DecrementInFlight", func() {
It("does not go below zero", func() {
node := makeNode("dec-node", "10.0.0.50:50051", 4_000_000_000)
diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go
index c61a11d91..7d77b62bc 100644
--- a/core/services/nodes/router.go
+++ b/core/services/nodes/router.go
@@ -2,6 +2,7 @@ package nodes
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -69,6 +70,18 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
// Unloader returns the remote unloader adapter for external use.
func (r *SmartRouter) Unloader() NodeCommandSender { return r.unloader }
+// ScheduleAndLoadModel implements ModelScheduler for the reconciler.
+// It schedules a model on a suitable node (optionally from candidates) and loads it.
+func (r *SmartRouter) ScheduleAndLoadModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) {
+ // Use scheduleNewModel with empty backend type and nil model options.
+ // The reconciler doesn't know the backend type — it will be determined by the model config.
+ node, _, err := r.scheduleNewModel(ctx, "", modelName, nil)
+ if err != nil {
+ return nil, err
+ }
+ return node, nil
+}
+
// RouteResult contains the routing decision.
type RouteResult struct {
Node *BackendNode
@@ -110,18 +123,26 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
xlog.Warn("Backend not reachable for cached model, falling through to reload",
"node", node.Name, "model", modelName)
} else {
- // Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight,
- // and Release decrements it. No InFlightTrackingClient to avoid double-counting.
- r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
- grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
- return &RouteResult{
- Node: node,
- Client: grpcClient,
- Release: func() {
- r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
- closeClient(grpcClient)
- },
- }, nil
+ // Verify node still matches scheduling constraints
+ if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
+ r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
+ xlog.Info("Cached model on node that no longer matches selector, falling through",
+ "node", node.Name, "model", trackingKey)
+ // Fall through to step 2 (scheduleNewModel)
+ } else {
+ // Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight,
+ // and Release decrements it. No InFlightTrackingClient to avoid double-counting.
+ r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
+ grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
+ return &RouteResult{
+ Node: node,
+ Client: grpcClient,
+ Release: func() {
+ r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
+ closeClient(grpcClient)
+ },
+ }, nil
+ }
}
}
@@ -143,17 +164,25 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
xlog.Warn("Backend not reachable for cached model inside lock, proceeding to load",
"node", node.Name, "model", modelName)
} else {
- // Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting
- r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
- grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
- return &RouteResult{
- Node: node,
- Client: grpcClient,
- Release: func() {
- r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
- closeClient(grpcClient)
- },
- }, nil
+ // Verify node still matches scheduling constraints
+ if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
+ r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
+ xlog.Info("Cached model on node that no longer matches selector, falling through",
+ "node", node.Name, "model", trackingKey)
+ // Fall through to scheduling below
+ } else {
+ // Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting
+ r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
+ grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
+ return &RouteResult{
+ Node: node,
+ Client: grpcClient,
+ Release: func() {
+ r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
+ closeClient(grpcClient)
+ },
+ }, nil
+ }
}
}
@@ -223,6 +252,59 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
return loadModel()
}
+// parseSelectorJSON decodes a JSON node selector string into a map.
+func parseSelectorJSON(selectorJSON string) map[string]string {
+ if selectorJSON == "" {
+ return nil
+ }
+ var selector map[string]string
+ if err := json.Unmarshal([]byte(selectorJSON), &selector); err != nil {
+ xlog.Warn("Failed to parse node selector", "selector", selectorJSON, "error", err)
+ return nil
+ }
+ return selector
+}
+
+func extractNodeIDs(nodes []BackendNode) []string {
+ ids := make([]string, len(nodes))
+ for i, n := range nodes {
+ ids[i] = n.ID
+ }
+ return ids
+}
+
+// nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model.
+// Returns true if no constraints exist or the node matches all selector labels.
+func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool {
+ sched, err := r.registry.GetModelScheduling(ctx, modelName)
+ if err != nil || sched == nil || sched.NodeSelector == "" {
+ return true // no constraints
+ }
+
+ selector := parseSelectorJSON(sched.NodeSelector)
+ if len(selector) == 0 {
+ return true
+ }
+
+ labels, err := r.registry.GetNodeLabels(ctx, node.ID)
+ if err != nil {
+ xlog.Warn("Failed to get node labels for selector check", "node", node.ID, "error", err)
+ return true // fail open
+ }
+
+ labelMap := make(map[string]string)
+ for _, l := range labels {
+ labelMap[l.Key] = l.Value
+ }
+
+ for k, v := range selector {
+ if labelMap[k] != v {
+ return false
+ }
+ }
+ return true
+}
+
// scheduleNewModel picks the best node for loading a new model.
// Strategy: VRAM-aware → idle-first → least-loaded.
// Sends backend.install via NATS so the chosen node has the right backend running.
@@ -233,12 +315,30 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
estimatedVRAM = r.estimateModelVRAM(ctx, modelOpts)
}
+ // Check for scheduling constraints (node selector)
+ sched, _ := r.registry.GetModelScheduling(ctx, modelID)
+ var candidateNodeIDs []string // nil = all nodes eligible
+
+ if sched != nil && sched.NodeSelector != "" {
+ selector := parseSelectorJSON(sched.NodeSelector)
+ if len(selector) > 0 {
+ candidates, err := r.registry.FindNodesBySelector(ctx, selector)
+ if err != nil || len(candidates) == 0 {
+ return nil, "", fmt.Errorf("no healthy nodes match selector for model %s: %v", modelID, sched.NodeSelector)
+ }
+ candidateNodeIDs = extractNodeIDs(candidates)
+ }
+ }
+
var node *BackendNode
var err error
if estimatedVRAM > 0 {
- // 1. Prefer nodes with enough VRAM (idle-first, then least-loaded)
- node, err = r.registry.FindNodeWithVRAM(ctx, estimatedVRAM)
+ if candidateNodeIDs != nil {
+ node, err = r.registry.FindNodeWithVRAMFromSet(ctx, estimatedVRAM, candidateNodeIDs)
+ } else {
+ node, err = r.registry.FindNodeWithVRAM(ctx, estimatedVRAM)
+ }
if err != nil {
xlog.Warn("No nodes with enough VRAM, falling back to standard scheduling",
"required_vram", vram.FormatBytes(estimatedVRAM), "error", err)
@@ -246,11 +346,16 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
}
if node == nil {
- // 2. Prefer truly idle nodes (no loaded models, no in-flight)
- node, err = r.registry.FindIdleNode(ctx)
- if err != nil {
- // 3. Fall back to least-loaded node (can run an additional backend process)
- node, err = r.registry.FindLeastLoadedNode(ctx)
+ if candidateNodeIDs != nil {
+ node, err = r.registry.FindIdleNodeFromSet(ctx, candidateNodeIDs)
+ if err != nil {
+ node, err = r.registry.FindLeastLoadedNodeFromSet(ctx, candidateNodeIDs)
+ }
+ } else {
+ node, err = r.registry.FindIdleNode(ctx)
+ if err != nil {
+ node, err = r.registry.FindLeastLoadedNode(ctx)
+ }
}
}
@@ -610,7 +715,12 @@ func (r *SmartRouter) evictLRUAndFreeNode(ctx context.Context) (*BackendNode, er
// Lock the row so no other frontend can evict the same model
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
- Where("node_models.in_flight = 0 AND node_models.state = ? AND backend_nodes.status = ?", "loaded", StatusHealthy).
+ Where(`node_models.in_flight = 0 AND node_models.state = ? AND backend_nodes.status = ?
+ AND (
+ NOT EXISTS (SELECT 1 FROM model_scheduling_configs sc WHERE sc.model_name = node_models.model_name AND (sc.min_replicas > 0 OR sc.max_replicas > 0))
+ OR (SELECT COUNT(*) FROM node_models nm2 WHERE nm2.model_name = node_models.model_name AND nm2.state = 'loaded')
+ > COALESCE((SELECT sc2.min_replicas FROM model_scheduling_configs sc2 WHERE sc2.model_name = node_models.model_name), 1)
+ )`, "loaded", StatusHealthy).
Order("node_models.last_used ASC").
First(&lru).Error; err != nil {
return err
diff --git a/core/services/nodes/router_test.go b/core/services/nodes/router_test.go
index 9f807b34b..8c53531fe 100644
--- a/core/services/nodes/router_test.go
+++ b/core/services/nodes/router_test.go
@@ -86,6 +86,26 @@ type fakeModelRouter struct {
getNode *BackendNode
getErr error
+ // GetModelScheduling returns
+ getModelScheduling *ModelSchedulingConfig
+ getModelSchedErr error
+
+ // FindNodesBySelector returns
+ findBySelectorNodes []BackendNode
+ findBySelectorErr error
+
+ // *FromSet variants
+ findVRAMFromSetNode *BackendNode
+ findVRAMFromSetErr error
+ findIdleFromSetNode *BackendNode
+ findIdleFromSetErr error
+ findLeastLoadedFromSetNode *BackendNode
+ findLeastLoadedFromSetErr error
+
+ // GetNodeLabels returns
+ getNodeLabels []NodeLabel
+ getNodeLabelsErr error
+
// Track calls for assertions
decrementCalls []string // "nodeID:modelName"
incrementCalls []string
@@ -146,6 +166,30 @@ func (f *fakeModelRouter) Get(_ context.Context, _ string) (*BackendNode, error)
return f.getNode, f.getErr
}
+func (f *fakeModelRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
+ return f.getModelScheduling, f.getModelSchedErr
+}
+
+func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
+ return f.findBySelectorNodes, f.findBySelectorErr
+}
+
+func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
+ return f.findVRAMFromSetNode, f.findVRAMFromSetErr
+}
+
+func (f *fakeModelRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
+ return f.findIdleFromSetNode, f.findIdleFromSetErr
+}
+
+func (f *fakeModelRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
+ return f.findLeastLoadedFromSetNode, f.findLeastLoadedFromSetErr
+}
+
+func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
+ return f.getNodeLabels, f.getNodeLabelsErr
+}
+
// ---------------------------------------------------------------------------
// Fake BackendClientFactory + Backend
// ---------------------------------------------------------------------------
@@ -478,6 +522,135 @@ var _ = Describe("SmartRouter", func() {
})
})
+ Describe("scheduleNewModel with node selector (mock-based, via Route)", func() {
+ var (
+ reg *fakeModelRouter
+ backend *stubBackend
+ factory *stubClientFactory
+ unloader *fakeUnloader
+ )
+
+ BeforeEach(func() {
+ reg = &fakeModelRouter{
+ findAndLockErr: errors.New("not found"),
+ }
+ backend = &stubBackend{
+ loadResult: &pb.Result{Success: true},
+ }
+ factory = &stubClientFactory{client: backend}
+ unloader = &fakeUnloader{
+ installReply: &messaging.BackendInstallReply{
+ Success: true,
+ Address: "10.0.0.1:9001",
+ },
+ }
+ })
+
+ It("uses *FromSet methods when model has a node selector", func() {
+ gpuNode := &BackendNode{ID: "gpu-1", Name: "gpu-node", Address: "10.0.0.50:50051"}
+ reg.getModelScheduling = &ModelSchedulingConfig{
+ ModelName: "selector-model",
+ NodeSelector: `{"gpu.vendor":"nvidia"}`,
+ }
+ reg.findBySelectorNodes = []BackendNode{*gpuNode}
+ reg.findIdleFromSetNode = gpuNode
+
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ })
+
+ result, err := router.Route(context.Background(), "selector-model", "models/selector.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(result).ToNot(BeNil())
+ Expect(result.Node.ID).To(Equal("gpu-1"))
+ })
+
+ It("returns error when no nodes match selector", func() {
+ reg.getModelScheduling = &ModelSchedulingConfig{
+ ModelName: "no-match-model",
+ NodeSelector: `{"gpu.vendor":"tpu"}`,
+ }
+ reg.findBySelectorNodes = nil
+ reg.findBySelectorErr = nil
+
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ })
+
+ _, err := router.Route(context.Background(), "no-match-model", "models/nomatch.gguf", "llama-cpp", nil, false)
+ Expect(err).To(HaveOccurred())
+ Expect(err.Error()).To(ContainSubstring("no healthy nodes match selector"))
+ })
+
+ It("uses regular methods when model has no scheduling config", func() {
+ reg.getModelScheduling = nil
+ idleNode := &BackendNode{ID: "regular-1", Name: "regular-node", Address: "10.0.0.60:50051"}
+ reg.findIdleNode = idleNode
+
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ })
+
+ result, err := router.Route(context.Background(), "regular-model", "models/regular.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(result).ToNot(BeNil())
+ Expect(result.Node.ID).To(Equal("regular-1"))
+ })
+ })
+
+ Describe("Route with selector validation on cached model (mock-based)", func() {
+ It("falls through when cached node no longer matches selector", func() {
+ cachedNode := &BackendNode{ID: "n-old", Name: "old-node", Address: "10.0.0.70:50051"}
+ newNode := &BackendNode{ID: "n-new", Name: "new-node", Address: "10.0.0.71:50051"}
+
+ backend := &stubBackend{
+ healthResult: true,
+ loadResult: &pb.Result{Success: true},
+ }
+ factory := &stubClientFactory{client: backend}
+ unloader := &fakeUnloader{
+ installReply: &messaging.BackendInstallReply{
+ Success: true,
+ Address: "10.0.0.71:9001",
+ },
+ }
+
+ reg := &fakeModelRouter{
+ // Step 1: cached model found on old node
+ findAndLockNode: cachedNode,
+ findAndLockNM: &NodeModel{NodeID: "n-old", ModelName: "sel-model", Address: "10.0.0.70:9001"},
+ // Scheduling config with selector that old node does NOT match
+ getModelScheduling: &ModelSchedulingConfig{
+ ModelName: "sel-model",
+ NodeSelector: `{"gpu.vendor":"nvidia"}`,
+ },
+ // Old node has no labels matching the selector
+ getNodeLabels: []NodeLabel{
+ {NodeID: "n-old", Key: "gpu.vendor", Value: "amd"},
+ },
+ // For scheduling fallthrough: selector matches new node
+ findBySelectorNodes: []BackendNode{*newNode},
+ findIdleFromSetNode: newNode,
+ }
+
+ router := NewSmartRouter(reg, SmartRouterOptions{
+ Unloader: unloader,
+ ClientFactory: factory,
+ })
+
+ result, err := router.Route(context.Background(), "sel-model", "models/sel.gguf", "llama-cpp", nil, false)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(result).ToNot(BeNil())
+ // Should have fallen through to the new node
+ Expect(result.Node.ID).To(Equal("n-new"))
+ // Old node should have had its in-flight decremented
+ Expect(reg.decrementCalls).To(ContainElement("n-old:sel-model"))
+ })
+ })
+
// -----------------------------------------------------------------------
// Integration tests using real PostgreSQL (existing)
// -----------------------------------------------------------------------
diff --git a/docs/content/features/distributed-mode.md b/docs/content/features/distributed-mode.md
index 0433de3ee..b942c244b 100644
--- a/docs/content/features/distributed-mode.md
+++ b/docs/content/features/distributed-mode.md
@@ -134,6 +134,47 @@ local-ai worker \
**HTTP file transfer:** Each worker also runs a small HTTP server for file transfer (model files, configs). By default it listens on the gRPC base port - 1 (e.g., if gRPC base is 50051, HTTP is on 50050). gRPC ports grow upward from the base port as additional models are loaded. Set `--advertise-http-addr` if the auto-detected address is not routable from the frontend.
{{% /notice %}}
+### Worker Address Configuration
+
+The simplest way to configure a worker's network address is with a single variable:
+
+| Variable | Description |
+|----------|-------------|
+| `LOCALAI_ADDR` | Reachable address of this worker (`host:port`). The port is used as the base for gRPC backend processes, and `port-1` for the HTTP file transfer server. |
+
+**Example:**
+```yaml
+environment:
+ LOCALAI_ADDR: "192.168.1.100:50051"
+ LOCALAI_NATS_URL: "nats://frontend:4222"
+ LOCALAI_REGISTER_TO: "http://frontend:8080"
+ LOCALAI_REGISTRATION_TOKEN: "my-secret"
+```
+
+For advanced networking scenarios (NAT, load balancers, separate gRPC/HTTP ports), the following override variables are available:
+
+| Variable | Description | Default |
+|----------|-------------|---------|
+| `LOCALAI_SERVE_ADDR` | gRPC base port bind address | `0.0.0.0:50051` |
+| `LOCALAI_HTTP_ADDR` | HTTP file transfer bind address | `0.0.0.0:{gRPC port - 1}` |
+| `LOCALAI_ADVERTISE_ADDR` | Public gRPC address (if different from `LOCALAI_ADDR`) | Derived from `LOCALAI_ADDR` |
+| `LOCALAI_ADVERTISE_HTTP_ADDR` | Public HTTP address (if different from gRPC host) | Derived from advertise host + HTTP port |
+
+### Node Labels
+
+Workers can declare labels at startup for scheduling constraints:
+
+| Variable | Description | Example |
+|----------|-------------|---------|
+| `LOCALAI_NODE_LABELS` | Comma-separated `key=value` labels | `tier=premium,gpu=a100,zone=us-east` |
+
+Labels can also be managed via the admin API (see [Label Management API](#label-management-api) below).
+
+The system automatically applies hardware-detected labels on registration:
+- `gpu.vendor` -- GPU vendor (nvidia, amd, intel, vulkan)
+- `gpu.vram` -- GPU VRAM bucket (8GB, 16GB, 24GB, 48GB, 80GB+)
+- `node.name` -- The node's registered name
+
### How Workers Operate
Workers start as generic processes with no backend installed. When the SmartRouter needs to load a model on a worker, it sends a NATS `backend.install` event with the backend name and model ID. The worker:
@@ -262,6 +303,75 @@ local-ai worker \
**Multiple frontend replicas:** Run multiple LocalAI frontends behind a load balancer. Since all state is in PostgreSQL and coordination is via NATS, frontends are fully stateless and interchangeable.
+## Model Scheduling
+
+Model scheduling controls where models are placed and how many replicas are maintained. It combines two optional features:
+
+### Node Selectors
+
+Pin models to nodes with specific labels. Only nodes matching **all** selector labels are eligible:
+
+```bash
+# Only schedule on NVIDIA nodes in the us-east zone
+curl -X POST http://frontend:8080/api/nodes/scheduling \
+ -H "Content-Type: application/json" \
+ -d '{"model_name": "llama3", "node_selector": {"gpu.vendor": "nvidia", "zone": "us-east"}}'
+```
+
+Without a node selector, models can schedule on any healthy node (default behavior).
+
+### Replica Auto-Scaling
+
+Control the number of model replicas across the cluster:
+
+| Field | Description |
+|-------|-------------|
+| `min_replicas` | Minimum replicas to maintain (0 = no minimum, single replica default) |
+| `max_replicas` | Maximum replicas allowed (0 = unlimited) |
+
+Auto-scaling is **only active** when `min_replicas > 0` or `max_replicas > 0`.
+
+```bash
+# Scale llama3 between 2 and 4 replicas on NVIDIA nodes
+curl -X POST http://frontend:8080/api/nodes/scheduling \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model_name": "llama3",
+ "node_selector": {"gpu.vendor": "nvidia"},
+ "min_replicas": 2,
+ "max_replicas": 4
+ }'
+```
+
+The **Replica Reconciler** runs as a background process on the frontend:
+- **Scale up**: Adds replicas when all existing replicas are busy (have in-flight requests)
+- **Scale down**: Removes idle replicas after 5 minutes of inactivity
+- **Maintain minimum**: Ensures `min_replicas` are always loaded (recovers from node failures)
+- **Eviction protection**: Models with auto-scaling enabled are never evicted below `min_replicas`
+
+All fields are optional and composable:
+- Node selector only: pin model to matching nodes, single replica
+- Replicas only: auto-scale across all nodes
+- Both: auto-scale on matching nodes only
+
+## Label Management API
+
+| Method | Path | Description |
+|--------|------|-------------|
+| `GET` | `/api/nodes/:id/labels` | Get labels for a node |
+| `PUT` | `/api/nodes/:id/labels` | Replace all labels (JSON object) |
+| `PATCH` | `/api/nodes/:id/labels` | Merge labels (add/update) |
+| `DELETE` | `/api/nodes/:id/labels/:key` | Remove a single label |
+
+## Scheduling API
+
+| Method | Path | Description |
+|--------|------|-------------|
+| `GET` | `/api/nodes/scheduling` | List all scheduling configs |
+| `GET` | `/api/nodes/scheduling/:model` | Get config for a model |
+| `POST` | `/api/nodes/scheduling` | Create/update config |
+| `DELETE` | `/api/nodes/scheduling/:model` | Remove config |
+
## Comparison with P2P
| | P2P / Federation | Distributed Mode |
diff --git a/docs/content/features/runtime-settings.md b/docs/content/features/runtime-settings.md
index 79d4ca7c5..1110f1cf7 100644
--- a/docs/content/features/runtime-settings.md
+++ b/docs/content/features/runtime-settings.md
@@ -28,7 +28,6 @@ Changes to watchdog settings are applied immediately by restarting the watchdog
### Backend Configuration
- **Max Active Backends**: Maximum number of active backends (loaded models). When exceeded, the least recently used model is automatically evicted. Set to `0` for unlimited, `1` for single-backend mode
-- **Parallel Backend Requests**: Enable backends to handle multiple requests in parallel if supported
- **Force Eviction When Busy**: Allow evicting models even when they have active API calls (default: disabled for safety). **Warning:** Enabling this can interrupt active requests
- **LRU Eviction Max Retries**: Maximum number of retries when waiting for busy models to become idle before eviction (default: 30)
- **LRU Eviction Retry Interval**: Interval between retries when waiting for busy models (default: `1s`)
@@ -123,7 +122,6 @@ The `runtime_settings.json` file follows this structure:
"watchdog_idle_timeout": "15m",
"watchdog_busy_timeout": "5m",
"max_active_backends": 0,
- "parallel_backend_requests": true,
"force_eviction_when_busy": false,
"lru_eviction_max_retries": 30,
"lru_eviction_retry_interval": "1s",
diff --git a/docs/content/reference/cli-reference.md b/docs/content/reference/cli-reference.md
index 9743c8682..556cf5a99 100644
--- a/docs/content/reference/cli-reference.md
+++ b/docs/content/reference/cli-reference.md
@@ -39,7 +39,6 @@ Complete reference for all LocalAI command-line interface (CLI) parameters and e
| `--external-grpc-backends` | | A list of external gRPC backends (format: `BACKEND_NAME:URI`) | `$LOCALAI_EXTERNAL_GRPC_BACKENDS`, `$EXTERNAL_GRPC_BACKENDS` |
| `--backend-galleries` | | JSON list of backend galleries | `$LOCALAI_BACKEND_GALLERIES`, `$BACKEND_GALLERIES` |
| `--autoload-backend-galleries` | `true` | Automatically load backend galleries on startup | `$LOCALAI_AUTOLOAD_BACKEND_GALLERIES`, `$AUTOLOAD_BACKEND_GALLERIES` |
-| `--parallel-requests` | `false` | Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm) | `$LOCALAI_PARALLEL_REQUESTS`, `$PARALLEL_REQUESTS` |
| `--max-active-backends` | `0` | Maximum number of active backends (loaded models). When exceeded, the least recently used model is evicted. Set to `0` for unlimited, `1` for single-backend mode | `$LOCALAI_MAX_ACTIVE_BACKENDS`, `$MAX_ACTIVE_BACKENDS` |
| `--single-active-backend` | `false` | **DEPRECATED** - Use `--max-active-backends=1` instead. Allow only one backend to be run at a time | `$LOCALAI_SINGLE_ACTIVE_BACKEND`, `$SINGLE_ACTIVE_BACKEND` |
| `--preload-backend-only` | `false` | Do not launch the API services, only the preloaded models/backends are started (useful for multi-node setups) | `$LOCALAI_PRELOAD_BACKEND_ONLY`, `$PRELOAD_BACKEND_ONLY` |