feat: add node reconciler, allow to schedule to group of nodes, min/max autoscaler (#9186)

* always enable parallel requests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: add node reconciler, allow to schedule to group of nodes, min/max autoscaler

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore: move tests to ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(smart router): order by available vram

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-03-31 08:28:56 +02:00
committed by GitHub
parent 80699a3f70
commit 8862e3ce60
30 changed files with 2337 additions and 365 deletions

View File

@@ -199,7 +199,6 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout
envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend
envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends
envParallelRequests := appConfig.ParallelBackendRequests == startupAppConfig.ParallelBackendRequests
envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled
envMemoryReclaimerThreshold := appConfig.MemoryReclaimerThreshold == startupAppConfig.MemoryReclaimerThreshold
envThreads := appConfig.Threads == startupAppConfig.Threads
@@ -271,9 +270,6 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
appConfig.MaxActiveBackends = 0
}
}
if settings.ParallelBackendRequests != nil && !envParallelRequests {
appConfig.ParallelBackendRequests = *settings.ParallelBackendRequests
}
if settings.MemoryReclaimerEnabled != nil && !envMemoryReclaimerEnabled {
appConfig.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
if appConfig.MemoryReclaimerEnabled {

View File

@@ -7,6 +7,7 @@ import (
"io"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
@@ -28,6 +29,7 @@ type DistributedServices struct {
Registry *nodes.NodeRegistry
Router *nodes.SmartRouter
Health *nodes.HealthMonitor
Reconciler *nodes.ReplicaReconciler
JobStore *jobs.JobStore
Dispatcher *jobs.Dispatcher
AgentStore *agents.AgentStore
@@ -240,6 +242,16 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
DB: authDB,
})
// Create ReplicaReconciler for auto-scaling model replicas
reconciler := nodes.NewReplicaReconciler(nodes.ReplicaReconcilerOptions{
Registry: registry,
Scheduler: router,
Unloader: remoteUnloader,
DB: authDB,
Interval: 30 * time.Second,
ScaleDownDelay: 5 * time.Minute,
})
// Create ModelRouterAdapter to wire into ModelLoader
modelAdapter := nodes.NewModelRouterAdapter(router)
@@ -250,6 +262,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
Registry: registry,
Router: router,
Health: healthMon,
Reconciler: reconciler,
JobStore: jobStore,
Dispatcher: dispatcher,
AgentStore: agentStore,

View File

@@ -158,6 +158,10 @@ func New(opts ...config.AppOption) (*Application, error) {
application.modelLoader.SetModelStore(distStore)
// Start health monitor
distSvc.Health.Start(options.Context)
// Start replica reconciler for auto-scaling model replicas
if distSvc.Reconciler != nil {
go distSvc.Reconciler.Run(options.Context)
}
// In distributed mode, MCP CI jobs are executed by agent workers (not the frontend)
// because the frontend can't create MCP sessions (e.g., stdio servers using docker).
// The dispatcher still subscribes to jobs.new for persistence (result/progress subs)
@@ -439,11 +443,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
}
}
}
if settings.ParallelBackendRequests != nil {
if !options.ParallelBackendRequests {
options.ParallelBackendRequests = *settings.ParallelBackendRequests
}
}
if settings.MemoryReclaimerEnabled != nil {
// Only apply if current value is default (false), suggesting it wasn't set from env var
if !options.MemoryReclaimerEnabled {

View File

@@ -59,9 +59,7 @@ func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...mo
grpcOpts := grpcModelOpts(c, so.SystemState.Model.ModelsPath)
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
if so.ParallelBackendRequests {
defOpts = append(defOpts, model.EnableParallelRequests)
}
defOpts = append(defOpts, model.EnableParallelRequests)
if c.GRPC.Attempts != 0 {
defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts))

View File

@@ -4,211 +4,172 @@ import (
"encoding/json"
"os"
"path/filepath"
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAGI/core/state"
)
func TestAgentRunCMD_LoadAgentConfigFromFile(t *testing.T) {
// Create a temporary agent config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "agent.json")
var _ = Describe("AgentRunCMD", func() {
Describe("loadAgentConfig", func() {
It("loads agent config from file", func() {
tmpDir := GinkgoT().TempDir()
configFile := filepath.Join(tmpDir, "agent.json")
cfg := state.AgentConfig{
Name: "test-agent",
Model: "llama3",
SystemPrompt: "You are a helpful assistant",
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(configFile, data, 0644); err != nil {
t.Fatal(err)
}
cfg := state.AgentConfig{
Name: "test-agent",
Model: "llama3",
SystemPrompt: "You are a helpful assistant",
}
data, err := json.MarshalIndent(cfg, "", " ")
Expect(err).ToNot(HaveOccurred())
Expect(os.WriteFile(configFile, data, 0644)).To(Succeed())
cmd := &AgentRunCMD{
Config: configFile,
StateDir: tmpDir,
}
cmd := &AgentRunCMD{
Config: configFile,
StateDir: tmpDir,
}
loaded, err := cmd.loadAgentConfig()
if err != nil {
t.Fatalf("loadAgentConfig() error: %v", err)
}
if loaded.Name != "test-agent" {
t.Errorf("expected name %q, got %q", "test-agent", loaded.Name)
}
if loaded.Model != "llama3" {
t.Errorf("expected model %q, got %q", "llama3", loaded.Model)
}
}
loaded, err := cmd.loadAgentConfig()
Expect(err).ToNot(HaveOccurred())
Expect(loaded.Name).To(Equal("test-agent"))
Expect(loaded.Model).To(Equal("llama3"))
})
func TestAgentRunCMD_LoadAgentConfigFromPool(t *testing.T) {
tmpDir := t.TempDir()
It("loads agent config from pool", func() {
tmpDir := GinkgoT().TempDir()
pool := map[string]state.AgentConfig{
"my-agent": {
Model: "gpt-4",
Description: "A test agent",
SystemPrompt: "Hello",
},
"other-agent": {
Model: "llama3",
},
}
data, err := json.MarshalIndent(pool, "", " ")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
t.Fatal(err)
}
pool := map[string]state.AgentConfig{
"my-agent": {
Model: "gpt-4",
Description: "A test agent",
SystemPrompt: "Hello",
},
"other-agent": {
Model: "llama3",
},
}
data, err := json.MarshalIndent(pool, "", " ")
Expect(err).ToNot(HaveOccurred())
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
cmd := &AgentRunCMD{
Name: "my-agent",
StateDir: tmpDir,
}
cmd := &AgentRunCMD{
Name: "my-agent",
StateDir: tmpDir,
}
loaded, err := cmd.loadAgentConfig()
if err != nil {
t.Fatalf("loadAgentConfig() error: %v", err)
}
if loaded.Name != "my-agent" {
t.Errorf("expected name %q, got %q", "my-agent", loaded.Name)
}
if loaded.Model != "gpt-4" {
t.Errorf("expected model %q, got %q", "gpt-4", loaded.Model)
}
}
loaded, err := cmd.loadAgentConfig()
Expect(err).ToNot(HaveOccurred())
Expect(loaded.Name).To(Equal("my-agent"))
Expect(loaded.Model).To(Equal("gpt-4"))
})
func TestAgentRunCMD_LoadAgentConfigFromPool_NotFound(t *testing.T) {
tmpDir := t.TempDir()
It("returns error for missing agent in pool", func() {
tmpDir := GinkgoT().TempDir()
pool := map[string]state.AgentConfig{
"existing-agent": {Model: "llama3"},
}
data, err := json.MarshalIndent(pool, "", " ")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
t.Fatal(err)
}
pool := map[string]state.AgentConfig{
"existing-agent": {Model: "llama3"},
}
data, err := json.MarshalIndent(pool, "", " ")
Expect(err).ToNot(HaveOccurred())
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
cmd := &AgentRunCMD{
Name: "nonexistent",
StateDir: tmpDir,
}
cmd := &AgentRunCMD{
Name: "nonexistent",
StateDir: tmpDir,
}
_, err = cmd.loadAgentConfig()
if err == nil {
t.Fatal("expected error for missing agent, got nil")
}
}
_, err = cmd.loadAgentConfig()
Expect(err).To(HaveOccurred())
})
func TestAgentRunCMD_LoadAgentConfigNoNameOrConfig(t *testing.T) {
cmd := &AgentRunCMD{
StateDir: t.TempDir(),
}
It("returns error when no pool.json exists", func() {
cmd := &AgentRunCMD{
StateDir: GinkgoT().TempDir(),
}
_, err := cmd.loadAgentConfig()
if err == nil {
t.Fatal("expected error when no pool.json exists, got nil")
}
}
_, err := cmd.loadAgentConfig()
Expect(err).To(HaveOccurred())
})
func TestAgentRunCMD_ApplyOverrides(t *testing.T) {
cfg := &state.AgentConfig{
Name: "test",
}
It("returns error for config with no name", func() {
tmpDir := GinkgoT().TempDir()
configFile := filepath.Join(tmpDir, "agent.json")
cmd := &AgentRunCMD{
APIURL: "http://localhost:9090",
APIKey: "secret",
DefaultModel: "my-model",
}
cfg := state.AgentConfig{
Model: "llama3",
}
data, _ := json.MarshalIndent(cfg, "", " ")
Expect(os.WriteFile(configFile, data, 0644)).To(Succeed())
cmd.applyOverrides(cfg)
cmd := &AgentRunCMD{
Config: configFile,
StateDir: tmpDir,
}
if cfg.APIURL != "http://localhost:9090" {
t.Errorf("expected APIURL %q, got %q", "http://localhost:9090", cfg.APIURL)
}
if cfg.APIKey != "secret" {
t.Errorf("expected APIKey %q, got %q", "secret", cfg.APIKey)
}
if cfg.Model != "my-model" {
t.Errorf("expected Model %q, got %q", "my-model", cfg.Model)
}
}
_, err := cmd.loadAgentConfig()
Expect(err).To(HaveOccurred())
})
})
func TestAgentRunCMD_ApplyOverridesDoesNotOverwriteExisting(t *testing.T) {
cfg := &state.AgentConfig{
Name: "test",
Model: "existing-model",
}
Describe("applyOverrides", func() {
It("applies overrides to empty fields", func() {
cfg := &state.AgentConfig{
Name: "test",
}
cmd := &AgentRunCMD{
DefaultModel: "override-model",
}
cmd := &AgentRunCMD{
APIURL: "http://localhost:9090",
APIKey: "secret",
DefaultModel: "my-model",
}
cmd.applyOverrides(cfg)
cmd.applyOverrides(cfg)
if cfg.Model != "existing-model" {
t.Errorf("expected Model to remain %q, got %q", "existing-model", cfg.Model)
}
}
Expect(cfg.APIURL).To(Equal("http://localhost:9090"))
Expect(cfg.APIKey).To(Equal("secret"))
Expect(cfg.Model).To(Equal("my-model"))
})
func TestAgentRunCMD_LoadConfigMissingName(t *testing.T) {
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "agent.json")
It("does not overwrite existing model", func() {
cfg := &state.AgentConfig{
Name: "test",
Model: "existing-model",
}
// Agent config with no name
cfg := state.AgentConfig{
Model: "llama3",
}
data, _ := json.MarshalIndent(cfg, "", " ")
os.WriteFile(configFile, data, 0644)
cmd := &AgentRunCMD{
DefaultModel: "override-model",
}
cmd := &AgentRunCMD{
Config: configFile,
StateDir: tmpDir,
}
cmd.applyOverrides(cfg)
_, err := cmd.loadAgentConfig()
if err == nil {
t.Fatal("expected error for config with no name, got nil")
}
}
Expect(cfg.Model).To(Equal("existing-model"))
})
})
})
func TestAgentListCMD_NoPoolFile(t *testing.T) {
cmd := &AgentListCMD{
StateDir: t.TempDir(),
}
var _ = Describe("AgentListCMD", func() {
It("runs without error when no pool file exists", func() {
cmd := &AgentListCMD{
StateDir: GinkgoT().TempDir(),
}
Expect(cmd.Run(nil)).To(Succeed())
})
// Should not error, just print "no agents found"
err := cmd.Run(nil)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
}
It("runs without error with agents in pool", func() {
tmpDir := GinkgoT().TempDir()
func TestAgentListCMD_WithAgents(t *testing.T) {
tmpDir := t.TempDir()
pool := map[string]state.AgentConfig{
"agent-a": {Model: "llama3", Description: "First agent"},
"agent-b": {Model: "gpt-4"},
}
data, _ := json.MarshalIndent(pool, "", " ")
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
pool := map[string]state.AgentConfig{
"agent-a": {Model: "llama3", Description: "First agent"},
"agent-b": {Model: "gpt-4"},
}
data, _ := json.MarshalIndent(pool, "", " ")
os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)
cmd := &AgentListCMD{
StateDir: tmpDir,
}
err := cmd.Run(nil)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
}
cmd := &AgentListCMD{
StateDir: tmpDir,
}
Expect(cmd.Run(nil)).To(Succeed())
})
})

View File

@@ -0,0 +1,13 @@
package cli
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestCLI(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "CLI Suite")
}

View File

@@ -1,10 +1,9 @@
package cli
import (
"strings"
"testing"
"github.com/alecthomas/kong"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func getTestApp() *kong.Application {
@@ -21,76 +20,55 @@ func getTestApp() *kong.Application {
return k.Model
}
func TestGenerateBashCompletion(t *testing.T) {
app := getTestApp()
script := generateBashCompletion(app)
var _ = Describe("Shell completions", func() {
var app *kong.Application
if !strings.Contains(script, "complete -F _local_ai_completions local-ai") {
t.Error("bash completion missing complete command registration")
}
if !strings.Contains(script, "run") {
t.Error("bash completion missing 'run' command")
}
if !strings.Contains(script, "models") {
t.Error("bash completion missing 'models' command")
}
if !strings.Contains(script, "completion") {
t.Error("bash completion missing 'completion' command")
}
}
BeforeEach(func() {
app = getTestApp()
})
func TestGenerateZshCompletion(t *testing.T) {
app := getTestApp()
script := generateZshCompletion(app)
Describe("generateBashCompletion", func() {
It("generates valid bash completion script", func() {
script := generateBashCompletion(app)
Expect(script).To(ContainSubstring("complete -F _local_ai_completions local-ai"))
Expect(script).To(ContainSubstring("run"))
Expect(script).To(ContainSubstring("models"))
Expect(script).To(ContainSubstring("completion"))
})
})
if !strings.Contains(script, "#compdef local-ai") {
t.Error("zsh completion missing compdef header")
}
if !strings.Contains(script, "run") {
t.Error("zsh completion missing 'run' command")
}
if !strings.Contains(script, "models") {
t.Error("zsh completion missing 'models' command")
}
}
Describe("generateZshCompletion", func() {
It("generates valid zsh completion script", func() {
script := generateZshCompletion(app)
Expect(script).To(ContainSubstring("#compdef local-ai"))
Expect(script).To(ContainSubstring("run"))
Expect(script).To(ContainSubstring("models"))
})
})
func TestGenerateFishCompletion(t *testing.T) {
app := getTestApp()
script := generateFishCompletion(app)
Describe("generateFishCompletion", func() {
It("generates valid fish completion script", func() {
script := generateFishCompletion(app)
Expect(script).To(ContainSubstring("complete -c local-ai"))
Expect(script).To(ContainSubstring("__fish_use_subcommand"))
Expect(script).To(ContainSubstring("run"))
Expect(script).To(ContainSubstring("models"))
})
})
if !strings.Contains(script, "complete -c local-ai") {
t.Error("fish completion missing complete command")
}
if !strings.Contains(script, "__fish_use_subcommand") {
t.Error("fish completion missing subcommand detection")
}
if !strings.Contains(script, "run") {
t.Error("fish completion missing 'run' command")
}
if !strings.Contains(script, "models") {
t.Error("fish completion missing 'models' command")
}
}
Describe("collectCommands", func() {
It("collects all commands and subcommands", func() {
cmds := collectCommands(app.Node, "")
func TestCollectCommands(t *testing.T) {
app := getTestApp()
cmds := collectCommands(app.Node, "")
names := make(map[string]bool)
for _, cmd := range cmds {
names[cmd.fullName] = true
}
names := make(map[string]bool)
for _, cmd := range cmds {
names[cmd.fullName] = true
}
if !names["run"] {
t.Error("missing 'run' command")
}
if !names["models"] {
t.Error("missing 'models' command")
}
if !names["models list"] {
t.Error("missing 'models list' subcommand")
}
if !names["models install"] {
t.Error("missing 'models install' subcommand")
}
}
Expect(names).To(HaveKey("run"))
Expect(names).To(HaveKey("models"))
Expect(names).To(HaveKey("models list"))
Expect(names).To(HaveKey("models install"))
})
})
})

View File

@@ -74,7 +74,6 @@ type RunCMD struct {
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2p-token" aliases:"p2ptoken" help:"Token for P2P mode (optional; --p2ptoken is deprecated, use --p2p-token)" group:"p2p"`
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time (deprecated: use --max-active-backends=1 instead)" group:"backends"`
MaxActiveBackends int `env:"LOCALAI_MAX_ACTIVE_BACKENDS,MAX_ACTIVE_BACKENDS" default:"0" help:"Maximum number of backends to keep loaded at once (0 = unlimited, 1 = single backend mode). Least recently used backends are evicted when limit is reached" group:"backends"`
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
@@ -438,10 +437,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
opts = append(opts, config.WithMemoryReclaimer(true, r.MemoryReclaimerThreshold))
}
if r.ParallelRequests {
opts = append(opts, config.EnableParallelBackendRequests)
}
// Handle max active backends (LRU eviction)
// MaxActiveBackends takes precedence over SingleActiveBackend
if r.MaxActiveBackends > 0 {

View File

@@ -67,22 +67,28 @@ func isPathAllowed(path string, allowedDirs []string) bool {
//
// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that.
type WorkerCMD struct {
Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"`
// Primary address — the reachable address of this worker.
// Host is used for advertise, port is the base for gRPC backends.
// HTTP file transfer runs on port-1.
Addr string `env:"LOCALAI_ADDR" help:"Address where this worker is reachable (host:port). Port is base for gRPC backends, port-1 for HTTP." group:"server"`
ServeAddr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"(Advanced) gRPC base port bind address" group:"server" hidden:""`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"`
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"`
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"`
// HTTP file transfer
HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"`
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"`
HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server" hidden:""`
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server" hidden:""`
// Registration (required)
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"`
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""`
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"`
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
@@ -96,7 +102,7 @@ type WorkerCMD struct {
}
func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
xlog.Info("Starting worker", "addr", cmd.Addr)
xlog.Info("Starting worker", "advertise", cmd.advertiseAddr(), "basePort", cmd.effectiveBasePort())
systemState, err := system.GetSystemState(
system.WithModelPath(cmd.ModelsPath),
@@ -181,15 +187,7 @@ func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
}()
// Process supervisor — manages multiple backend gRPC processes on different ports
basePort := 50051
if cmd.Addr != "" {
// Extract port from addr (e.g., "0.0.0.0:50051" → 50051)
if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil {
if p, err := strconv.Atoi(portStr); err == nil {
basePort = p
}
}
}
basePort := cmd.effectiveBasePort()
// Buffered so NATS stop handler can send without blocking
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
@@ -667,10 +665,10 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
// Return the gRPC address so the router knows which port to use
advertiseAddr := addr
if s.cmd.AdvertiseAddr != "" {
// Replace 0.0.0.0 with the advertised host but keep the dynamic port
advAddr := s.cmd.advertiseAddr()
if advAddr != addr { // only remap if advertise differs from bind
_, port, _ := net.SplitHostPort(addr)
advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr)
advertiseHost, _, _ := net.SplitHostPort(advAddr)
advertiseAddr = net.JoinHostPort(advertiseHost, port)
}
resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr}
@@ -816,18 +814,31 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
})
}
// effectiveBasePort returns the port used as base for gRPC backend processes.
// Priority: Addr port → ServeAddr port → 50051
func (cmd *WorkerCMD) effectiveBasePort() int {
for _, addr := range []string{cmd.Addr, cmd.ServeAddr} {
if addr != "" {
if _, portStr, ok := strings.Cut(addr, ":"); ok {
if p, _ := strconv.Atoi(portStr); p > 0 {
return p
}
}
}
}
return 50051
}
// advertiseAddr returns the address the frontend should use to reach this node.
func (cmd *WorkerCMD) advertiseAddr() string {
if cmd.AdvertiseAddr != "" {
return cmd.AdvertiseAddr
}
host, port, ok := strings.Cut(cmd.Addr, ":")
if ok && (host == "0.0.0.0" || host == "") {
if hostname, err := os.Hostname(); err == nil {
return hostname + ":" + port
}
if cmd.Addr != "" {
return cmd.Addr
}
return cmd.Addr
hostname, _ := os.Hostname()
return fmt.Sprintf("%s:%d", cmp.Or(hostname, "localhost"), cmd.effectiveBasePort())
}
// resolveHTTPAddr returns the address to bind the HTTP file transfer server to.
@@ -837,12 +848,7 @@ func (cmd *WorkerCMD) resolveHTTPAddr() string {
if cmd.HTTPAddr != "" {
return cmd.HTTPAddr
}
host, port, ok := strings.Cut(cmd.Addr, ":")
if !ok {
return "0.0.0.0:50050"
}
portNum, _ := strconv.Atoi(port)
return fmt.Sprintf("%s:%d", host, portNum-1)
return fmt.Sprintf("0.0.0.0:%d", cmd.effectiveBasePort()-1)
}
// advertiseHTTPAddr returns the HTTP address the frontend should use to reach
@@ -851,14 +857,9 @@ func (cmd *WorkerCMD) advertiseHTTPAddr() string {
if cmd.AdvertiseHTTPAddr != "" {
return cmd.AdvertiseHTTPAddr
}
httpAddr := cmd.resolveHTTPAddr()
host, port, ok := strings.Cut(httpAddr, ":")
if ok && (host == "0.0.0.0" || host == "") {
if hostname, err := os.Hostname(); err == nil {
return hostname + ":" + port
}
}
return httpAddr
advHost, _, _ := strings.Cut(cmd.advertiseAddr(), ":")
httpPort := cmd.effectiveBasePort() - 1
return fmt.Sprintf("%s:%d", advHost, httpPort)
}
// registrationBody builds the JSON body for node registration.
@@ -896,6 +897,21 @@ func (cmd *WorkerCMD) registrationBody() map[string]any {
if cmd.RegistrationToken != "" {
body["token"] = cmd.RegistrationToken
}
// Parse and add static node labels
if cmd.NodeLabels != "" {
labels := make(map[string]string)
for _, pair := range strings.Split(cmd.NodeLabels, ",") {
pair = strings.TrimSpace(pair)
if k, v, ok := strings.Cut(pair, "="); ok {
labels[strings.TrimSpace(k)] = strings.TrimSpace(v)
}
}
if len(labels) > 0 {
body["labels"] = labels
}
}
return body
}

View File

@@ -0,0 +1,83 @@
package cli
import (
"os"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("WorkerCMD address resolution", func() {
Describe("effectiveBasePort", func() {
DescribeTable("returns the correct port",
func(addr, serve string, want int) {
cmd := &WorkerCMD{Addr: addr, ServeAddr: serve}
Expect(cmd.effectiveBasePort()).To(Equal(want))
},
Entry("Addr takes priority", "worker1.example.com:60000", "0.0.0.0:50051", 60000),
Entry("falls back to ServeAddr", "", "0.0.0.0:50051", 50051),
Entry("returns 50051 when neither set", "", "", 50051),
Entry("Addr with custom port", "10.0.0.5:7000", "", 7000),
Entry("invalid port in Addr falls through to ServeAddr", "host:notanumber", "0.0.0.0:9999", 9999),
)
})
Describe("advertiseAddr", func() {
It("returns AdvertiseAddr when set", func() {
cmd := &WorkerCMD{
AdvertiseAddr: "public.example.com:50051",
Addr: "10.0.0.5:60000",
}
Expect(cmd.advertiseAddr()).To(Equal("public.example.com:50051"))
})
It("returns Addr when set", func() {
cmd := &WorkerCMD{Addr: "worker1.example.com:60000"}
Expect(cmd.advertiseAddr()).To(Equal("worker1.example.com:60000"))
})
It("falls back to hostname:basePort", func() {
cmd := &WorkerCMD{ServeAddr: "0.0.0.0:50051"}
got := cmd.advertiseAddr()
_, port, _ := strings.Cut(got, ":")
Expect(port).To(Equal("50051"))
hostname, _ := os.Hostname()
if hostname != "" {
host, _, _ := strings.Cut(got, ":")
Expect(host).To(Equal(hostname))
}
})
})
Describe("resolveHTTPAddr", func() {
DescribeTable("returns the correct address",
func(httpAddr, addr, serve, want string) {
cmd := &WorkerCMD{HTTPAddr: httpAddr, Addr: addr, ServeAddr: serve}
Expect(cmd.resolveHTTPAddr()).To(Equal(want))
},
Entry("HTTPAddr takes priority", "0.0.0.0:8080", "", "", "0.0.0.0:8080"),
Entry("derives from Addr port minus 1", "", "worker1:60000", "0.0.0.0:50051", "0.0.0.0:59999"),
Entry("derives from ServeAddr port minus 1", "", "", "0.0.0.0:50051", "0.0.0.0:50050"),
Entry("default when nothing set", "", "", "", "0.0.0.0:50050"),
)
})
Describe("advertiseHTTPAddr", func() {
DescribeTable("returns the correct address",
func(advertiseHTTP, advertise, addr, serve, want string) {
cmd := &WorkerCMD{
AdvertiseHTTPAddr: advertiseHTTP,
AdvertiseAddr: advertise,
Addr: addr,
ServeAddr: serve,
}
Expect(cmd.advertiseHTTPAddr()).To(Equal(want))
},
Entry("AdvertiseHTTPAddr takes priority", "public.example.com:8080", "", "", "", "public.example.com:8080"),
Entry("derives from advertiseAddr host + basePort-1", "", "", "worker1.example.com:60000", "", "worker1.example.com:59999"),
Entry("uses AdvertiseAddr host with basePort-1", "", "public.example.com:60000", "10.0.0.5:60000", "", "public.example.com:59999"),
)
})
})

View File

@@ -59,8 +59,6 @@ type ApplicationConfig struct {
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
ParallelBackendRequests bool
WatchDogIdle bool
WatchDogBusy bool
WatchDog bool
@@ -379,10 +377,6 @@ func WithLRUEvictionRetryInterval(interval time.Duration) AppOption {
}
}
var EnableParallelBackendRequests = func(o *ApplicationConfig) {
o.ParallelBackendRequests = true
}
var EnableGalleriesAutoload = func(o *ApplicationConfig) {
o.AutoloadGalleries = true
}
@@ -842,7 +836,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
watchdogBusy := o.WatchDogBusy
singleBackend := o.SingleBackend
maxActiveBackends := o.MaxActiveBackends
parallelBackendRequests := o.ParallelBackendRequests
memoryReclaimerEnabled := o.MemoryReclaimerEnabled
memoryReclaimerThreshold := o.MemoryReclaimerThreshold
forceEvictionWhenBusy := o.ForceEvictionWhenBusy
@@ -915,7 +908,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
WatchdogInterval: &watchdogInterval,
SingleBackend: &singleBackend,
MaxActiveBackends: &maxActiveBackends,
ParallelBackendRequests: &parallelBackendRequests,
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
@@ -1008,9 +1000,6 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
}
requireRestart = true
}
if settings.ParallelBackendRequests != nil {
o.ParallelBackendRequests = *settings.ParallelBackendRequests
}
if settings.MemoryReclaimerEnabled != nil {
o.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
if *settings.MemoryReclaimerEnabled {

View File

@@ -18,7 +18,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
WatchDogBusyTimeout: 10 * time.Minute,
SingleBackend: false,
MaxActiveBackends: 5,
ParallelBackendRequests: true,
MemoryReclaimerEnabled: true,
MemoryReclaimerThreshold: 0.85,
Threads: 8,
@@ -62,9 +61,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
Expect(rs.MaxActiveBackends).ToNot(BeNil())
Expect(*rs.MaxActiveBackends).To(Equal(5))
Expect(rs.ParallelBackendRequests).ToNot(BeNil())
Expect(*rs.ParallelBackendRequests).To(BeTrue())
Expect(rs.MemoryReclaimerEnabled).ToNot(BeNil())
Expect(*rs.MemoryReclaimerEnabled).To(BeTrue())
@@ -455,7 +451,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
WatchDogBusyTimeout: 12 * time.Minute,
SingleBackend: false,
MaxActiveBackends: 3,
ParallelBackendRequests: true,
MemoryReclaimerEnabled: true,
MemoryReclaimerThreshold: 0.92,
Threads: 12,
@@ -487,7 +482,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
Expect(target.WatchDogIdleTimeout).To(Equal(original.WatchDogIdleTimeout))
Expect(target.WatchDogBusyTimeout).To(Equal(original.WatchDogBusyTimeout))
Expect(target.MaxActiveBackends).To(Equal(original.MaxActiveBackends))
Expect(target.ParallelBackendRequests).To(Equal(original.ParallelBackendRequests))
Expect(target.MemoryReclaimerEnabled).To(Equal(original.MemoryReclaimerEnabled))
Expect(target.MemoryReclaimerThreshold).To(Equal(original.MemoryReclaimerThreshold))
Expect(target.Threads).To(Equal(original.Threads))

View File

@@ -20,8 +20,6 @@ type RuntimeSettings struct {
// Backend management
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"`
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)

View File

@@ -5,6 +5,7 @@ import (
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -37,7 +38,7 @@ func nodeError(code int, message string) schema.ErrorResponse {
func ListNodesEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
nodeList, err := registry.List(ctx)
nodeList, err := registry.ListWithExtras(ctx)
if err != nil {
xlog.Error("Failed to list nodes", "error", err)
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list nodes"))
@@ -70,7 +71,8 @@ type RegisterNodeRequest struct {
AvailableVRAM uint64 `json:"available_vram,omitempty"`
TotalRAM uint64 `json:"total_ram,omitempty"`
AvailableRAM uint64 `json:"available_ram,omitempty"`
GPUVendor string `json:"gpu_vendor,omitempty"`
GPUVendor string `json:"gpu_vendor,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
}
// RegisterNodeEndpoint registers a new backend node.
@@ -148,6 +150,14 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to register node"))
}
// Store static labels from worker and apply auto-labels
if len(req.Labels) > 0 {
if err := registry.SetNodeLabels(ctx, node.ID, req.Labels); err != nil {
xlog.Warn("Failed to set node labels", "node", node.ID, "error", err)
}
}
registry.ApplyAutoLabels(ctx, node.ID, node)
response := map[string]any{
"id": node.ID,
"name": node.Name,
@@ -618,6 +628,178 @@ func NodeBackendLogsWSEndpoint(registry *nodes.NodeRegistry, registrationToken s
}
}
// GetNodeLabelsEndpoint returns labels for a specific node.
func GetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
nodeID := c.Param("id")
labels, err := registry.GetNodeLabels(ctx, nodeID)
if err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get labels"))
}
// Convert to map for cleaner API response
result := make(map[string]string)
for _, l := range labels {
result[l.Key] = l.Value
}
return c.JSON(http.StatusOK, result)
}
}
// SetNodeLabelsEndpoint replaces all labels for a node.
func SetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
nodeID := c.Param("id")
if _, err := registry.Get(ctx, nodeID); err != nil {
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
}
var labels map[string]string
if err := c.Bind(&labels); err != nil {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
}
if err := registry.SetNodeLabels(ctx, nodeID, labels); err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set labels"))
}
return c.JSON(http.StatusOK, labels)
}
}
// MergeNodeLabelsEndpoint adds/updates labels without removing existing ones.
func MergeNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
nodeID := c.Param("id")
if _, err := registry.Get(ctx, nodeID); err != nil {
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
}
var labels map[string]string
if err := c.Bind(&labels); err != nil {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
}
for k, v := range labels {
if err := registry.SetNodeLabel(ctx, nodeID, k, v); err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to merge labels"))
}
}
// Return updated labels
updated, _ := registry.GetNodeLabels(ctx, nodeID)
result := make(map[string]string)
for _, l := range updated {
result[l.Key] = l.Value
}
return c.JSON(http.StatusOK, result)
}
}
// DeleteNodeLabelEndpoint removes a single label from a node.
func DeleteNodeLabelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
nodeID := c.Param("id")
key := c.Param("key")
if key == "" {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "label key is required"))
}
if err := registry.RemoveNodeLabel(ctx, nodeID, key); err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to remove label"))
}
return c.NoContent(http.StatusNoContent)
}
}
// ListSchedulingEndpoint returns all model scheduling configs.
func ListSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
configs, err := registry.ListModelSchedulings(ctx)
if err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list scheduling configs"))
}
return c.JSON(http.StatusOK, configs)
}
}
// GetSchedulingEndpoint returns the scheduling config for a specific model.
func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
modelName := c.Param("model")
config, err := registry.GetModelScheduling(ctx, modelName)
if err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get scheduling config"))
}
if config == nil {
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "no scheduling config for model"))
}
return c.JSON(http.StatusOK, config)
}
}
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
type SetSchedulingRequest struct {
ModelName string `json:"model_name"`
NodeSelector map[string]string `json:"node_selector,omitempty"`
MinReplicas int `json:"min_replicas"`
MaxReplicas int `json:"max_replicas"`
}
// SetSchedulingEndpoint creates or updates a model scheduling config.
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
var req SetSchedulingRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
}
if req.ModelName == "" {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required"))
}
if req.MinReplicas < 0 {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0"))
}
if req.MaxReplicas < 0 {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0"))
}
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas"))
}
// Serialize node selector to JSON
var selectorJSON string
if len(req.NodeSelector) > 0 {
b, err := json.Marshal(req.NodeSelector)
if err != nil {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid node_selector"))
}
selectorJSON = string(b)
}
config := &nodes.ModelSchedulingConfig{
ModelName: req.ModelName,
NodeSelector: selectorJSON,
MinReplicas: req.MinReplicas,
MaxReplicas: req.MaxReplicas,
}
if err := registry.SetModelScheduling(ctx, config); err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
}
return c.JSON(http.StatusOK, config)
}
}
// DeleteSchedulingEndpoint removes a model scheduling config.
func DeleteSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
modelName := c.Param("model")
if err := registry.DeleteModelScheduling(ctx, modelName); err != nil {
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete scheduling config"))
}
return c.NoContent(http.StatusNoContent)
}
}
// proxyHTTPToWorker makes a GET request to a worker's HTTP server with bearer token auth.
func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error) {
reqURL := fmt.Sprintf("http://%s%s", httpAddress, path)

View File

@@ -159,6 +159,62 @@ function WorkerHintCard({ addToast, activeTab, hasWorkers }) {
)
}
function SchedulingForm({ onSave, onCancel }) {
const [modelName, setModelName] = useState('')
const [selectorText, setSelectorText] = useState('')
const [minReplicas, setMinReplicas] = useState(0)
const [maxReplicas, setMaxReplicas] = useState(0)
const handleSubmit = () => {
let nodeSelector = null
if (selectorText.trim()) {
const pairs = {}
selectorText.split(',').forEach(p => {
const [k, v] = p.split('=').map(s => s.trim())
if (k) pairs[k] = v || ''
})
nodeSelector = pairs
}
onSave({
model_name: modelName,
node_selector: nodeSelector ? JSON.stringify(nodeSelector) : '',
min_replicas: minReplicas,
max_replicas: maxReplicas,
})
}
return (
<div className="card" style={{ padding: 'var(--spacing-md)', marginBottom: 'var(--spacing-md)' }}>
<div style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: 'var(--spacing-sm)' }}>
<div>
<label style={{ fontSize: '0.75rem', fontWeight: 500 }}>Model Name</label>
<input type="text" value={modelName} onChange={e => setModelName(e.target.value)}
placeholder="e.g. llama3" style={{ width: '100%' }} />
</div>
<div>
<label style={{ fontSize: '0.75rem', fontWeight: 500 }}>Node Selector (key=value, comma-separated)</label>
<input type="text" value={selectorText} onChange={e => setSelectorText(e.target.value)}
placeholder="e.g. gpu.vendor=nvidia,tier=fast" style={{ width: '100%' }} />
</div>
<div>
<label style={{ fontSize: '0.75rem', fontWeight: 500 }}>Min Replicas (0 = no minimum)</label>
<input type="number" min={0} value={minReplicas} onChange={e => setMinReplicas(parseInt(e.target.value) || 0)}
style={{ width: '100%' }} />
</div>
<div>
<label style={{ fontSize: '0.75rem', fontWeight: 500 }}>Max Replicas (0 = unlimited)</label>
<input type="number" min={0} value={maxReplicas} onChange={e => setMaxReplicas(parseInt(e.target.value) || 0)}
style={{ width: '100%' }} />
</div>
</div>
<div style={{ display: 'flex', gap: 'var(--spacing-sm)', marginTop: 'var(--spacing-sm)', justifyContent: 'flex-end' }}>
<button className="btn btn-secondary btn-sm" onClick={onCancel}>Cancel</button>
<button className="btn btn-primary btn-sm" onClick={handleSubmit} disabled={!modelName}>Save</button>
</div>
</div>
)
}
export default function Nodes() {
const { addToast } = useOutletContext()
const navigate = useNavigate()
@@ -170,7 +226,9 @@ export default function Nodes() {
const [nodeBackends, setNodeBackends] = useState({})
const [confirmDelete, setConfirmDelete] = useState(null)
const [showTips, setShowTips] = useState(false)
const [activeTab, setActiveTab] = useState('backend') // 'backend' or 'agent'
const [activeTab, setActiveTab] = useState('backend') // 'backend', 'agent', or 'scheduling'
const [schedulingConfigs, setSchedulingConfigs] = useState([])
const [showSchedulingForm, setShowSchedulingForm] = useState(false)
const fetchNodes = useCallback(async () => {
try {
@@ -186,11 +244,19 @@ export default function Nodes() {
}
}, [])
const fetchScheduling = useCallback(async () => {
try {
const data = await nodesApi.listScheduling()
setSchedulingConfigs(Array.isArray(data) ? data : [])
} catch { setSchedulingConfigs([]) }
}, [])
useEffect(() => {
fetchNodes()
fetchScheduling()
const interval = setInterval(fetchNodes, 5000)
return () => clearInterval(interval)
}, [fetchNodes])
}, [fetchNodes, fetchScheduling])
const fetchModels = useCallback(async (nodeId) => {
try {
@@ -254,6 +320,36 @@ export default function Nodes() {
}
}
const handleUnloadModel = async (nodeId, modelName) => {
try {
await nodesApi.unloadModel(nodeId, modelName)
addToast(`Model "${modelName}" unloaded`, 'success')
fetchModels(nodeId)
} catch (err) {
addToast(`Failed to unload model: ${err.message}`, 'error')
}
}
const handleAddLabel = async (nodeId, key, value) => {
try {
await nodesApi.mergeLabels(nodeId, { [key]: value })
addToast(`Label "${key}=${value}" added`, 'success')
fetchNodes()
} catch (err) {
addToast(`Failed to add label: ${err.message}`, 'error')
}
}
const handleDeleteLabel = async (nodeId, key) => {
try {
await nodesApi.deleteLabel(nodeId, key)
addToast(`Label "${key}" removed`, 'success')
fetchNodes()
} catch (err) {
addToast(`Failed to remove label: ${err.message}`, 'error')
}
}
const handleDelete = async (nodeId) => {
try {
await nodesApi.delete(nodeId)
@@ -422,8 +518,23 @@ export default function Nodes() {
<i className="fas fa-robot" style={{ marginRight: 6 }} />
Agent Workers ({agentNodes.length})
</button>
<button
onClick={() => setActiveTab('scheduling')}
style={{
padding: 'var(--spacing-sm) var(--spacing-lg)',
border: 'none', cursor: 'pointer', fontWeight: 600, fontSize: '0.875rem',
background: 'none',
color: activeTab === 'scheduling' ? 'var(--color-primary)' : 'var(--color-text-muted)',
borderBottom: activeTab === 'scheduling' ? '2px solid var(--color-primary)' : '2px solid transparent',
marginBottom: '-2px',
}}
>
<i className="fas fa-calendar-alt" style={{ marginRight: 6 }} />
Scheduling ({schedulingConfigs.length})
</button>
</div>
{activeTab !== 'scheduling' && <>
{/* Stat cards */}
<div style={{ display: 'flex', gap: 'var(--spacing-md)', marginBottom: 'var(--spacing-xl)', flexWrap: 'wrap' }}>
<StatCard icon={activeTab === 'agent' ? 'fas fa-robot' : 'fas fa-server'} label={`Total ${activeTab === 'agent' ? 'Agent' : 'Backend'} Workers`} value={total} />
@@ -433,6 +544,23 @@ export default function Nodes() {
{pending > 0 && (
<StatCard icon="fas fa-clock" label="Pending" value={pending} color="var(--color-warning)" />
)}
{activeTab === 'backend' && (() => {
const clusterTotalVRAM = backendNodes.reduce((sum, n) => sum + (n.total_vram || 0), 0)
const clusterUsedVRAM = backendNodes.reduce((sum, n) => {
if (n.total_vram && n.available_vram != null) return sum + (n.total_vram - n.available_vram)
return sum
}, 0)
const totalModelsLoaded = backendNodes.reduce((sum, n) => sum + (n.model_count || 0), 0)
return (
<>
{clusterTotalVRAM > 0 && (
<StatCard icon="fas fa-microchip" label="Cluster VRAM"
value={`${formatVRAM(clusterUsedVRAM) || '0'} / ${formatVRAM(clusterTotalVRAM)}`} />
)}
<StatCard icon="fas fa-cube" label="Models Loaded" value={totalModelsLoaded} />
</>
)
})()}
</div>
{/* Worker tips */}
@@ -506,6 +634,22 @@ export default function Nodes() {
<div style={{ fontSize: '0.75rem', fontFamily: "'JetBrains Mono', monospace", color: 'var(--color-text-muted)' }}>
{node.address}
</div>
{node.labels && Object.keys(node.labels).length > 0 && (
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 3, marginTop: 3 }}>
{Object.entries(node.labels).slice(0, 5).map(([k, v]) => (
<span key={k} style={{
fontSize: '0.625rem', padding: '1px 5px', borderRadius: 3,
background: 'var(--color-bg-tertiary)', color: 'var(--color-text-muted)',
fontFamily: "'JetBrains Mono', monospace", border: '1px solid var(--color-border-subtle)',
}}>{k}={v}</span>
))}
{Object.keys(node.labels).length > 5 && (
<span style={{ fontSize: '0.625rem', color: 'var(--color-text-muted)' }}>
+{Object.keys(node.labels).length - 5} more
</span>
)}
</div>
)}
</div>
</div>
</td>
@@ -593,6 +737,7 @@ export default function Nodes() {
<th>State</th>
<th>In-Flight</th>
<th style={{ width: 40 }}>Logs</th>
<th style={{ textAlign: 'right' }}>Actions</th>
</tr>
</thead>
<tbody>
@@ -628,6 +773,21 @@ export default function Nodes() {
<i className="fas fa-terminal" />
</a>
</td>
<td style={{ textAlign: 'right' }}>
<button
className="btn btn-danger btn-sm"
disabled={m.in_flight > 0}
title={m.in_flight > 0 ? 'Cannot unload while serving requests' : 'Unload model'}
onClick={(e) => {
e.stopPropagation()
if (confirm(`Unload "${m.model_name}" from ${node.name}?`)) {
handleUnloadModel(node.id, m.model_name)
}
}}
>
<i className="fas fa-stop" />
</button>
</td>
</tr>
)
})}
@@ -689,6 +849,50 @@ export default function Nodes() {
</tbody>
</table>
)}
{/* Labels */}
<div style={{ marginTop: 'var(--spacing-md)' }}>
<h4 style={{ fontSize: '0.8125rem', fontWeight: 600, marginBottom: 'var(--spacing-sm)', color: 'var(--color-text-secondary)' }}>
<i className="fas fa-tags" style={{ marginRight: 6 }} />
Labels
</h4>
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 'var(--spacing-xs)', marginBottom: 'var(--spacing-sm)' }}>
{node.labels && Object.entries(node.labels).map(([k, v]) => (
<span key={k} style={{
display: 'inline-flex', alignItems: 'center', gap: 4,
fontSize: '0.75rem', padding: '2px 8px', borderRadius: 4,
background: 'var(--color-bg-tertiary)', border: '1px solid var(--color-border-subtle)',
fontFamily: "'JetBrains Mono', monospace",
}}>
{k}={v}
<button
onClick={(e) => { e.stopPropagation(); handleDeleteLabel(node.id, k) }}
style={{ background: 'none', border: 'none', cursor: 'pointer', color: 'var(--color-text-muted)', fontSize: '0.625rem', padding: 0 }}
title="Remove label"
>
<i className="fas fa-times" />
</button>
</span>
))}
</div>
{/* Add label form */}
<div style={{ display: 'flex', gap: 'var(--spacing-xs)', alignItems: 'center' }}>
<input
type="text" placeholder="key" style={{ width: 100, fontSize: '0.75rem' }}
id={`label-key-${node.id}`}
/>
<input
type="text" placeholder="value" style={{ width: 100, fontSize: '0.75rem' }}
id={`label-value-${node.id}`}
/>
<button className="btn btn-secondary btn-sm" onClick={(e) => {
e.stopPropagation()
const key = document.getElementById(`label-key-${node.id}`).value.trim()
const val = document.getElementById(`label-value-${node.id}`).value.trim()
if (key) handleAddLabel(node.id, key, val)
}}>Add</button>
</div>
</div>
</div>
</td>
</tr>
@@ -700,6 +904,78 @@ export default function Nodes() {
</table>
</div>
)}
</>}
{activeTab === 'scheduling' && (
<div>
<button className="btn btn-primary btn-sm" style={{ marginBottom: 'var(--spacing-md)' }}
onClick={() => setShowSchedulingForm(f => !f)}>
<i className="fas fa-plus" style={{ marginRight: 6 }} />
Add Scheduling Rule
</button>
{showSchedulingForm && <SchedulingForm onSave={async (config) => {
try {
await nodesApi.setScheduling(config)
fetchScheduling()
setShowSchedulingForm(false)
addToast('Scheduling rule saved', 'success')
} catch (err) {
addToast(`Failed to save rule: ${err.message}`, 'error')
}
}} onCancel={() => setShowSchedulingForm(false)} />}
{schedulingConfigs.length === 0 && !showSchedulingForm ? (
<p style={{ fontSize: '0.875rem', color: 'var(--color-text-muted)', textAlign: 'center', padding: 'var(--spacing-xl) 0' }}>
No scheduling rules configured. Add a rule to control how models are placed on nodes.
</p>
) : schedulingConfigs.length > 0 && (
<div className="table-container">
<table className="table">
<thead><tr>
<th>Model</th>
<th>Node Selector</th>
<th>Min Replicas</th>
<th>Max Replicas</th>
<th style={{ textAlign: 'right' }}>Actions</th>
</tr></thead>
<tbody>
{schedulingConfigs.map(cfg => (
<tr key={cfg.id || cfg.model_name}>
<td style={{ fontWeight: 600, fontSize: '0.875rem' }}>{cfg.model_name}</td>
<td>
{cfg.node_selector ? (() => {
try {
const sel = typeof cfg.node_selector === 'string' ? JSON.parse(cfg.node_selector) : cfg.node_selector
return Object.entries(sel).map(([k,v]) => (
<span key={k} style={{
display: 'inline-block', fontSize: '0.75rem', padding: '2px 6px', borderRadius: 3,
background: 'var(--color-bg-tertiary)', border: '1px solid var(--color-border-subtle)',
fontFamily: "'JetBrains Mono', monospace", marginRight: 4,
}}>{k}={v}</span>
))
} catch { return <span style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem' }}>{cfg.node_selector}</span> }
})() : <span style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem' }}>Any node</span>}
</td>
<td style={{ fontFamily: "'JetBrains Mono', monospace" }}>{cfg.min_replicas || '-'}</td>
<td style={{ fontFamily: "'JetBrains Mono', monospace" }}>{cfg.max_replicas || 'unlimited'}</td>
<td style={{ textAlign: 'right' }}>
<button className="btn btn-danger btn-sm" onClick={async () => {
try {
await nodesApi.deleteScheduling(cfg.model_name)
fetchScheduling()
addToast('Rule deleted', 'success')
} catch (err) {
addToast(`Failed to delete rule: ${err.message}`, 'error')
}
}}><i className="fas fa-trash" /></button>
</td>
</tr>
))}
</tbody>
</table>
</div>
)}
</div>
)}
<ConfirmDialog
open={!!confirmDelete}

View File

@@ -266,9 +266,6 @@ export default function Settings() {
<SettingRow label="Max Active Backends" description="Maximum models to keep loaded simultaneously (0 = unlimited)">
<input className="input" type="number" style={{ width: 120 }} value={settings.max_active_backends ?? ''} onChange={(e) => update('max_active_backends', parseInt(e.target.value) || 0)} placeholder="0" />
</SettingRow>
<SettingRow label="Parallel Backend Requests" description="Enable parallel request handling per backend">
<Toggle checked={settings.parallel_backend_requests} onChange={(v) => update('parallel_backend_requests', v)} />
</SettingRow>
</div>
</div>

View File

@@ -423,6 +423,13 @@ export const nodesApi = {
deleteBackend: (id, backend) => postJSON(API_CONFIG.endpoints.nodeBackendsDelete(id), { backend }),
getBackendLogs: (id) => fetchJSON(API_CONFIG.endpoints.nodeBackendLogs(id)),
getBackendLogLines: (id, modelId) => fetchJSON(API_CONFIG.endpoints.nodeBackendLogsModel(id, modelId)),
unloadModel: (id, modelName) => postJSON(API_CONFIG.endpoints.nodeModelsUnload(id), { model_name: modelName }),
getLabels: (id) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id)),
mergeLabels: (id, labels) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id), { method: 'PATCH', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(labels) }),
deleteLabel: (id, key) => fetchJSON(API_CONFIG.endpoints.nodeLabelKey(id, key), { method: 'DELETE' }),
listScheduling: () => fetchJSON(API_CONFIG.endpoints.nodesScheduling),
setScheduling: (config) => postJSON(API_CONFIG.endpoints.nodesScheduling, config),
deleteScheduling: (model) => fetchJSON(API_CONFIG.endpoints.nodesSchedulingModel(model), { method: 'DELETE' }),
}
// File to base64 helper

View File

@@ -107,5 +107,10 @@ export const API_CONFIG = {
nodeBackendsDelete: (id) => `/api/nodes/${id}/backends/delete`,
nodeBackendLogs: (id) => `/api/nodes/${id}/backend-logs`,
nodeBackendLogsModel: (id, modelId) => `/api/nodes/${id}/backend-logs/${encodeURIComponent(modelId)}`,
nodeModelsUnload: (id) => `/api/nodes/${id}/models/unload`,
nodeLabels: (id) => `/api/nodes/${id}/labels`,
nodeLabelKey: (id, key) => `/api/nodes/${id}/labels/${key}`,
nodesScheduling: '/api/nodes/scheduling',
nodesSchedulingModel: (model) => `/api/nodes/scheduling/${encodeURIComponent(model)}`,
},
}

View File

@@ -61,6 +61,13 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
admin := e.Group("/api/nodes", readyMw, adminMw)
admin.GET("", localai.ListNodesEndpoint(registry))
// Model scheduling (registered before /:id to avoid route conflicts)
admin.GET("/scheduling", localai.ListSchedulingEndpoint(registry))
admin.GET("/scheduling/:model", localai.GetSchedulingEndpoint(registry))
admin.POST("/scheduling", localai.SetSchedulingEndpoint(registry))
admin.DELETE("/scheduling/:model", localai.DeleteSchedulingEndpoint(registry))
admin.GET("/:id", localai.GetNodeEndpoint(registry))
admin.GET("/:id/models", localai.GetNodeModelsEndpoint(registry))
admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry))
@@ -80,6 +87,12 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
admin.GET("/:id/backend-logs", localai.NodeBackendLogsListEndpoint(registry, registrationToken))
admin.GET("/:id/backend-logs/:modelId", localai.NodeBackendLogsLinesEndpoint(registry, registrationToken))
// Label management
admin.GET("/:id/labels", localai.GetNodeLabelsEndpoint(registry))
admin.PUT("/:id/labels", localai.SetNodeLabelsEndpoint(registry))
admin.PATCH("/:id/labels", localai.MergeNodeLabelsEndpoint(registry))
admin.DELETE("/:id/labels/:key", localai.DeleteNodeLabelEndpoint(registry))
// WebSocket proxy for real-time log streaming from workers
e.GET("/ws/nodes/:id/backend-logs/:modelId", localai.NodeBackendLogsWSEndpoint(registry, registrationToken), readyMw, adminMw)
}

View File

@@ -21,6 +21,12 @@ type ModelRouter interface {
FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (*NodeModel, error)
FindLRUModel(ctx context.Context, nodeID string) (*NodeModel, error)
Get(ctx context.Context, nodeID string) (*BackendNode, error)
GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error)
FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error)
FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error)
FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
}
// NodeHealthStore is used by HealthMonitor for node status management.

View File

@@ -72,6 +72,24 @@ func (f *fakeModelRouterForSmartRouter) Get(_ context.Context, nodeID string) (*
}
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
return nil, nil
}
// Compile-time check
var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil)

View File

@@ -0,0 +1,236 @@
package nodes
import (
"context"
"encoding/json"
"time"
"github.com/mudler/LocalAI/core/services/advisorylock"
"github.com/mudler/xlog"
"gorm.io/gorm"
)
// ReplicaReconciler periodically ensures model replica counts match their
// scheduling configs. It scales up replicas when below MinReplicas or when
// all replicas are busy (up to MaxReplicas), and scales down idle replicas
// above MinReplicas.
//
// Only processes models with auto-scaling enabled (MinReplicas > 0 or MaxReplicas > 0).
type ReplicaReconciler struct {
registry *NodeRegistry
scheduler ModelScheduler // interface for scheduling new models
unloader NodeCommandSender
db *gorm.DB
interval time.Duration
scaleDownDelay time.Duration
}
// ModelScheduler abstracts the scheduling logic needed by the reconciler.
// SmartRouter implements this interface.
type ModelScheduler interface {
// ScheduleAndLoadModel picks a node (optionally from candidateNodeIDs),
// installs the backend, and loads the model. Returns the node used.
ScheduleAndLoadModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error)
}
// ReplicaReconcilerOptions holds configuration for creating a ReplicaReconciler.
type ReplicaReconcilerOptions struct {
Registry *NodeRegistry
Scheduler ModelScheduler
Unloader NodeCommandSender
DB *gorm.DB
Interval time.Duration // default 30s
ScaleDownDelay time.Duration // default 5m
}
// NewReplicaReconciler creates a new ReplicaReconciler.
func NewReplicaReconciler(opts ReplicaReconcilerOptions) *ReplicaReconciler {
interval := opts.Interval
if interval == 0 {
interval = 30 * time.Second
}
scaleDownDelay := opts.ScaleDownDelay
if scaleDownDelay == 0 {
scaleDownDelay = 5 * time.Minute
}
return &ReplicaReconciler{
registry: opts.Registry,
scheduler: opts.Scheduler,
unloader: opts.Unloader,
db: opts.DB,
interval: interval,
scaleDownDelay: scaleDownDelay,
}
}
// Run starts the reconciliation loop. It blocks until ctx is cancelled.
func (rc *ReplicaReconciler) Run(ctx context.Context) {
ticker := time.NewTicker(rc.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rc.reconcileOnce(ctx)
}
}
}
// reconcileOnce performs a single reconciliation pass.
// Uses an advisory lock so only one frontend instance reconciles at a time.
func (rc *ReplicaReconciler) reconcileOnce(ctx context.Context) {
if rc.db != nil {
lockKey := advisorylock.KeyFromString("replica-reconciler")
_ = advisorylock.WithLockCtx(ctx, rc.db, lockKey, func() error {
rc.reconcile(ctx)
return nil
})
} else {
rc.reconcile(ctx)
}
}
func (rc *ReplicaReconciler) reconcile(ctx context.Context) {
configs, err := rc.registry.ListAutoScalingConfigs(ctx)
if err != nil {
xlog.Warn("Reconciler: failed to list auto-scaling configs", "error", err)
return
}
for _, cfg := range configs {
if err := ctx.Err(); err != nil {
return // context cancelled
}
rc.reconcileModel(ctx, cfg)
}
}
func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedulingConfig) {
current, err := rc.registry.CountLoadedReplicas(ctx, cfg.ModelName)
if err != nil {
xlog.Warn("Reconciler: failed to count replicas", "model", cfg.ModelName, "error", err)
return
}
// 1. Ensure minimum replicas
if cfg.MinReplicas > 0 && int(current) < cfg.MinReplicas {
needed := cfg.MinReplicas - int(current)
xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName,
"current", current, "min", cfg.MinReplicas, "adding", needed)
rc.scaleUp(ctx, cfg, needed)
return
}
// 2. Auto-scale up if all replicas are busy
if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
if rc.allReplicasBusy(ctx, cfg.ModelName) {
xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName,
"current", current)
rc.scaleUp(ctx, cfg, 1)
}
}
// 3. Scale down idle replicas above minimum
floor := cfg.MinReplicas
if floor < 1 {
floor = 1
}
if int(current) > floor {
rc.scaleDownIdle(ctx, cfg, int(current), floor)
}
}
// scaleUp schedules additional replicas of the model.
func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) {
if rc.scheduler == nil {
xlog.Warn("Reconciler: no scheduler available, cannot scale up")
return
}
// Determine candidate nodes from selector
var candidateNodeIDs []string
if cfg.NodeSelector != "" {
selector := parseSelector(cfg.NodeSelector)
if len(selector) > 0 {
candidates, err := rc.registry.FindNodesBySelector(ctx, selector)
if err != nil || len(candidates) == 0 {
xlog.Warn("Reconciler: no nodes match selector", "model", cfg.ModelName,
"selector", cfg.NodeSelector)
return
}
candidateNodeIDs = make([]string, len(candidates))
for i, n := range candidates {
candidateNodeIDs[i] = n.ID
}
}
}
for i := 0; i < count; i++ {
node, err := rc.scheduler.ScheduleAndLoadModel(ctx, cfg.ModelName, candidateNodeIDs)
if err != nil {
xlog.Warn("Reconciler: failed to scale up replica", "model", cfg.ModelName,
"attempt", i+1, "error", err)
return // stop trying on first failure
}
xlog.Info("Reconciler: scaled up replica", "model", cfg.ModelName, "node", node.Name)
}
}
// scaleDownIdle removes idle replicas above the floor.
func (rc *ReplicaReconciler) scaleDownIdle(ctx context.Context, cfg ModelSchedulingConfig, current, floor int) {
if rc.unloader == nil {
return
}
// Find idle replicas that have been unused for longer than scaleDownDelay
cutoff := time.Now().Add(-rc.scaleDownDelay)
var idleModels []NodeModel
rc.registry.db.WithContext(ctx).
Where("model_name = ? AND state = ? AND in_flight = 0 AND last_used < ?",
cfg.ModelName, "loaded", cutoff).
Order("last_used ASC").
Find(&idleModels)
toRemove := current - floor
removed := 0
for _, nm := range idleModels {
if removed >= toRemove {
break
}
// Remove from registry
if err := rc.registry.RemoveNodeModel(ctx, nm.NodeID, nm.ModelName); err != nil {
xlog.Warn("Reconciler: failed to remove model record", "error", err)
continue
}
// Unload from worker
if err := rc.unloader.UnloadModelOnNode(nm.NodeID, nm.ModelName); err != nil {
xlog.Warn("Reconciler: unload failed (model already removed from registry)", "error", err)
}
xlog.Info("Reconciler: scaled down idle replica", "model", cfg.ModelName, "node", nm.NodeID)
removed++
}
}
// allReplicasBusy returns true if all loaded replicas of a model have in-flight requests.
func (rc *ReplicaReconciler) allReplicasBusy(ctx context.Context, modelName string) bool {
var idleCount int64
rc.registry.db.WithContext(ctx).Model(&NodeModel{}).
Where("model_name = ? AND state = ? AND in_flight = 0", modelName, "loaded").
Count(&idleCount)
return idleCount == 0
}
// parseSelector decodes a JSON node selector string into a map.
func parseSelector(selectorJSON string) map[string]string {
if selectorJSON == "" {
return nil
}
var selector map[string]string
if err := json.Unmarshal([]byte(selectorJSON), &selector); err != nil {
xlog.Warn("Failed to parse node selector", "selector", selectorJSON, "error", err)
return nil
}
return selector
}

View File

@@ -0,0 +1,241 @@
package nodes
import (
"context"
"runtime"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/services/testutil"
"gorm.io/gorm"
)
// ---------------------------------------------------------------------------
// Fake ModelScheduler
// ---------------------------------------------------------------------------
type fakeScheduler struct {
scheduleNode *BackendNode
scheduleErr error
scheduleCalls []scheduleCall
}
type scheduleCall struct {
modelName string
candidateIDs []string
}
func (f *fakeScheduler) ScheduleAndLoadModel(_ context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) {
f.scheduleCalls = append(f.scheduleCalls, scheduleCall{modelName, candidateNodeIDs})
return f.scheduleNode, f.scheduleErr
}
var _ = Describe("ReplicaReconciler", func() {
var (
db *gorm.DB
registry *NodeRegistry
)
BeforeEach(func() {
if runtime.GOOS == "darwin" {
Skip("testcontainers requires Docker, not available on macOS CI")
}
db = testutil.SetupTestDB()
var err error
registry, err = NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
// Helper to register a healthy node.
registerNode := func(name, address string) *BackendNode {
node := &BackendNode{
Name: name,
NodeType: NodeTypeBackend,
Address: address,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
return node
}
// Helper to set up a scheduling config.
setSchedulingConfig := func(modelName string, minReplicas, maxReplicas int, nodeSelector string) {
cfg := &ModelSchedulingConfig{
ModelName: modelName,
MinReplicas: minReplicas,
MaxReplicas: maxReplicas,
NodeSelector: nodeSelector,
}
Expect(registry.SetModelScheduling(context.Background(), cfg)).To(Succeed())
}
Context("model below min_replicas", func() {
It("scales up to min_replicas", func() {
node := registerNode("node-1", "10.0.0.1:50051")
setSchedulingConfig("model-a", 2, 4, "")
scheduler := &fakeScheduler{
scheduleNode: node,
}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Scheduler: scheduler,
DB: db,
})
// No replicas loaded — should schedule 2
reconciler.reconcile(context.Background())
Expect(scheduler.scheduleCalls).To(HaveLen(2))
Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-a"))
Expect(scheduler.scheduleCalls[1].modelName).To(Equal("model-a"))
})
})
Context("all replicas busy and below max_replicas", func() {
It("scales up by 1", func() {
node := registerNode("node-busy", "10.0.0.2:50051")
setSchedulingConfig("model-b", 1, 4, "")
// Load 2 replicas, both busy (in_flight > 0)
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", "loaded", "addr1", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-b")).To(Succeed())
node2 := registerNode("node-busy-2", "10.0.0.3:50051")
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-b", "loaded", "addr2", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-b")).To(Succeed())
scheduler := &fakeScheduler{
scheduleNode: node,
}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Scheduler: scheduler,
DB: db,
})
reconciler.reconcile(context.Background())
Expect(scheduler.scheduleCalls).To(HaveLen(1))
Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-b"))
})
})
Context("all replicas busy and at max_replicas", func() {
It("does not scale up", func() {
node := registerNode("node-max", "10.0.0.4:50051")
setSchedulingConfig("model-c", 1, 2, "")
// Load 2 replicas (at max), both busy
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-c", "loaded", "addr1", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-c")).To(Succeed())
node2 := registerNode("node-max-2", "10.0.0.5:50051")
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-c", "loaded", "addr2", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-c")).To(Succeed())
scheduler := &fakeScheduler{
scheduleNode: node,
}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Scheduler: scheduler,
DB: db,
})
reconciler.reconcile(context.Background())
Expect(scheduler.scheduleCalls).To(BeEmpty())
})
})
Context("idle replicas above min_replicas", func() {
It("scales down after idle delay", func() {
node1 := registerNode("node-idle-1", "10.0.0.6:50051")
node2 := registerNode("node-idle-2", "10.0.0.7:50051")
node3 := registerNode("node-idle-3", "10.0.0.8:50051")
setSchedulingConfig("model-d", 1, 4, "")
// Load 3 replicas, all idle with last_used in the past
pastTime := time.Now().Add(-10 * time.Minute)
for _, n := range []*BackendNode{node1, node2, node3} {
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-d", "loaded", "", 0)).To(Succeed())
// Set last_used to past time to trigger scale-down
db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-d").
Update("last_used", pastTime)
}
unloader := &fakeUnloader{}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Unloader: unloader,
DB: db,
ScaleDownDelay: 1 * time.Minute, // short delay for test
})
reconciler.reconcile(context.Background())
// Should scale down 2 replicas (3 - floor of 1)
Expect(unloader.unloadCalls).To(HaveLen(2))
})
})
Context("idle replicas at min_replicas", func() {
It("does not scale down", func() {
node1 := registerNode("node-keep-1", "10.0.0.9:50051")
node2 := registerNode("node-keep-2", "10.0.0.10:50051")
setSchedulingConfig("model-e", 2, 4, "")
// Load exactly 2 replicas (at min), both idle with past last_used
pastTime := time.Now().Add(-10 * time.Minute)
for _, n := range []*BackendNode{node1, node2} {
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-e", "loaded", "", 0)).To(Succeed())
db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-e").
Update("last_used", pastTime)
}
unloader := &fakeUnloader{}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Unloader: unloader,
DB: db,
ScaleDownDelay: 1 * time.Minute,
})
reconciler.reconcile(context.Background())
Expect(unloader.unloadCalls).To(BeEmpty())
})
})
Context("model with node_selector", func() {
It("passes candidate node IDs to scheduler", func() {
node1 := registerNode("gpu-node", "10.0.0.11:50051")
node2 := registerNode("cpu-node", "10.0.0.12:50051")
// Add labels — only node1 matches the selector
Expect(registry.SetNodeLabel(context.Background(), node1.ID, "gpu.vendor", "nvidia")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node2.ID, "gpu.vendor", "none")).To(Succeed())
setSchedulingConfig("model-f", 1, 2, `{"gpu.vendor":"nvidia"}`)
scheduler := &fakeScheduler{
scheduleNode: node1,
}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Scheduler: scheduler,
DB: db,
})
// No replicas loaded — should schedule 1 with candidate node IDs
reconciler.reconcile(context.Background())
Expect(scheduler.scheduleCalls).To(HaveLen(1))
Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-f"))
Expect(scheduler.scheduleCalls[0].candidateIDs).To(ContainElement(node1.ID))
Expect(scheduler.scheduleCalls[0].candidateIDs).ToNot(ContainElement(node2.ID))
})
})
})

View File

@@ -68,6 +68,39 @@ type NodeModel struct {
UpdatedAt time.Time `json:"updated_at"`
}
// NodeLabel is a key-value label on a node (like K8s labels).
type NodeLabel struct {
ID string `gorm:"primaryKey;size:36" json:"id"`
NodeID string `gorm:"uniqueIndex:idx_node_label;size:36" json:"node_id"`
Key string `gorm:"uniqueIndex:idx_node_label;size:128" json:"key"`
Value string `gorm:"size:255" json:"value"`
}
// ModelSchedulingConfig defines how a model should be scheduled across the cluster.
// All fields are optional:
// - NodeSelector only → constrain nodes, single replica
// - MinReplicas/MaxReplicas only → auto-scale on any node
// - Both → auto-scale on matching nodes
// - Neither → no-op (default behavior)
//
// Auto-scaling is enabled when MinReplicas > 0 or MaxReplicas > 0.
type ModelSchedulingConfig struct {
ID string `gorm:"primaryKey;size:36" json:"id"`
ModelName string `gorm:"uniqueIndex;size:255" json:"model_name"`
NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...}
MinReplicas int `gorm:"default:0" json:"min_replicas"`
MaxReplicas int `gorm:"default:0" json:"max_replicas"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// NodeWithExtras extends BackendNode with computed fields for list views.
type NodeWithExtras struct {
BackendNode
ModelCount int `json:"model_count"`
Labels map[string]string `json:"labels,omitempty"`
}
// NodeRegistry manages backend node registration and lookup in PostgreSQL.
type NodeRegistry struct {
db *gorm.DB
@@ -78,7 +111,7 @@ type NodeRegistry struct {
// when multiple instances (frontend + workers) start at the same time.
func NewNodeRegistry(db *gorm.DB) (*NodeRegistry, error) {
if err := advisorylock.WithLockCtx(context.Background(), db, advisorylock.KeySchemaMigrate, func() error {
return db.AutoMigrate(&BackendNode{}, &NodeModel{})
return db.AutoMigrate(&BackendNode{}, &NodeModel{}, &NodeLabel{}, &ModelSchedulingConfig{})
}); err != nil {
return nil, fmt.Errorf("migrating node tables: %w", err)
}
@@ -207,18 +240,19 @@ func (r *NodeRegistry) FindNodeWithVRAM(ctx context.Context, minBytes uint64) (*
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
Group("node_id")
// Try idle nodes with enough VRAM first
// Try idle nodes with enough VRAM first, prefer the one with most free VRAM
var node BackendNode
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, minBytes, loadedModels).
Order("available_vram DESC").
First(&node).Error
if err == nil {
return &node, nil
}
// Fall back to least-loaded nodes with enough VRAM
// Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
err = db.Where("status = ? AND node_type = ? AND available_vram >= ?", StatusHealthy, NodeTypeBackend, minBytes).
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
Order("COALESCE(load.total_inflight, 0) ASC").
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
First(&node).Error
if err != nil {
return nil, fmt.Errorf("no healthy nodes with %d bytes available VRAM: %w", minBytes, err)
@@ -407,9 +441,12 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
var node BackendNode
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Order by in_flight ASC (least busy replica), then by available_vram DESC
// (prefer nodes with more free VRAM to spread load across the cluster).
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("model_name = ? AND state = ?", modelName, "loaded").
Order("in_flight ASC").
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
Where("node_models.model_name = ? AND node_models.state = ?", modelName, "loaded").
Order("node_models.in_flight ASC, backend_nodes.available_vram DESC").
First(&nm).Error; err != nil {
return err
}
@@ -463,7 +500,7 @@ func (r *NodeRegistry) FindLeastLoadedNode(ctx context.Context) (*BackendNode, e
Group("node_id")
err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
Order("COALESCE(load.total_inflight, 0) ASC").
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
First(&node).Error
if err != nil {
return nil, fmt.Errorf("finding least loaded node: %w", err)
@@ -482,6 +519,7 @@ func (r *NodeRegistry) FindIdleNode(ctx context.Context) (*BackendNode, error) {
Where("state = ?", "loaded").
Group("node_id")
err := db.Where("status = ? AND node_type = ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, loadedModels).
Order("available_vram DESC").
First(&node).Error
if err != nil {
return nil, err
@@ -578,3 +616,287 @@ func (r *NodeRegistry) FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (
}
return &nm, nil
}
// --- NodeLabel operations ---
// SetNodeLabel upserts a single label on a node.
func (r *NodeRegistry) SetNodeLabel(ctx context.Context, nodeID, key, value string) error {
label := NodeLabel{
ID: uuid.New().String(),
NodeID: nodeID,
Key: key,
Value: value,
}
return r.db.WithContext(ctx).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "node_id"}, {Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value"}),
}).
Create(&label).Error
}
// SetNodeLabels replaces all labels for a node with the given map.
func (r *NodeRegistry) SetNodeLabels(ctx context.Context, nodeID string, labels map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("node_id = ?", nodeID).Delete(&NodeLabel{}).Error; err != nil {
return err
}
for k, v := range labels {
label := NodeLabel{ID: uuid.New().String(), NodeID: nodeID, Key: k, Value: v}
if err := tx.Create(&label).Error; err != nil {
return err
}
}
return nil
})
}
// RemoveNodeLabel removes a single label from a node.
func (r *NodeRegistry) RemoveNodeLabel(ctx context.Context, nodeID, key string) error {
return r.db.WithContext(ctx).Where("node_id = ? AND key = ?", nodeID, key).Delete(&NodeLabel{}).Error
}
// GetNodeLabels returns all labels for a node.
func (r *NodeRegistry) GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error) {
var labels []NodeLabel
err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Find(&labels).Error
return labels, err
}
// GetAllNodeLabelsMap returns all labels grouped by node ID.
func (r *NodeRegistry) GetAllNodeLabelsMap(ctx context.Context) (map[string]map[string]string, error) {
var labels []NodeLabel
if err := r.db.WithContext(ctx).Find(&labels).Error; err != nil {
return nil, err
}
result := make(map[string]map[string]string)
for _, l := range labels {
if result[l.NodeID] == nil {
result[l.NodeID] = make(map[string]string)
}
result[l.NodeID][l.Key] = l.Value
}
return result, nil
}
// --- Selector-based queries ---
// FindNodesBySelector returns healthy backend nodes matching ALL key-value pairs in the selector.
func (r *NodeRegistry) FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error) {
if len(selector) == 0 {
// Empty selector matches all healthy backend nodes
var nodes []BackendNode
err := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend).Find(&nodes).Error
return nodes, err
}
db := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend)
for k, v := range selector {
db = db.Where("EXISTS (SELECT 1 FROM node_labels WHERE node_labels.node_id = backend_nodes.id AND node_labels.key = ? AND node_labels.value = ?)", k, v)
}
var nodes []BackendNode
err := db.Find(&nodes).Error
return nodes, err
}
// FindNodeWithVRAMFromSet is like FindNodeWithVRAM but restricted to the given node IDs.
func (r *NodeRegistry) FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error) {
db := r.db.WithContext(ctx)
loadedModels := db.Model(&NodeModel{}).
Select("node_id").
Where("state = ?", "loaded").
Group("node_id")
subquery := db.Model(&NodeModel{}).
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
Group("node_id")
// Try idle nodes with enough VRAM first, prefer the one with most free VRAM
var node BackendNode
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, minBytes, loadedModels, nodeIDs).
Order("available_vram DESC").
First(&node).Error
if err == nil {
return &node, nil
}
// Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
err = db.Where("status = ? AND node_type = ? AND available_vram >= ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, minBytes, nodeIDs).
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
First(&node).Error
if err != nil {
return nil, fmt.Errorf("no healthy nodes in set with %d bytes available VRAM: %w", minBytes, err)
}
return &node, nil
}
// FindIdleNodeFromSet is like FindIdleNode but restricted to the given node IDs.
func (r *NodeRegistry) FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) {
db := r.db.WithContext(ctx)
var node BackendNode
loadedModels := db.Model(&NodeModel{}).
Select("node_id").
Where("state = ?", "loaded").
Group("node_id")
err := db.Where("status = ? AND node_type = ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, loadedModels, nodeIDs).
Order("available_vram DESC").
First(&node).Error
if err != nil {
return nil, err
}
return &node, nil
}
// FindLeastLoadedNodeFromSet is like FindLeastLoadedNode but restricted to the given node IDs.
func (r *NodeRegistry) FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) {
db := r.db.WithContext(ctx)
var node BackendNode
query := db.Where("status = ? AND node_type = ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, nodeIDs)
// Order by total in-flight across all models on the node
subquery := db.Model(&NodeModel{}).
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
Group("node_id")
err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
First(&node).Error
if err != nil {
return nil, fmt.Errorf("finding least loaded node in set: %w", err)
}
return &node, nil
}
// --- ModelSchedulingConfig operations ---
// SetModelScheduling creates or updates a scheduling config for a model.
func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSchedulingConfig) error {
if config.ID == "" {
config.ID = uuid.New().String()
}
return r.db.WithContext(ctx).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "model_name"}},
DoUpdates: clause.AssignmentColumns([]string{"node_selector", "min_replicas", "max_replicas", "updated_at"}),
}).
Create(config).Error
}
// GetModelScheduling returns the scheduling config for a model, or nil if none exists.
func (r *NodeRegistry) GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error) {
var config ModelSchedulingConfig
err := r.db.WithContext(ctx).Where("model_name = ?", modelName).First(&config).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return &config, nil
}
// ListModelSchedulings returns all scheduling configs.
func (r *NodeRegistry) ListModelSchedulings(ctx context.Context) ([]ModelSchedulingConfig, error) {
var configs []ModelSchedulingConfig
err := r.db.WithContext(ctx).Find(&configs).Error
return configs, err
}
// ListAutoScalingConfigs returns scheduling configs where auto-scaling is enabled.
func (r *NodeRegistry) ListAutoScalingConfigs(ctx context.Context) ([]ModelSchedulingConfig, error) {
var configs []ModelSchedulingConfig
err := r.db.WithContext(ctx).Where("min_replicas > 0 OR max_replicas > 0").Find(&configs).Error
return configs, err
}
// DeleteModelScheduling removes a scheduling config by model name.
func (r *NodeRegistry) DeleteModelScheduling(ctx context.Context, modelName string) error {
return r.db.WithContext(ctx).Where("model_name = ?", modelName).Delete(&ModelSchedulingConfig{}).Error
}
// CountLoadedReplicas returns the number of loaded replicas for a model.
func (r *NodeRegistry) CountLoadedReplicas(ctx context.Context, modelName string) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&NodeModel{}).Where("model_name = ? AND state = ?", modelName, "loaded").Count(&count).Error
return count, err
}
// --- Composite queries ---
// ListWithExtras returns all nodes with model counts and labels.
func (r *NodeRegistry) ListWithExtras(ctx context.Context) ([]NodeWithExtras, error) {
// Get all nodes
var nodes []BackendNode
if err := r.db.WithContext(ctx).Find(&nodes).Error; err != nil {
return nil, err
}
// Get model counts per node
type modelCount struct {
NodeID string
Count int
}
var counts []modelCount
if err := r.db.WithContext(ctx).Model(&NodeModel{}).
Select("node_id, COUNT(*) as count").
Where("state = ?", "loaded").
Group("node_id").
Find(&counts).Error; err != nil {
xlog.Warn("ListWithExtras: failed to get model counts", "error", err)
}
countMap := make(map[string]int)
for _, c := range counts {
countMap[c.NodeID] = c.Count
}
// Get all labels
labelsMap, err := r.GetAllNodeLabelsMap(ctx)
if err != nil {
xlog.Warn("ListWithExtras: failed to get labels", "error", err)
}
// Build result
result := make([]NodeWithExtras, len(nodes))
for i, n := range nodes {
result[i] = NodeWithExtras{
BackendNode: n,
ModelCount: countMap[n.ID],
Labels: labelsMap[n.ID],
}
}
return result, nil
}
// ApplyAutoLabels sets automatic labels based on node hardware info.
func (r *NodeRegistry) ApplyAutoLabels(ctx context.Context, nodeID string, node *BackendNode) {
if node.GPUVendor != "" {
_ = r.SetNodeLabel(ctx, nodeID, "gpu.vendor", node.GPUVendor)
}
if node.TotalVRAM > 0 {
gb := node.TotalVRAM / (1024 * 1024 * 1024)
var bucket string
switch {
case gb >= 80:
bucket = "80GB+"
case gb >= 48:
bucket = "48GB"
case gb >= 24:
bucket = "24GB"
case gb >= 16:
bucket = "16GB"
case gb >= 8:
bucket = "8GB"
default:
bucket = fmt.Sprintf("%dGB", gb)
}
_ = r.SetNodeLabel(ctx, nodeID, "gpu.vram", bucket)
}
if node.Name != "" {
_ = r.SetNodeLabel(ctx, nodeID, "node.name", node.Name)
}
}

View File

@@ -305,6 +305,252 @@ var _ = Describe("NodeRegistry", func() {
})
})
Describe("NodeLabel CRUD", func() {
It("sets and retrieves labels for a node", func() {
node := makeNode("label-node", "10.0.0.70:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed())
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(labels).To(HaveLen(2))
labelMap := make(map[string]string)
for _, l := range labels {
labelMap[l.Key] = l.Value
}
Expect(labelMap["env"]).To(Equal("prod"))
Expect(labelMap["region"]).To(Equal("us-east"))
})
It("overwrites existing label with same key", func() {
node := makeNode("label-overwrite", "10.0.0.71:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "dev")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(labels).To(HaveLen(1))
Expect(labels[0].Value).To(Equal("prod"))
})
It("removes a single label by key", func() {
node := makeNode("label-remove", "10.0.0.72:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed())
Expect(registry.RemoveNodeLabel(context.Background(), node.ID, "env")).To(Succeed())
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(labels).To(HaveLen(1))
Expect(labels[0].Key).To(Equal("region"))
})
It("SetNodeLabels replaces all labels", func() {
node := makeNode("label-replace", "10.0.0.73:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "old-key", "old-val")).To(Succeed())
newLabels := map[string]string{"new-a": "val-a", "new-b": "val-b"}
Expect(registry.SetNodeLabels(context.Background(), node.ID, newLabels)).To(Succeed())
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(labels).To(HaveLen(2))
labelMap := make(map[string]string)
for _, l := range labels {
labelMap[l.Key] = l.Value
}
Expect(labelMap).To(Equal(newLabels))
})
})
Describe("FindNodesBySelector", func() {
It("returns nodes matching all labels in selector", func() {
n1 := makeNode("sel-match", "10.0.0.80:50051", 8_000_000_000)
n2 := makeNode("sel-nomatch", "10.0.0.81:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n1.ID, "env", "prod")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n1.ID, "region", "us-east")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n2.ID, "env", "dev")).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod", "region": "us-east"})
Expect(err).ToNot(HaveOccurred())
Expect(nodes).To(HaveLen(1))
Expect(nodes[0].Name).To(Equal("sel-match"))
})
It("returns empty when no nodes match", func() {
n := makeNode("sel-empty", "10.0.0.82:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "dev")).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
Expect(err).ToNot(HaveOccurred())
Expect(nodes).To(BeEmpty())
})
It("ignores unhealthy nodes", func() {
n := makeNode("sel-unhealthy", "10.0.0.83:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed())
Expect(registry.MarkUnhealthy(context.Background(), n.ID)).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
Expect(err).ToNot(HaveOccurred())
Expect(nodes).To(BeEmpty())
})
It("matches nodes with more labels than selector requires", func() {
n := makeNode("sel-superset", "10.0.0.84:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "region", "us-east")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "tier", "gpu")).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
Expect(err).ToNot(HaveOccurred())
Expect(nodes).To(HaveLen(1))
Expect(nodes[0].Name).To(Equal("sel-superset"))
})
It("returns all healthy nodes for empty selector", func() {
n1 := makeNode("sel-all-1", "10.0.0.85:50051", 8_000_000_000)
n2 := makeNode("sel-all-2", "10.0.0.86:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{})
Expect(err).ToNot(HaveOccurred())
Expect(len(nodes)).To(BeNumerically(">=", 2))
})
})
Describe("ModelSchedulingConfig CRUD", func() {
It("creates and retrieves a scheduling config", func() {
config := &ModelSchedulingConfig{
ModelName: "llama-7b",
NodeSelector: `{"gpu.vendor":"nvidia"}`,
MinReplicas: 1,
MaxReplicas: 3,
}
Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed())
Expect(config.ID).ToNot(BeEmpty())
fetched, err := registry.GetModelScheduling(context.Background(), "llama-7b")
Expect(err).ToNot(HaveOccurred())
Expect(fetched).ToNot(BeNil())
Expect(fetched.ModelName).To(Equal("llama-7b"))
Expect(fetched.NodeSelector).To(Equal(`{"gpu.vendor":"nvidia"}`))
Expect(fetched.MinReplicas).To(Equal(1))
Expect(fetched.MaxReplicas).To(Equal(3))
})
It("updates existing config via SetModelScheduling", func() {
config := &ModelSchedulingConfig{
ModelName: "update-model",
MinReplicas: 1,
MaxReplicas: 2,
}
Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed())
config2 := &ModelSchedulingConfig{
ModelName: "update-model",
MinReplicas: 2,
MaxReplicas: 5,
}
Expect(registry.SetModelScheduling(context.Background(), config2)).To(Succeed())
fetched, err := registry.GetModelScheduling(context.Background(), "update-model")
Expect(err).ToNot(HaveOccurred())
Expect(fetched.MinReplicas).To(Equal(2))
Expect(fetched.MaxReplicas).To(Equal(5))
})
It("lists all configs", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed())
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed())
configs, err := registry.ListModelSchedulings(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(len(configs)).To(BeNumerically(">=", 2))
})
It("lists only auto-scaling configs", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-a", MinReplicas: 2})).To(Succeed())
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-b", MaxReplicas: 3})).To(Succeed())
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "no-auto", NodeSelector: `{"env":"prod"}`})).To(Succeed())
configs, err := registry.ListAutoScalingConfigs(context.Background())
Expect(err).ToNot(HaveOccurred())
names := make([]string, len(configs))
for i, c := range configs {
names[i] = c.ModelName
}
Expect(names).To(ContainElement("auto-a"))
Expect(names).To(ContainElement("auto-b"))
Expect(names).ToNot(ContainElement("no-auto"))
})
It("deletes a config", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "delete-me", MinReplicas: 1})).To(Succeed())
Expect(registry.DeleteModelScheduling(context.Background(), "delete-me")).To(Succeed())
fetched, err := registry.GetModelScheduling(context.Background(), "delete-me")
Expect(err).ToNot(HaveOccurred())
Expect(fetched).To(BeNil())
})
It("returns nil for non-existent model", func() {
fetched, err := registry.GetModelScheduling(context.Background(), "does-not-exist")
Expect(err).ToNot(HaveOccurred())
Expect(fetched).To(BeNil())
})
})
Describe("CountLoadedReplicas", func() {
It("returns correct count of loaded replicas", func() {
n1 := makeNode("replica-node-1", "10.0.0.90:50051", 8_000_000_000)
n2 := makeNode("replica-node-2", "10.0.0.91:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "counted-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "counted-model", "loaded", "", 0)).To(Succeed())
count, err := registry.CountLoadedReplicas(context.Background(), "counted-model")
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(2)))
})
It("excludes non-loaded states", func() {
n1 := makeNode("replica-loaded", "10.0.0.92:50051", 8_000_000_000)
n2 := makeNode("replica-loading", "10.0.0.93:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "state-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "state-model", "loading", "", 0)).To(Succeed())
count, err := registry.CountLoadedReplicas(context.Background(), "state-model")
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(1)))
})
})
Describe("DecrementInFlight", func() {
It("does not go below zero", func() {
node := makeNode("dec-node", "10.0.0.50:50051", 4_000_000_000)

View File

@@ -2,6 +2,7 @@ package nodes
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -69,6 +70,18 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
// Unloader returns the remote unloader adapter for external use.
func (r *SmartRouter) Unloader() NodeCommandSender { return r.unloader }
// ScheduleAndLoadModel implements ModelScheduler for the reconciler.
// It schedules a model on a suitable node (optionally from candidates) and loads it.
func (r *SmartRouter) ScheduleAndLoadModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) {
// Use scheduleNewModel with empty backend type and nil model options.
// The reconciler doesn't know the backend type — it will be determined by the model config.
node, _, err := r.scheduleNewModel(ctx, "", modelName, nil)
if err != nil {
return nil, err
}
return node, nil
}
// RouteResult contains the routing decision.
type RouteResult struct {
Node *BackendNode
@@ -110,18 +123,26 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
xlog.Warn("Backend not reachable for cached model, falling through to reload",
"node", node.Name, "model", modelName)
} else {
// Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight,
// and Release decrements it. No InFlightTrackingClient to avoid double-counting.
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
return &RouteResult{
Node: node,
Client: grpcClient,
Release: func() {
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
closeClient(grpcClient)
},
}, nil
// Verify node still matches scheduling constraints
if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
xlog.Info("Cached model on node that no longer matches selector, falling through",
"node", node.Name, "model", trackingKey)
// Fall through to step 2 (scheduleNewModel)
} else {
// Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight,
// and Release decrements it. No InFlightTrackingClient to avoid double-counting.
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
return &RouteResult{
Node: node,
Client: grpcClient,
Release: func() {
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
closeClient(grpcClient)
},
}, nil
}
}
}
@@ -143,17 +164,25 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
xlog.Warn("Backend not reachable for cached model inside lock, proceeding to load",
"node", node.Name, "model", modelName)
} else {
// Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
return &RouteResult{
Node: node,
Client: grpcClient,
Release: func() {
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
closeClient(grpcClient)
},
}, nil
// Verify node still matches scheduling constraints
if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
xlog.Info("Cached model on node that no longer matches selector, falling through",
"node", node.Name, "model", trackingKey)
// Fall through to scheduling below
} else {
// Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
return &RouteResult{
Node: node,
Client: grpcClient,
Release: func() {
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
closeClient(grpcClient)
},
}, nil
}
}
}
@@ -223,6 +252,59 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
return loadModel()
}
// parseSelectorJSON decodes a JSON node selector string into a map.
func parseSelectorJSON(selectorJSON string) map[string]string {
if selectorJSON == "" {
return nil
}
var selector map[string]string
if err := json.Unmarshal([]byte(selectorJSON), &selector); err != nil {
xlog.Warn("Failed to parse node selector", "selector", selectorJSON, "error", err)
return nil
}
return selector
}
func extractNodeIDs(nodes []BackendNode) []string {
ids := make([]string, len(nodes))
for i, n := range nodes {
ids[i] = n.ID
}
return ids
}
// nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model.
// Returns true if no constraints exist or the node matches all selector labels.
func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool {
sched, err := r.registry.GetModelScheduling(ctx, modelName)
if err != nil || sched == nil || sched.NodeSelector == "" {
return true // no constraints
}
selector := parseSelectorJSON(sched.NodeSelector)
if len(selector) == 0 {
return true
}
labels, err := r.registry.GetNodeLabels(ctx, node.ID)
if err != nil {
xlog.Warn("Failed to get node labels for selector check", "node", node.ID, "error", err)
return true // fail open
}
labelMap := make(map[string]string)
for _, l := range labels {
labelMap[l.Key] = l.Value
}
for k, v := range selector {
if labelMap[k] != v {
return false
}
}
return true
}
// scheduleNewModel picks the best node for loading a new model.
// Strategy: VRAM-aware → idle-first → least-loaded.
// Sends backend.install via NATS so the chosen node has the right backend running.
@@ -233,12 +315,30 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
estimatedVRAM = r.estimateModelVRAM(ctx, modelOpts)
}
// Check for scheduling constraints (node selector)
sched, _ := r.registry.GetModelScheduling(ctx, modelID)
var candidateNodeIDs []string // nil = all nodes eligible
if sched != nil && sched.NodeSelector != "" {
selector := parseSelectorJSON(sched.NodeSelector)
if len(selector) > 0 {
candidates, err := r.registry.FindNodesBySelector(ctx, selector)
if err != nil || len(candidates) == 0 {
return nil, "", fmt.Errorf("no healthy nodes match selector for model %s: %v", modelID, sched.NodeSelector)
}
candidateNodeIDs = extractNodeIDs(candidates)
}
}
var node *BackendNode
var err error
if estimatedVRAM > 0 {
// 1. Prefer nodes with enough VRAM (idle-first, then least-loaded)
node, err = r.registry.FindNodeWithVRAM(ctx, estimatedVRAM)
if candidateNodeIDs != nil {
node, err = r.registry.FindNodeWithVRAMFromSet(ctx, estimatedVRAM, candidateNodeIDs)
} else {
node, err = r.registry.FindNodeWithVRAM(ctx, estimatedVRAM)
}
if err != nil {
xlog.Warn("No nodes with enough VRAM, falling back to standard scheduling",
"required_vram", vram.FormatBytes(estimatedVRAM), "error", err)
@@ -246,11 +346,16 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
}
if node == nil {
// 2. Prefer truly idle nodes (no loaded models, no in-flight)
node, err = r.registry.FindIdleNode(ctx)
if err != nil {
// 3. Fall back to least-loaded node (can run an additional backend process)
node, err = r.registry.FindLeastLoadedNode(ctx)
if candidateNodeIDs != nil {
node, err = r.registry.FindIdleNodeFromSet(ctx, candidateNodeIDs)
if err != nil {
node, err = r.registry.FindLeastLoadedNodeFromSet(ctx, candidateNodeIDs)
}
} else {
node, err = r.registry.FindIdleNode(ctx)
if err != nil {
node, err = r.registry.FindLeastLoadedNode(ctx)
}
}
}
@@ -610,7 +715,12 @@ func (r *SmartRouter) evictLRUAndFreeNode(ctx context.Context) (*BackendNode, er
// Lock the row so no other frontend can evict the same model
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
Where("node_models.in_flight = 0 AND node_models.state = ? AND backend_nodes.status = ?", "loaded", StatusHealthy).
Where(`node_models.in_flight = 0 AND node_models.state = ? AND backend_nodes.status = ?
AND (
NOT EXISTS (SELECT 1 FROM model_scheduling_configs sc WHERE sc.model_name = node_models.model_name AND (sc.min_replicas > 0 OR sc.max_replicas > 0))
OR (SELECT COUNT(*) FROM node_models nm2 WHERE nm2.model_name = node_models.model_name AND nm2.state = 'loaded')
> COALESCE((SELECT sc2.min_replicas FROM model_scheduling_configs sc2 WHERE sc2.model_name = node_models.model_name), 1)
)`, "loaded", StatusHealthy).
Order("node_models.last_used ASC").
First(&lru).Error; err != nil {
return err

View File

@@ -86,6 +86,26 @@ type fakeModelRouter struct {
getNode *BackendNode
getErr error
// GetModelScheduling returns
getModelScheduling *ModelSchedulingConfig
getModelSchedErr error
// FindNodesBySelector returns
findBySelectorNodes []BackendNode
findBySelectorErr error
// *FromSet variants
findVRAMFromSetNode *BackendNode
findVRAMFromSetErr error
findIdleFromSetNode *BackendNode
findIdleFromSetErr error
findLeastLoadedFromSetNode *BackendNode
findLeastLoadedFromSetErr error
// GetNodeLabels returns
getNodeLabels []NodeLabel
getNodeLabelsErr error
// Track calls for assertions
decrementCalls []string // "nodeID:modelName"
incrementCalls []string
@@ -146,6 +166,30 @@ func (f *fakeModelRouter) Get(_ context.Context, _ string) (*BackendNode, error)
return f.getNode, f.getErr
}
func (f *fakeModelRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
return f.getModelScheduling, f.getModelSchedErr
}
func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
return f.findBySelectorNodes, f.findBySelectorErr
}
func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
return f.findVRAMFromSetNode, f.findVRAMFromSetErr
}
func (f *fakeModelRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
return f.findIdleFromSetNode, f.findIdleFromSetErr
}
func (f *fakeModelRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
return f.findLeastLoadedFromSetNode, f.findLeastLoadedFromSetErr
}
func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
return f.getNodeLabels, f.getNodeLabelsErr
}
// ---------------------------------------------------------------------------
// Fake BackendClientFactory + Backend
// ---------------------------------------------------------------------------
@@ -478,6 +522,135 @@ var _ = Describe("SmartRouter", func() {
})
})
Describe("scheduleNewModel with node selector (mock-based, via Route)", func() {
var (
reg *fakeModelRouter
backend *stubBackend
factory *stubClientFactory
unloader *fakeUnloader
)
BeforeEach(func() {
reg = &fakeModelRouter{
findAndLockErr: errors.New("not found"),
}
backend = &stubBackend{
loadResult: &pb.Result{Success: true},
}
factory = &stubClientFactory{client: backend}
unloader = &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.1:9001",
},
}
})
It("uses *FromSet methods when model has a node selector", func() {
gpuNode := &BackendNode{ID: "gpu-1", Name: "gpu-node", Address: "10.0.0.50:50051"}
reg.getModelScheduling = &ModelSchedulingConfig{
ModelName: "selector-model",
NodeSelector: `{"gpu.vendor":"nvidia"}`,
}
reg.findBySelectorNodes = []BackendNode{*gpuNode}
reg.findIdleFromSetNode = gpuNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "selector-model", "models/selector.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("gpu-1"))
})
It("returns error when no nodes match selector", func() {
reg.getModelScheduling = &ModelSchedulingConfig{
ModelName: "no-match-model",
NodeSelector: `{"gpu.vendor":"tpu"}`,
}
reg.findBySelectorNodes = nil
reg.findBySelectorErr = nil
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
_, err := router.Route(context.Background(), "no-match-model", "models/nomatch.gguf", "llama-cpp", nil, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no healthy nodes match selector"))
})
It("uses regular methods when model has no scheduling config", func() {
reg.getModelScheduling = nil
idleNode := &BackendNode{ID: "regular-1", Name: "regular-node", Address: "10.0.0.60:50051"}
reg.findIdleNode = idleNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "regular-model", "models/regular.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("regular-1"))
})
})
Describe("Route with selector validation on cached model (mock-based)", func() {
It("falls through when cached node no longer matches selector", func() {
cachedNode := &BackendNode{ID: "n-old", Name: "old-node", Address: "10.0.0.70:50051"}
newNode := &BackendNode{ID: "n-new", Name: "new-node", Address: "10.0.0.71:50051"}
backend := &stubBackend{
healthResult: true,
loadResult: &pb.Result{Success: true},
}
factory := &stubClientFactory{client: backend}
unloader := &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.71:9001",
},
}
reg := &fakeModelRouter{
// Step 1: cached model found on old node
findAndLockNode: cachedNode,
findAndLockNM: &NodeModel{NodeID: "n-old", ModelName: "sel-model", Address: "10.0.0.70:9001"},
// Scheduling config with selector that old node does NOT match
getModelScheduling: &ModelSchedulingConfig{
ModelName: "sel-model",
NodeSelector: `{"gpu.vendor":"nvidia"}`,
},
// Old node has no labels matching the selector
getNodeLabels: []NodeLabel{
{NodeID: "n-old", Key: "gpu.vendor", Value: "amd"},
},
// For scheduling fallthrough: selector matches new node
findBySelectorNodes: []BackendNode{*newNode},
findIdleFromSetNode: newNode,
}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "sel-model", "models/sel.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
// Should have fallen through to the new node
Expect(result.Node.ID).To(Equal("n-new"))
// Old node should have had its in-flight decremented
Expect(reg.decrementCalls).To(ContainElement("n-old:sel-model"))
})
})
// -----------------------------------------------------------------------
// Integration tests using real PostgreSQL (existing)
// -----------------------------------------------------------------------

View File

@@ -134,6 +134,47 @@ local-ai worker \
**HTTP file transfer:** Each worker also runs a small HTTP server for file transfer (model files, configs). By default it listens on the gRPC base port - 1 (e.g., if gRPC base is 50051, HTTP is on 50050). gRPC ports grow upward from the base port as additional models are loaded. Set `--advertise-http-addr` if the auto-detected address is not routable from the frontend.
{{% /notice %}}
### Worker Address Configuration
The simplest way to configure a worker's network address is with a single variable:
| Variable | Description |
|----------|-------------|
| `LOCALAI_ADDR` | Reachable address of this worker (`host:port`). The port is used as the base for gRPC backend processes, and `port-1` for the HTTP file transfer server. |
**Example:**
```yaml
environment:
LOCALAI_ADDR: "192.168.1.100:50051"
LOCALAI_NATS_URL: "nats://frontend:4222"
LOCALAI_REGISTER_TO: "http://frontend:8080"
LOCALAI_REGISTRATION_TOKEN: "my-secret"
```
For advanced networking scenarios (NAT, load balancers, separate gRPC/HTTP ports), the following override variables are available:
| Variable | Description | Default |
|----------|-------------|---------|
| `LOCALAI_SERVE_ADDR` | gRPC base port bind address | `0.0.0.0:50051` |
| `LOCALAI_HTTP_ADDR` | HTTP file transfer bind address | `0.0.0.0:{gRPC port - 1}` |
| `LOCALAI_ADVERTISE_ADDR` | Public gRPC address (if different from `LOCALAI_ADDR`) | Derived from `LOCALAI_ADDR` |
| `LOCALAI_ADVERTISE_HTTP_ADDR` | Public HTTP address (if different from gRPC host) | Derived from advertise host + HTTP port |
### Node Labels
Workers can declare labels at startup for scheduling constraints:
| Variable | Description | Example |
|----------|-------------|---------|
| `LOCALAI_NODE_LABELS` | Comma-separated `key=value` labels | `tier=premium,gpu=a100,zone=us-east` |
Labels can also be managed via the admin API (see [Label Management API](#label-management-api) below).
The system automatically applies hardware-detected labels on registration:
- `gpu.vendor` -- GPU vendor (nvidia, amd, intel, vulkan)
- `gpu.vram` -- GPU VRAM bucket (8GB, 16GB, 24GB, 48GB, 80GB+)
- `node.name` -- The node's registered name
### How Workers Operate
Workers start as generic processes with no backend installed. When the SmartRouter needs to load a model on a worker, it sends a NATS `backend.install` event with the backend name and model ID. The worker:
@@ -262,6 +303,75 @@ local-ai worker \
**Multiple frontend replicas:** Run multiple LocalAI frontends behind a load balancer. Since all state is in PostgreSQL and coordination is via NATS, frontends are fully stateless and interchangeable.
## Model Scheduling
Model scheduling controls where models are placed and how many replicas are maintained. It combines two optional features:
### Node Selectors
Pin models to nodes with specific labels. Only nodes matching **all** selector labels are eligible:
```bash
# Only schedule on NVIDIA nodes in the us-east zone
curl -X POST http://frontend:8080/api/nodes/scheduling \
-H "Content-Type: application/json" \
-d '{"model_name": "llama3", "node_selector": {"gpu.vendor": "nvidia", "zone": "us-east"}}'
```
Without a node selector, models can schedule on any healthy node (default behavior).
### Replica Auto-Scaling
Control the number of model replicas across the cluster:
| Field | Description |
|-------|-------------|
| `min_replicas` | Minimum replicas to maintain (0 = no minimum, single replica default) |
| `max_replicas` | Maximum replicas allowed (0 = unlimited) |
Auto-scaling is **only active** when `min_replicas > 0` or `max_replicas > 0`.
```bash
# Scale llama3 between 2 and 4 replicas on NVIDIA nodes
curl -X POST http://frontend:8080/api/nodes/scheduling \
-H "Content-Type: application/json" \
-d '{
"model_name": "llama3",
"node_selector": {"gpu.vendor": "nvidia"},
"min_replicas": 2,
"max_replicas": 4
}'
```
The **Replica Reconciler** runs as a background process on the frontend:
- **Scale up**: Adds replicas when all existing replicas are busy (have in-flight requests)
- **Scale down**: Removes idle replicas after 5 minutes of inactivity
- **Maintain minimum**: Ensures `min_replicas` are always loaded (recovers from node failures)
- **Eviction protection**: Models with auto-scaling enabled are never evicted below `min_replicas`
All fields are optional and composable:
- Node selector only: pin model to matching nodes, single replica
- Replicas only: auto-scale across all nodes
- Both: auto-scale on matching nodes only
## Label Management API
| Method | Path | Description |
|--------|------|-------------|
| `GET` | `/api/nodes/:id/labels` | Get labels for a node |
| `PUT` | `/api/nodes/:id/labels` | Replace all labels (JSON object) |
| `PATCH` | `/api/nodes/:id/labels` | Merge labels (add/update) |
| `DELETE` | `/api/nodes/:id/labels/:key` | Remove a single label |
## Scheduling API
| Method | Path | Description |
|--------|------|-------------|
| `GET` | `/api/nodes/scheduling` | List all scheduling configs |
| `GET` | `/api/nodes/scheduling/:model` | Get config for a model |
| `POST` | `/api/nodes/scheduling` | Create/update config |
| `DELETE` | `/api/nodes/scheduling/:model` | Remove config |
## Comparison with P2P
| | P2P / Federation | Distributed Mode |

View File

@@ -28,7 +28,6 @@ Changes to watchdog settings are applied immediately by restarting the watchdog
### Backend Configuration
- **Max Active Backends**: Maximum number of active backends (loaded models). When exceeded, the least recently used model is automatically evicted. Set to `0` for unlimited, `1` for single-backend mode
- **Parallel Backend Requests**: Enable backends to handle multiple requests in parallel if supported
- **Force Eviction When Busy**: Allow evicting models even when they have active API calls (default: disabled for safety). **Warning:** Enabling this can interrupt active requests
- **LRU Eviction Max Retries**: Maximum number of retries when waiting for busy models to become idle before eviction (default: 30)
- **LRU Eviction Retry Interval**: Interval between retries when waiting for busy models (default: `1s`)
@@ -123,7 +122,6 @@ The `runtime_settings.json` file follows this structure:
"watchdog_idle_timeout": "15m",
"watchdog_busy_timeout": "5m",
"max_active_backends": 0,
"parallel_backend_requests": true,
"force_eviction_when_busy": false,
"lru_eviction_max_retries": 30,
"lru_eviction_retry_interval": "1s",

View File

@@ -39,7 +39,6 @@ Complete reference for all LocalAI command-line interface (CLI) parameters and e
| `--external-grpc-backends` | | A list of external gRPC backends (format: `BACKEND_NAME:URI`) | `$LOCALAI_EXTERNAL_GRPC_BACKENDS`, `$EXTERNAL_GRPC_BACKENDS` |
| `--backend-galleries` | | JSON list of backend galleries | `$LOCALAI_BACKEND_GALLERIES`, `$BACKEND_GALLERIES` |
| `--autoload-backend-galleries` | `true` | Automatically load backend galleries on startup | `$LOCALAI_AUTOLOAD_BACKEND_GALLERIES`, `$AUTOLOAD_BACKEND_GALLERIES` |
| `--parallel-requests` | `false` | Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm) | `$LOCALAI_PARALLEL_REQUESTS`, `$PARALLEL_REQUESTS` |
| `--max-active-backends` | `0` | Maximum number of active backends (loaded models). When exceeded, the least recently used model is evicted. Set to `0` for unlimited, `1` for single-backend mode | `$LOCALAI_MAX_ACTIVE_BACKENDS`, `$MAX_ACTIVE_BACKENDS` |
| `--single-active-backend` | `false` | **DEPRECATED** - Use `--max-active-backends=1` instead. Allow only one backend to be run at a time | `$LOCALAI_SINGLE_ACTIVE_BACKEND`, `$SINGLE_ACTIVE_BACKEND` |
| `--preload-backend-only` | `false` | Do not launch the API services, only the preloaded models/backends are started (useful for multi-node setups) | `$LOCALAI_PRELOAD_BACKEND_ONLY`, `$PRELOAD_BACKEND_ONLY` |