mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
feat: add node reconciler, allow to schedule to group of nodes, min/max autoscaler (#9186)
* always enable parallel requests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: add node reconciler, allow to schedule to group of nodes, min/max autoscaler Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore: move tests to ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(smart router): order by available vram Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
80699a3f70
commit
8862e3ce60
@@ -199,7 +199,6 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout
|
||||
envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend
|
||||
envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends
|
||||
envParallelRequests := appConfig.ParallelBackendRequests == startupAppConfig.ParallelBackendRequests
|
||||
envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled
|
||||
envMemoryReclaimerThreshold := appConfig.MemoryReclaimerThreshold == startupAppConfig.MemoryReclaimerThreshold
|
||||
envThreads := appConfig.Threads == startupAppConfig.Threads
|
||||
@@ -271,9 +270,6 @@ func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHand
|
||||
appConfig.MaxActiveBackends = 0
|
||||
}
|
||||
}
|
||||
if settings.ParallelBackendRequests != nil && !envParallelRequests {
|
||||
appConfig.ParallelBackendRequests = *settings.ParallelBackendRequests
|
||||
}
|
||||
if settings.MemoryReclaimerEnabled != nil && !envMemoryReclaimerEnabled {
|
||||
appConfig.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
|
||||
if appConfig.MemoryReclaimerEnabled {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -28,6 +29,7 @@ type DistributedServices struct {
|
||||
Registry *nodes.NodeRegistry
|
||||
Router *nodes.SmartRouter
|
||||
Health *nodes.HealthMonitor
|
||||
Reconciler *nodes.ReplicaReconciler
|
||||
JobStore *jobs.JobStore
|
||||
Dispatcher *jobs.Dispatcher
|
||||
AgentStore *agents.AgentStore
|
||||
@@ -240,6 +242,16 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
|
||||
DB: authDB,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas
|
||||
reconciler := nodes.NewReplicaReconciler(nodes.ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: router,
|
||||
Unloader: remoteUnloader,
|
||||
DB: authDB,
|
||||
Interval: 30 * time.Second,
|
||||
ScaleDownDelay: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Create ModelRouterAdapter to wire into ModelLoader
|
||||
modelAdapter := nodes.NewModelRouterAdapter(router)
|
||||
|
||||
@@ -250,6 +262,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
|
||||
Registry: registry,
|
||||
Router: router,
|
||||
Health: healthMon,
|
||||
Reconciler: reconciler,
|
||||
JobStore: jobStore,
|
||||
Dispatcher: dispatcher,
|
||||
AgentStore: agentStore,
|
||||
|
||||
@@ -158,6 +158,10 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
application.modelLoader.SetModelStore(distStore)
|
||||
// Start health monitor
|
||||
distSvc.Health.Start(options.Context)
|
||||
// Start replica reconciler for auto-scaling model replicas
|
||||
if distSvc.Reconciler != nil {
|
||||
go distSvc.Reconciler.Run(options.Context)
|
||||
}
|
||||
// In distributed mode, MCP CI jobs are executed by agent workers (not the frontend)
|
||||
// because the frontend can't create MCP sessions (e.g., stdio servers using docker).
|
||||
// The dispatcher still subscribes to jobs.new for persistence (result/progress subs)
|
||||
@@ -439,11 +443,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings.ParallelBackendRequests != nil {
|
||||
if !options.ParallelBackendRequests {
|
||||
options.ParallelBackendRequests = *settings.ParallelBackendRequests
|
||||
}
|
||||
}
|
||||
if settings.MemoryReclaimerEnabled != nil {
|
||||
// Only apply if current value is default (false), suggesting it wasn't set from env var
|
||||
if !options.MemoryReclaimerEnabled {
|
||||
|
||||
@@ -59,9 +59,7 @@ func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...mo
|
||||
grpcOpts := grpcModelOpts(c, so.SystemState.Model.ModelsPath)
|
||||
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
|
||||
|
||||
if so.ParallelBackendRequests {
|
||||
defOpts = append(defOpts, model.EnableParallelRequests)
|
||||
}
|
||||
defOpts = append(defOpts, model.EnableParallelRequests)
|
||||
|
||||
if c.GRPC.Attempts != 0 {
|
||||
defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts))
|
||||
|
||||
@@ -4,211 +4,172 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAGI/core/state"
|
||||
)
|
||||
|
||||
func TestAgentRunCMD_LoadAgentConfigFromFile(t *testing.T) {
|
||||
// Create a temporary agent config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
var _ = Describe("AgentRunCMD", func() {
|
||||
Describe("loadAgentConfig", func() {
|
||||
It("loads agent config from file", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
|
||||
cfg := state.AgentConfig{
|
||||
Name: "test-agent",
|
||||
Model: "llama3",
|
||||
SystemPrompt: "You are a helpful assistant",
|
||||
}
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(configFile, data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cfg := state.AgentConfig{
|
||||
Name: "test-agent",
|
||||
Model: "llama3",
|
||||
SystemPrompt: "You are a helpful assistant",
|
||||
}
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(configFile, data, 0644)).To(Succeed())
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadAgentConfig() error: %v", err)
|
||||
}
|
||||
if loaded.Name != "test-agent" {
|
||||
t.Errorf("expected name %q, got %q", "test-agent", loaded.Name)
|
||||
}
|
||||
if loaded.Model != "llama3" {
|
||||
t.Errorf("expected model %q, got %q", "llama3", loaded.Model)
|
||||
}
|
||||
}
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(loaded.Name).To(Equal("test-agent"))
|
||||
Expect(loaded.Model).To(Equal("llama3"))
|
||||
})
|
||||
|
||||
func TestAgentRunCMD_LoadAgentConfigFromPool(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
It("loads agent config from pool", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"my-agent": {
|
||||
Model: "gpt-4",
|
||||
Description: "A test agent",
|
||||
SystemPrompt: "Hello",
|
||||
},
|
||||
"other-agent": {
|
||||
Model: "llama3",
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pool := map[string]state.AgentConfig{
|
||||
"my-agent": {
|
||||
Model: "gpt-4",
|
||||
Description: "A test agent",
|
||||
SystemPrompt: "Hello",
|
||||
},
|
||||
"other-agent": {
|
||||
Model: "llama3",
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "my-agent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "my-agent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadAgentConfig() error: %v", err)
|
||||
}
|
||||
if loaded.Name != "my-agent" {
|
||||
t.Errorf("expected name %q, got %q", "my-agent", loaded.Name)
|
||||
}
|
||||
if loaded.Model != "gpt-4" {
|
||||
t.Errorf("expected model %q, got %q", "gpt-4", loaded.Model)
|
||||
}
|
||||
}
|
||||
loaded, err := cmd.loadAgentConfig()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(loaded.Name).To(Equal("my-agent"))
|
||||
Expect(loaded.Model).To(Equal("gpt-4"))
|
||||
})
|
||||
|
||||
func TestAgentRunCMD_LoadAgentConfigFromPool_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
It("returns error for missing agent in pool", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"existing-agent": {Model: "llama3"},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pool := map[string]state.AgentConfig{
|
||||
"existing-agent": {Model: "llama3"},
|
||||
}
|
||||
data, err := json.MarshalIndent(pool, "", " ")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "nonexistent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
cmd := &AgentRunCMD{
|
||||
Name: "nonexistent",
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
_, err = cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing agent, got nil")
|
||||
}
|
||||
}
|
||||
_, err = cmd.loadAgentConfig()
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
func TestAgentRunCMD_LoadAgentConfigNoNameOrConfig(t *testing.T) {
|
||||
cmd := &AgentRunCMD{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
It("returns error when no pool.json exists", func() {
|
||||
cmd := &AgentRunCMD{
|
||||
StateDir: GinkgoT().TempDir(),
|
||||
}
|
||||
|
||||
_, err := cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no pool.json exists, got nil")
|
||||
}
|
||||
}
|
||||
_, err := cmd.loadAgentConfig()
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
func TestAgentRunCMD_ApplyOverrides(t *testing.T) {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
}
|
||||
It("returns error for config with no name", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
APIURL: "http://localhost:9090",
|
||||
APIKey: "secret",
|
||||
DefaultModel: "my-model",
|
||||
}
|
||||
cfg := state.AgentConfig{
|
||||
Model: "llama3",
|
||||
}
|
||||
data, _ := json.MarshalIndent(cfg, "", " ")
|
||||
Expect(os.WriteFile(configFile, data, 0644)).To(Succeed())
|
||||
|
||||
cmd.applyOverrides(cfg)
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
if cfg.APIURL != "http://localhost:9090" {
|
||||
t.Errorf("expected APIURL %q, got %q", "http://localhost:9090", cfg.APIURL)
|
||||
}
|
||||
if cfg.APIKey != "secret" {
|
||||
t.Errorf("expected APIKey %q, got %q", "secret", cfg.APIKey)
|
||||
}
|
||||
if cfg.Model != "my-model" {
|
||||
t.Errorf("expected Model %q, got %q", "my-model", cfg.Model)
|
||||
}
|
||||
}
|
||||
_, err := cmd.loadAgentConfig()
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
func TestAgentRunCMD_ApplyOverridesDoesNotOverwriteExisting(t *testing.T) {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
Model: "existing-model",
|
||||
}
|
||||
Describe("applyOverrides", func() {
|
||||
It("applies overrides to empty fields", func() {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
DefaultModel: "override-model",
|
||||
}
|
||||
cmd := &AgentRunCMD{
|
||||
APIURL: "http://localhost:9090",
|
||||
APIKey: "secret",
|
||||
DefaultModel: "my-model",
|
||||
}
|
||||
|
||||
cmd.applyOverrides(cfg)
|
||||
cmd.applyOverrides(cfg)
|
||||
|
||||
if cfg.Model != "existing-model" {
|
||||
t.Errorf("expected Model to remain %q, got %q", "existing-model", cfg.Model)
|
||||
}
|
||||
}
|
||||
Expect(cfg.APIURL).To(Equal("http://localhost:9090"))
|
||||
Expect(cfg.APIKey).To(Equal("secret"))
|
||||
Expect(cfg.Model).To(Equal("my-model"))
|
||||
})
|
||||
|
||||
func TestAgentRunCMD_LoadConfigMissingName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "agent.json")
|
||||
It("does not overwrite existing model", func() {
|
||||
cfg := &state.AgentConfig{
|
||||
Name: "test",
|
||||
Model: "existing-model",
|
||||
}
|
||||
|
||||
// Agent config with no name
|
||||
cfg := state.AgentConfig{
|
||||
Model: "llama3",
|
||||
}
|
||||
data, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configFile, data, 0644)
|
||||
cmd := &AgentRunCMD{
|
||||
DefaultModel: "override-model",
|
||||
}
|
||||
|
||||
cmd := &AgentRunCMD{
|
||||
Config: configFile,
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
cmd.applyOverrides(cfg)
|
||||
|
||||
_, err := cmd.loadAgentConfig()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for config with no name, got nil")
|
||||
}
|
||||
}
|
||||
Expect(cfg.Model).To(Equal("existing-model"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func TestAgentListCMD_NoPoolFile(t *testing.T) {
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
var _ = Describe("AgentListCMD", func() {
|
||||
It("runs without error when no pool file exists", func() {
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: GinkgoT().TempDir(),
|
||||
}
|
||||
Expect(cmd.Run(nil)).To(Succeed())
|
||||
})
|
||||
|
||||
// Should not error, just print "no agents found"
|
||||
err := cmd.Run(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
It("runs without error with agents in pool", func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
|
||||
func TestAgentListCMD_WithAgents(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
pool := map[string]state.AgentConfig{
|
||||
"agent-a": {Model: "llama3", Description: "First agent"},
|
||||
"agent-b": {Model: "gpt-4"},
|
||||
}
|
||||
data, _ := json.MarshalIndent(pool, "", " ")
|
||||
Expect(os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)).To(Succeed())
|
||||
|
||||
pool := map[string]state.AgentConfig{
|
||||
"agent-a": {Model: "llama3", Description: "First agent"},
|
||||
"agent-b": {Model: "gpt-4"},
|
||||
}
|
||||
data, _ := json.MarshalIndent(pool, "", " ")
|
||||
os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644)
|
||||
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
|
||||
err := cmd.Run(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
cmd := &AgentListCMD{
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
Expect(cmd.Run(nil)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
13
core/cli/cli_suite_test.go
Normal file
13
core/cli/cli_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestCLI(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "CLI Suite")
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alecthomas/kong"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func getTestApp() *kong.Application {
|
||||
@@ -21,76 +20,55 @@ func getTestApp() *kong.Application {
|
||||
return k.Model
|
||||
}
|
||||
|
||||
func TestGenerateBashCompletion(t *testing.T) {
|
||||
app := getTestApp()
|
||||
script := generateBashCompletion(app)
|
||||
var _ = Describe("Shell completions", func() {
|
||||
var app *kong.Application
|
||||
|
||||
if !strings.Contains(script, "complete -F _local_ai_completions local-ai") {
|
||||
t.Error("bash completion missing complete command registration")
|
||||
}
|
||||
if !strings.Contains(script, "run") {
|
||||
t.Error("bash completion missing 'run' command")
|
||||
}
|
||||
if !strings.Contains(script, "models") {
|
||||
t.Error("bash completion missing 'models' command")
|
||||
}
|
||||
if !strings.Contains(script, "completion") {
|
||||
t.Error("bash completion missing 'completion' command")
|
||||
}
|
||||
}
|
||||
BeforeEach(func() {
|
||||
app = getTestApp()
|
||||
})
|
||||
|
||||
func TestGenerateZshCompletion(t *testing.T) {
|
||||
app := getTestApp()
|
||||
script := generateZshCompletion(app)
|
||||
Describe("generateBashCompletion", func() {
|
||||
It("generates valid bash completion script", func() {
|
||||
script := generateBashCompletion(app)
|
||||
Expect(script).To(ContainSubstring("complete -F _local_ai_completions local-ai"))
|
||||
Expect(script).To(ContainSubstring("run"))
|
||||
Expect(script).To(ContainSubstring("models"))
|
||||
Expect(script).To(ContainSubstring("completion"))
|
||||
})
|
||||
})
|
||||
|
||||
if !strings.Contains(script, "#compdef local-ai") {
|
||||
t.Error("zsh completion missing compdef header")
|
||||
}
|
||||
if !strings.Contains(script, "run") {
|
||||
t.Error("zsh completion missing 'run' command")
|
||||
}
|
||||
if !strings.Contains(script, "models") {
|
||||
t.Error("zsh completion missing 'models' command")
|
||||
}
|
||||
}
|
||||
Describe("generateZshCompletion", func() {
|
||||
It("generates valid zsh completion script", func() {
|
||||
script := generateZshCompletion(app)
|
||||
Expect(script).To(ContainSubstring("#compdef local-ai"))
|
||||
Expect(script).To(ContainSubstring("run"))
|
||||
Expect(script).To(ContainSubstring("models"))
|
||||
})
|
||||
})
|
||||
|
||||
func TestGenerateFishCompletion(t *testing.T) {
|
||||
app := getTestApp()
|
||||
script := generateFishCompletion(app)
|
||||
Describe("generateFishCompletion", func() {
|
||||
It("generates valid fish completion script", func() {
|
||||
script := generateFishCompletion(app)
|
||||
Expect(script).To(ContainSubstring("complete -c local-ai"))
|
||||
Expect(script).To(ContainSubstring("__fish_use_subcommand"))
|
||||
Expect(script).To(ContainSubstring("run"))
|
||||
Expect(script).To(ContainSubstring("models"))
|
||||
})
|
||||
})
|
||||
|
||||
if !strings.Contains(script, "complete -c local-ai") {
|
||||
t.Error("fish completion missing complete command")
|
||||
}
|
||||
if !strings.Contains(script, "__fish_use_subcommand") {
|
||||
t.Error("fish completion missing subcommand detection")
|
||||
}
|
||||
if !strings.Contains(script, "run") {
|
||||
t.Error("fish completion missing 'run' command")
|
||||
}
|
||||
if !strings.Contains(script, "models") {
|
||||
t.Error("fish completion missing 'models' command")
|
||||
}
|
||||
}
|
||||
Describe("collectCommands", func() {
|
||||
It("collects all commands and subcommands", func() {
|
||||
cmds := collectCommands(app.Node, "")
|
||||
|
||||
func TestCollectCommands(t *testing.T) {
|
||||
app := getTestApp()
|
||||
cmds := collectCommands(app.Node, "")
|
||||
names := make(map[string]bool)
|
||||
for _, cmd := range cmds {
|
||||
names[cmd.fullName] = true
|
||||
}
|
||||
|
||||
names := make(map[string]bool)
|
||||
for _, cmd := range cmds {
|
||||
names[cmd.fullName] = true
|
||||
}
|
||||
|
||||
if !names["run"] {
|
||||
t.Error("missing 'run' command")
|
||||
}
|
||||
if !names["models"] {
|
||||
t.Error("missing 'models' command")
|
||||
}
|
||||
if !names["models list"] {
|
||||
t.Error("missing 'models list' subcommand")
|
||||
}
|
||||
if !names["models install"] {
|
||||
t.Error("missing 'models install' subcommand")
|
||||
}
|
||||
}
|
||||
Expect(names).To(HaveKey("run"))
|
||||
Expect(names).To(HaveKey("models"))
|
||||
Expect(names).To(HaveKey("models list"))
|
||||
Expect(names).To(HaveKey("models install"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -74,7 +74,6 @@ type RunCMD struct {
|
||||
Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"`
|
||||
Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2p-token" aliases:"p2ptoken" help:"Token for P2P mode (optional; --p2ptoken is deprecated, use --p2p-token)" group:"p2p"`
|
||||
Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"`
|
||||
ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"`
|
||||
SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time (deprecated: use --max-active-backends=1 instead)" group:"backends"`
|
||||
MaxActiveBackends int `env:"LOCALAI_MAX_ACTIVE_BACKENDS,MAX_ACTIVE_BACKENDS" default:"0" help:"Maximum number of backends to keep loaded at once (0 = unlimited, 1 = single backend mode). Least recently used backends are evicted when limit is reached" group:"backends"`
|
||||
PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"`
|
||||
@@ -438,10 +437,6 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.WithMemoryReclaimer(true, r.MemoryReclaimerThreshold))
|
||||
}
|
||||
|
||||
if r.ParallelRequests {
|
||||
opts = append(opts, config.EnableParallelBackendRequests)
|
||||
}
|
||||
|
||||
// Handle max active backends (LRU eviction)
|
||||
// MaxActiveBackends takes precedence over SingleActiveBackend
|
||||
if r.MaxActiveBackends > 0 {
|
||||
|
||||
@@ -67,22 +67,28 @@ func isPathAllowed(path string, allowedDirs []string) bool {
|
||||
//
|
||||
// Model loading (LoadModel) is always via direct gRPC — no NATS needed for that.
|
||||
type WorkerCMD struct {
|
||||
Addr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"Address to bind the gRPC server to" group:"server"`
|
||||
// Primary address — the reachable address of this worker.
|
||||
// Host is used for advertise, port is the base for gRPC backends.
|
||||
// HTTP file transfer runs on port-1.
|
||||
Addr string `env:"LOCALAI_ADDR" help:"Address where this worker is reachable (host:port). Port is base for gRPC backends, port-1 for HTTP." group:"server"`
|
||||
ServeAddr string `env:"LOCALAI_SERVE_ADDR" default:"0.0.0.0:50051" help:"(Advanced) gRPC base port bind address" group:"server" hidden:""`
|
||||
|
||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends" group:"server"`
|
||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends" group:"server"`
|
||||
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"server" default:"${backends}"`
|
||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models" group:"server"`
|
||||
|
||||
// HTTP file transfer
|
||||
HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server"`
|
||||
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server"`
|
||||
HTTPAddr string `env:"LOCALAI_HTTP_ADDR" default:"" help:"HTTP file transfer server address (default: gRPC port + 1)" group:"server" hidden:""`
|
||||
AdvertiseHTTPAddr string `env:"LOCALAI_ADVERTISE_HTTP_ADDR" help:"HTTP address the frontend uses to reach this node for file transfer" group:"server" hidden:""`
|
||||
|
||||
// Registration (required)
|
||||
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration"`
|
||||
AdvertiseAddr string `env:"LOCALAI_ADVERTISE_ADDR" help:"Address the frontend uses to reach this node (defaults to hostname:port from Addr)" group:"registration" hidden:""`
|
||||
RegisterTo string `env:"LOCALAI_REGISTER_TO" required:"" help:"Frontend URL for registration" group:"registration"`
|
||||
NodeName string `env:"LOCALAI_NODE_NAME" help:"Node name for registration (defaults to hostname)" group:"registration"`
|
||||
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
|
||||
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
|
||||
NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"`
|
||||
|
||||
// NATS (required)
|
||||
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
|
||||
@@ -96,7 +102,7 @@ type WorkerCMD struct {
|
||||
}
|
||||
|
||||
func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
xlog.Info("Starting worker", "addr", cmd.Addr)
|
||||
xlog.Info("Starting worker", "advertise", cmd.advertiseAddr(), "basePort", cmd.effectiveBasePort())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(cmd.ModelsPath),
|
||||
@@ -181,15 +187,7 @@ func (cmd *WorkerCMD) Run(ctx *cliContext.Context) error {
|
||||
}()
|
||||
|
||||
// Process supervisor — manages multiple backend gRPC processes on different ports
|
||||
basePort := 50051
|
||||
if cmd.Addr != "" {
|
||||
// Extract port from addr (e.g., "0.0.0.0:50051" → 50051)
|
||||
if _, portStr, err := net.SplitHostPort(cmd.Addr); err == nil {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
basePort = p
|
||||
}
|
||||
}
|
||||
}
|
||||
basePort := cmd.effectiveBasePort()
|
||||
// Buffered so NATS stop handler can send without blocking
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
@@ -667,10 +665,10 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
|
||||
|
||||
// Return the gRPC address so the router knows which port to use
|
||||
advertiseAddr := addr
|
||||
if s.cmd.AdvertiseAddr != "" {
|
||||
// Replace 0.0.0.0 with the advertised host but keep the dynamic port
|
||||
advAddr := s.cmd.advertiseAddr()
|
||||
if advAddr != addr { // only remap if advertise differs from bind
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
advertiseHost, _, _ := net.SplitHostPort(s.cmd.AdvertiseAddr)
|
||||
advertiseHost, _, _ := net.SplitHostPort(advAddr)
|
||||
advertiseAddr = net.JoinHostPort(advertiseHost, port)
|
||||
}
|
||||
resp := messaging.BackendInstallReply{Success: true, Address: advertiseAddr}
|
||||
@@ -816,18 +814,31 @@ func (s *backendSupervisor) subscribeLifecycleEvents() {
|
||||
})
|
||||
}
|
||||
|
||||
// effectiveBasePort returns the port used as base for gRPC backend processes.
|
||||
// Priority: Addr port → ServeAddr port → 50051
|
||||
func (cmd *WorkerCMD) effectiveBasePort() int {
|
||||
for _, addr := range []string{cmd.Addr, cmd.ServeAddr} {
|
||||
if addr != "" {
|
||||
if _, portStr, ok := strings.Cut(addr, ":"); ok {
|
||||
if p, _ := strconv.Atoi(portStr); p > 0 {
|
||||
return p
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 50051
|
||||
}
|
||||
|
||||
// advertiseAddr returns the address the frontend should use to reach this node.
|
||||
func (cmd *WorkerCMD) advertiseAddr() string {
|
||||
if cmd.AdvertiseAddr != "" {
|
||||
return cmd.AdvertiseAddr
|
||||
}
|
||||
host, port, ok := strings.Cut(cmd.Addr, ":")
|
||||
if ok && (host == "0.0.0.0" || host == "") {
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
return hostname + ":" + port
|
||||
}
|
||||
if cmd.Addr != "" {
|
||||
return cmd.Addr
|
||||
}
|
||||
return cmd.Addr
|
||||
hostname, _ := os.Hostname()
|
||||
return fmt.Sprintf("%s:%d", cmp.Or(hostname, "localhost"), cmd.effectiveBasePort())
|
||||
}
|
||||
|
||||
// resolveHTTPAddr returns the address to bind the HTTP file transfer server to.
|
||||
@@ -837,12 +848,7 @@ func (cmd *WorkerCMD) resolveHTTPAddr() string {
|
||||
if cmd.HTTPAddr != "" {
|
||||
return cmd.HTTPAddr
|
||||
}
|
||||
host, port, ok := strings.Cut(cmd.Addr, ":")
|
||||
if !ok {
|
||||
return "0.0.0.0:50050"
|
||||
}
|
||||
portNum, _ := strconv.Atoi(port)
|
||||
return fmt.Sprintf("%s:%d", host, portNum-1)
|
||||
return fmt.Sprintf("0.0.0.0:%d", cmd.effectiveBasePort()-1)
|
||||
}
|
||||
|
||||
// advertiseHTTPAddr returns the HTTP address the frontend should use to reach
|
||||
@@ -851,14 +857,9 @@ func (cmd *WorkerCMD) advertiseHTTPAddr() string {
|
||||
if cmd.AdvertiseHTTPAddr != "" {
|
||||
return cmd.AdvertiseHTTPAddr
|
||||
}
|
||||
httpAddr := cmd.resolveHTTPAddr()
|
||||
host, port, ok := strings.Cut(httpAddr, ":")
|
||||
if ok && (host == "0.0.0.0" || host == "") {
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
return hostname + ":" + port
|
||||
}
|
||||
}
|
||||
return httpAddr
|
||||
advHost, _, _ := strings.Cut(cmd.advertiseAddr(), ":")
|
||||
httpPort := cmd.effectiveBasePort() - 1
|
||||
return fmt.Sprintf("%s:%d", advHost, httpPort)
|
||||
}
|
||||
|
||||
// registrationBody builds the JSON body for node registration.
|
||||
@@ -896,6 +897,21 @@ func (cmd *WorkerCMD) registrationBody() map[string]any {
|
||||
if cmd.RegistrationToken != "" {
|
||||
body["token"] = cmd.RegistrationToken
|
||||
}
|
||||
|
||||
// Parse and add static node labels
|
||||
if cmd.NodeLabels != "" {
|
||||
labels := make(map[string]string)
|
||||
for _, pair := range strings.Split(cmd.NodeLabels, ",") {
|
||||
pair = strings.TrimSpace(pair)
|
||||
if k, v, ok := strings.Cut(pair, "="); ok {
|
||||
labels[strings.TrimSpace(k)] = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
if len(labels) > 0 {
|
||||
body["labels"] = labels
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
|
||||
83
core/cli/worker_addr_test.go
Normal file
83
core/cli/worker_addr_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("WorkerCMD address resolution", func() {
|
||||
Describe("effectiveBasePort", func() {
|
||||
DescribeTable("returns the correct port",
|
||||
func(addr, serve string, want int) {
|
||||
cmd := &WorkerCMD{Addr: addr, ServeAddr: serve}
|
||||
Expect(cmd.effectiveBasePort()).To(Equal(want))
|
||||
},
|
||||
Entry("Addr takes priority", "worker1.example.com:60000", "0.0.0.0:50051", 60000),
|
||||
Entry("falls back to ServeAddr", "", "0.0.0.0:50051", 50051),
|
||||
Entry("returns 50051 when neither set", "", "", 50051),
|
||||
Entry("Addr with custom port", "10.0.0.5:7000", "", 7000),
|
||||
Entry("invalid port in Addr falls through to ServeAddr", "host:notanumber", "0.0.0.0:9999", 9999),
|
||||
)
|
||||
})
|
||||
|
||||
Describe("advertiseAddr", func() {
|
||||
It("returns AdvertiseAddr when set", func() {
|
||||
cmd := &WorkerCMD{
|
||||
AdvertiseAddr: "public.example.com:50051",
|
||||
Addr: "10.0.0.5:60000",
|
||||
}
|
||||
Expect(cmd.advertiseAddr()).To(Equal("public.example.com:50051"))
|
||||
})
|
||||
|
||||
It("returns Addr when set", func() {
|
||||
cmd := &WorkerCMD{Addr: "worker1.example.com:60000"}
|
||||
Expect(cmd.advertiseAddr()).To(Equal("worker1.example.com:60000"))
|
||||
})
|
||||
|
||||
It("falls back to hostname:basePort", func() {
|
||||
cmd := &WorkerCMD{ServeAddr: "0.0.0.0:50051"}
|
||||
got := cmd.advertiseAddr()
|
||||
_, port, _ := strings.Cut(got, ":")
|
||||
Expect(port).To(Equal("50051"))
|
||||
|
||||
hostname, _ := os.Hostname()
|
||||
if hostname != "" {
|
||||
host, _, _ := strings.Cut(got, ":")
|
||||
Expect(host).To(Equal(hostname))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("resolveHTTPAddr", func() {
|
||||
DescribeTable("returns the correct address",
|
||||
func(httpAddr, addr, serve, want string) {
|
||||
cmd := &WorkerCMD{HTTPAddr: httpAddr, Addr: addr, ServeAddr: serve}
|
||||
Expect(cmd.resolveHTTPAddr()).To(Equal(want))
|
||||
},
|
||||
Entry("HTTPAddr takes priority", "0.0.0.0:8080", "", "", "0.0.0.0:8080"),
|
||||
Entry("derives from Addr port minus 1", "", "worker1:60000", "0.0.0.0:50051", "0.0.0.0:59999"),
|
||||
Entry("derives from ServeAddr port minus 1", "", "", "0.0.0.0:50051", "0.0.0.0:50050"),
|
||||
Entry("default when nothing set", "", "", "", "0.0.0.0:50050"),
|
||||
)
|
||||
})
|
||||
|
||||
Describe("advertiseHTTPAddr", func() {
|
||||
DescribeTable("returns the correct address",
|
||||
func(advertiseHTTP, advertise, addr, serve, want string) {
|
||||
cmd := &WorkerCMD{
|
||||
AdvertiseHTTPAddr: advertiseHTTP,
|
||||
AdvertiseAddr: advertise,
|
||||
Addr: addr,
|
||||
ServeAddr: serve,
|
||||
}
|
||||
Expect(cmd.advertiseHTTPAddr()).To(Equal(want))
|
||||
},
|
||||
Entry("AdvertiseHTTPAddr takes priority", "public.example.com:8080", "", "", "", "public.example.com:8080"),
|
||||
Entry("derives from advertiseAddr host + basePort-1", "", "", "worker1.example.com:60000", "", "worker1.example.com:59999"),
|
||||
Entry("uses AdvertiseAddr host with basePort-1", "", "public.example.com:60000", "10.0.0.5:60000", "", "public.example.com:59999"),
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -59,8 +59,6 @@ type ApplicationConfig struct {
|
||||
|
||||
SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
ParallelBackendRequests bool
|
||||
|
||||
WatchDogIdle bool
|
||||
WatchDogBusy bool
|
||||
WatchDog bool
|
||||
@@ -379,10 +377,6 @@ func WithLRUEvictionRetryInterval(interval time.Duration) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
var EnableParallelBackendRequests = func(o *ApplicationConfig) {
|
||||
o.ParallelBackendRequests = true
|
||||
}
|
||||
|
||||
var EnableGalleriesAutoload = func(o *ApplicationConfig) {
|
||||
o.AutoloadGalleries = true
|
||||
}
|
||||
@@ -842,7 +836,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
watchdogBusy := o.WatchDogBusy
|
||||
singleBackend := o.SingleBackend
|
||||
maxActiveBackends := o.MaxActiveBackends
|
||||
parallelBackendRequests := o.ParallelBackendRequests
|
||||
memoryReclaimerEnabled := o.MemoryReclaimerEnabled
|
||||
memoryReclaimerThreshold := o.MemoryReclaimerThreshold
|
||||
forceEvictionWhenBusy := o.ForceEvictionWhenBusy
|
||||
@@ -915,7 +908,6 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
WatchdogInterval: &watchdogInterval,
|
||||
SingleBackend: &singleBackend,
|
||||
MaxActiveBackends: &maxActiveBackends,
|
||||
ParallelBackendRequests: ¶llelBackendRequests,
|
||||
MemoryReclaimerEnabled: &memoryReclaimerEnabled,
|
||||
MemoryReclaimerThreshold: &memoryReclaimerThreshold,
|
||||
ForceEvictionWhenBusy: &forceEvictionWhenBusy,
|
||||
@@ -1008,9 +1000,6 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
}
|
||||
requireRestart = true
|
||||
}
|
||||
if settings.ParallelBackendRequests != nil {
|
||||
o.ParallelBackendRequests = *settings.ParallelBackendRequests
|
||||
}
|
||||
if settings.MemoryReclaimerEnabled != nil {
|
||||
o.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled
|
||||
if *settings.MemoryReclaimerEnabled {
|
||||
|
||||
@@ -18,7 +18,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
WatchDogBusyTimeout: 10 * time.Minute,
|
||||
SingleBackend: false,
|
||||
MaxActiveBackends: 5,
|
||||
ParallelBackendRequests: true,
|
||||
MemoryReclaimerEnabled: true,
|
||||
MemoryReclaimerThreshold: 0.85,
|
||||
Threads: 8,
|
||||
@@ -62,9 +61,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(rs.MaxActiveBackends).ToNot(BeNil())
|
||||
Expect(*rs.MaxActiveBackends).To(Equal(5))
|
||||
|
||||
Expect(rs.ParallelBackendRequests).ToNot(BeNil())
|
||||
Expect(*rs.ParallelBackendRequests).To(BeTrue())
|
||||
|
||||
Expect(rs.MemoryReclaimerEnabled).ToNot(BeNil())
|
||||
Expect(*rs.MemoryReclaimerEnabled).To(BeTrue())
|
||||
|
||||
@@ -455,7 +451,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
WatchDogBusyTimeout: 12 * time.Minute,
|
||||
SingleBackend: false,
|
||||
MaxActiveBackends: 3,
|
||||
ParallelBackendRequests: true,
|
||||
MemoryReclaimerEnabled: true,
|
||||
MemoryReclaimerThreshold: 0.92,
|
||||
Threads: 12,
|
||||
@@ -487,7 +482,6 @@ var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() {
|
||||
Expect(target.WatchDogIdleTimeout).To(Equal(original.WatchDogIdleTimeout))
|
||||
Expect(target.WatchDogBusyTimeout).To(Equal(original.WatchDogBusyTimeout))
|
||||
Expect(target.MaxActiveBackends).To(Equal(original.MaxActiveBackends))
|
||||
Expect(target.ParallelBackendRequests).To(Equal(original.ParallelBackendRequests))
|
||||
Expect(target.MemoryReclaimerEnabled).To(Equal(original.MemoryReclaimerEnabled))
|
||||
Expect(target.MemoryReclaimerThreshold).To(Equal(original.MemoryReclaimerThreshold))
|
||||
Expect(target.Threads).To(Equal(original.Threads))
|
||||
|
||||
@@ -20,8 +20,6 @@ type RuntimeSettings struct {
|
||||
// Backend management
|
||||
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"`
|
||||
|
||||
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
|
||||
MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring
|
||||
MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -37,7 +38,7 @@ func nodeError(code int, message string) schema.ErrorResponse {
|
||||
func ListNodesEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
nodeList, err := registry.List(ctx)
|
||||
nodeList, err := registry.ListWithExtras(ctx)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to list nodes", "error", err)
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list nodes"))
|
||||
@@ -70,7 +71,8 @@ type RegisterNodeRequest struct {
|
||||
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
||||
TotalRAM uint64 `json:"total_ram,omitempty"`
|
||||
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
||||
GPUVendor string `json:"gpu_vendor,omitempty"`
|
||||
GPUVendor string `json:"gpu_vendor,omitempty"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterNodeEndpoint registers a new backend node.
|
||||
@@ -148,6 +150,14 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to register node"))
|
||||
}
|
||||
|
||||
// Store static labels from worker and apply auto-labels
|
||||
if len(req.Labels) > 0 {
|
||||
if err := registry.SetNodeLabels(ctx, node.ID, req.Labels); err != nil {
|
||||
xlog.Warn("Failed to set node labels", "node", node.ID, "error", err)
|
||||
}
|
||||
}
|
||||
registry.ApplyAutoLabels(ctx, node.ID, node)
|
||||
|
||||
response := map[string]any{
|
||||
"id": node.ID,
|
||||
"name": node.Name,
|
||||
@@ -618,6 +628,178 @@ func NodeBackendLogsWSEndpoint(registry *nodes.NodeRegistry, registrationToken s
|
||||
}
|
||||
}
|
||||
|
||||
// GetNodeLabelsEndpoint returns labels for a specific node.
|
||||
func GetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
nodeID := c.Param("id")
|
||||
labels, err := registry.GetNodeLabels(ctx, nodeID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get labels"))
|
||||
}
|
||||
// Convert to map for cleaner API response
|
||||
result := make(map[string]string)
|
||||
for _, l := range labels {
|
||||
result[l.Key] = l.Value
|
||||
}
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
// SetNodeLabelsEndpoint replaces all labels for a node.
|
||||
func SetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
nodeID := c.Param("id")
|
||||
if _, err := registry.Get(ctx, nodeID); err != nil {
|
||||
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
||||
}
|
||||
var labels map[string]string
|
||||
if err := c.Bind(&labels); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||
}
|
||||
if err := registry.SetNodeLabels(ctx, nodeID, labels); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set labels"))
|
||||
}
|
||||
return c.JSON(http.StatusOK, labels)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeNodeLabelsEndpoint adds/updates labels without removing existing ones.
|
||||
func MergeNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
nodeID := c.Param("id")
|
||||
if _, err := registry.Get(ctx, nodeID); err != nil {
|
||||
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
|
||||
}
|
||||
var labels map[string]string
|
||||
if err := c.Bind(&labels); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||
}
|
||||
for k, v := range labels {
|
||||
if err := registry.SetNodeLabel(ctx, nodeID, k, v); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to merge labels"))
|
||||
}
|
||||
}
|
||||
// Return updated labels
|
||||
updated, _ := registry.GetNodeLabels(ctx, nodeID)
|
||||
result := make(map[string]string)
|
||||
for _, l := range updated {
|
||||
result[l.Key] = l.Value
|
||||
}
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteNodeLabelEndpoint removes a single label from a node.
|
||||
func DeleteNodeLabelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
nodeID := c.Param("id")
|
||||
key := c.Param("key")
|
||||
if key == "" {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "label key is required"))
|
||||
}
|
||||
if err := registry.RemoveNodeLabel(ctx, nodeID, key); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to remove label"))
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// ListSchedulingEndpoint returns all model scheduling configs.
|
||||
func ListSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
configs, err := registry.ListModelSchedulings(ctx)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to list scheduling configs"))
|
||||
}
|
||||
return c.JSON(http.StatusOK, configs)
|
||||
}
|
||||
}
|
||||
|
||||
// GetSchedulingEndpoint returns the scheduling config for a specific model.
|
||||
func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
modelName := c.Param("model")
|
||||
config, err := registry.GetModelScheduling(ctx, modelName)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to get scheduling config"))
|
||||
}
|
||||
if config == nil {
|
||||
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "no scheduling config for model"))
|
||||
}
|
||||
return c.JSON(http.StatusOK, config)
|
||||
}
|
||||
}
|
||||
|
||||
// SetSchedulingRequest is the request body for creating/updating a scheduling config.
|
||||
type SetSchedulingRequest struct {
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
}
|
||||
|
||||
// SetSchedulingEndpoint creates or updates a model scheduling config.
|
||||
func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
var req SetSchedulingRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
|
||||
}
|
||||
if req.ModelName == "" {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "model_name is required"))
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be >= 0"))
|
||||
}
|
||||
if req.MaxReplicas < 0 {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "max_replicas must be >= 0"))
|
||||
}
|
||||
if req.MaxReplicas > 0 && req.MinReplicas > req.MaxReplicas {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "min_replicas must be <= max_replicas"))
|
||||
}
|
||||
|
||||
// Serialize node selector to JSON
|
||||
var selectorJSON string
|
||||
if len(req.NodeSelector) > 0 {
|
||||
b, err := json.Marshal(req.NodeSelector)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid node_selector"))
|
||||
}
|
||||
selectorJSON = string(b)
|
||||
}
|
||||
|
||||
config := &nodes.ModelSchedulingConfig{
|
||||
ModelName: req.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
}
|
||||
if err := registry.SetModelScheduling(ctx, config); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to set scheduling config"))
|
||||
}
|
||||
return c.JSON(http.StatusOK, config)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteSchedulingEndpoint removes a model scheduling config.
|
||||
func DeleteSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
modelName := c.Param("model")
|
||||
if err := registry.DeleteModelScheduling(ctx, modelName); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to delete scheduling config"))
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// proxyHTTPToWorker makes a GET request to a worker's HTTP server with bearer token auth.
|
||||
func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error) {
|
||||
reqURL := fmt.Sprintf("http://%s%s", httpAddress, path)
|
||||
|
||||
@@ -159,6 +159,62 @@ function WorkerHintCard({ addToast, activeTab, hasWorkers }) {
|
||||
)
|
||||
}
|
||||
|
||||
function SchedulingForm({ onSave, onCancel }) {
|
||||
const [modelName, setModelName] = useState('')
|
||||
const [selectorText, setSelectorText] = useState('')
|
||||
const [minReplicas, setMinReplicas] = useState(0)
|
||||
const [maxReplicas, setMaxReplicas] = useState(0)
|
||||
|
||||
const handleSubmit = () => {
|
||||
let nodeSelector = null
|
||||
if (selectorText.trim()) {
|
||||
const pairs = {}
|
||||
selectorText.split(',').forEach(p => {
|
||||
const [k, v] = p.split('=').map(s => s.trim())
|
||||
if (k) pairs[k] = v || ''
|
||||
})
|
||||
nodeSelector = pairs
|
||||
}
|
||||
onSave({
|
||||
model_name: modelName,
|
||||
node_selector: nodeSelector ? JSON.stringify(nodeSelector) : '',
|
||||
min_replicas: minReplicas,
|
||||
max_replicas: maxReplicas,
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="card" style={{ padding: 'var(--spacing-md)', marginBottom: 'var(--spacing-md)' }}>
|
||||
<div style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: 'var(--spacing-sm)' }}>
|
||||
<div>
|
||||
<label style={{ fontSize: '0.75rem', fontWeight: 500 }}>Model Name</label>
|
||||
<input type="text" value={modelName} onChange={e => setModelName(e.target.value)}
|
||||
placeholder="e.g. llama3" style={{ width: '100%' }} />
|
||||
</div>
|
||||
<div>
|
||||
<label style={{ fontSize: '0.75rem', fontWeight: 500 }}>Node Selector (key=value, comma-separated)</label>
|
||||
<input type="text" value={selectorText} onChange={e => setSelectorText(e.target.value)}
|
||||
placeholder="e.g. gpu.vendor=nvidia,tier=fast" style={{ width: '100%' }} />
|
||||
</div>
|
||||
<div>
|
||||
<label style={{ fontSize: '0.75rem', fontWeight: 500 }}>Min Replicas (0 = no minimum)</label>
|
||||
<input type="number" min={0} value={minReplicas} onChange={e => setMinReplicas(parseInt(e.target.value) || 0)}
|
||||
style={{ width: '100%' }} />
|
||||
</div>
|
||||
<div>
|
||||
<label style={{ fontSize: '0.75rem', fontWeight: 500 }}>Max Replicas (0 = unlimited)</label>
|
||||
<input type="number" min={0} value={maxReplicas} onChange={e => setMaxReplicas(parseInt(e.target.value) || 0)}
|
||||
style={{ width: '100%' }} />
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-sm)', marginTop: 'var(--spacing-sm)', justifyContent: 'flex-end' }}>
|
||||
<button className="btn btn-secondary btn-sm" onClick={onCancel}>Cancel</button>
|
||||
<button className="btn btn-primary btn-sm" onClick={handleSubmit} disabled={!modelName}>Save</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default function Nodes() {
|
||||
const { addToast } = useOutletContext()
|
||||
const navigate = useNavigate()
|
||||
@@ -170,7 +226,9 @@ export default function Nodes() {
|
||||
const [nodeBackends, setNodeBackends] = useState({})
|
||||
const [confirmDelete, setConfirmDelete] = useState(null)
|
||||
const [showTips, setShowTips] = useState(false)
|
||||
const [activeTab, setActiveTab] = useState('backend') // 'backend' or 'agent'
|
||||
const [activeTab, setActiveTab] = useState('backend') // 'backend', 'agent', or 'scheduling'
|
||||
const [schedulingConfigs, setSchedulingConfigs] = useState([])
|
||||
const [showSchedulingForm, setShowSchedulingForm] = useState(false)
|
||||
|
||||
const fetchNodes = useCallback(async () => {
|
||||
try {
|
||||
@@ -186,11 +244,19 @@ export default function Nodes() {
|
||||
}
|
||||
}, [])
|
||||
|
||||
const fetchScheduling = useCallback(async () => {
|
||||
try {
|
||||
const data = await nodesApi.listScheduling()
|
||||
setSchedulingConfigs(Array.isArray(data) ? data : [])
|
||||
} catch { setSchedulingConfigs([]) }
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
fetchNodes()
|
||||
fetchScheduling()
|
||||
const interval = setInterval(fetchNodes, 5000)
|
||||
return () => clearInterval(interval)
|
||||
}, [fetchNodes])
|
||||
}, [fetchNodes, fetchScheduling])
|
||||
|
||||
const fetchModels = useCallback(async (nodeId) => {
|
||||
try {
|
||||
@@ -254,6 +320,36 @@ export default function Nodes() {
|
||||
}
|
||||
}
|
||||
|
||||
const handleUnloadModel = async (nodeId, modelName) => {
|
||||
try {
|
||||
await nodesApi.unloadModel(nodeId, modelName)
|
||||
addToast(`Model "${modelName}" unloaded`, 'success')
|
||||
fetchModels(nodeId)
|
||||
} catch (err) {
|
||||
addToast(`Failed to unload model: ${err.message}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
const handleAddLabel = async (nodeId, key, value) => {
|
||||
try {
|
||||
await nodesApi.mergeLabels(nodeId, { [key]: value })
|
||||
addToast(`Label "${key}=${value}" added`, 'success')
|
||||
fetchNodes()
|
||||
} catch (err) {
|
||||
addToast(`Failed to add label: ${err.message}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
const handleDeleteLabel = async (nodeId, key) => {
|
||||
try {
|
||||
await nodesApi.deleteLabel(nodeId, key)
|
||||
addToast(`Label "${key}" removed`, 'success')
|
||||
fetchNodes()
|
||||
} catch (err) {
|
||||
addToast(`Failed to remove label: ${err.message}`, 'error')
|
||||
}
|
||||
}
|
||||
|
||||
const handleDelete = async (nodeId) => {
|
||||
try {
|
||||
await nodesApi.delete(nodeId)
|
||||
@@ -422,8 +518,23 @@ export default function Nodes() {
|
||||
<i className="fas fa-robot" style={{ marginRight: 6 }} />
|
||||
Agent Workers ({agentNodes.length})
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setActiveTab('scheduling')}
|
||||
style={{
|
||||
padding: 'var(--spacing-sm) var(--spacing-lg)',
|
||||
border: 'none', cursor: 'pointer', fontWeight: 600, fontSize: '0.875rem',
|
||||
background: 'none',
|
||||
color: activeTab === 'scheduling' ? 'var(--color-primary)' : 'var(--color-text-muted)',
|
||||
borderBottom: activeTab === 'scheduling' ? '2px solid var(--color-primary)' : '2px solid transparent',
|
||||
marginBottom: '-2px',
|
||||
}}
|
||||
>
|
||||
<i className="fas fa-calendar-alt" style={{ marginRight: 6 }} />
|
||||
Scheduling ({schedulingConfigs.length})
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{activeTab !== 'scheduling' && <>
|
||||
{/* Stat cards */}
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-md)', marginBottom: 'var(--spacing-xl)', flexWrap: 'wrap' }}>
|
||||
<StatCard icon={activeTab === 'agent' ? 'fas fa-robot' : 'fas fa-server'} label={`Total ${activeTab === 'agent' ? 'Agent' : 'Backend'} Workers`} value={total} />
|
||||
@@ -433,6 +544,23 @@ export default function Nodes() {
|
||||
{pending > 0 && (
|
||||
<StatCard icon="fas fa-clock" label="Pending" value={pending} color="var(--color-warning)" />
|
||||
)}
|
||||
{activeTab === 'backend' && (() => {
|
||||
const clusterTotalVRAM = backendNodes.reduce((sum, n) => sum + (n.total_vram || 0), 0)
|
||||
const clusterUsedVRAM = backendNodes.reduce((sum, n) => {
|
||||
if (n.total_vram && n.available_vram != null) return sum + (n.total_vram - n.available_vram)
|
||||
return sum
|
||||
}, 0)
|
||||
const totalModelsLoaded = backendNodes.reduce((sum, n) => sum + (n.model_count || 0), 0)
|
||||
return (
|
||||
<>
|
||||
{clusterTotalVRAM > 0 && (
|
||||
<StatCard icon="fas fa-microchip" label="Cluster VRAM"
|
||||
value={`${formatVRAM(clusterUsedVRAM) || '0'} / ${formatVRAM(clusterTotalVRAM)}`} />
|
||||
)}
|
||||
<StatCard icon="fas fa-cube" label="Models Loaded" value={totalModelsLoaded} />
|
||||
</>
|
||||
)
|
||||
})()}
|
||||
</div>
|
||||
|
||||
{/* Worker tips */}
|
||||
@@ -506,6 +634,22 @@ export default function Nodes() {
|
||||
<div style={{ fontSize: '0.75rem', fontFamily: "'JetBrains Mono', monospace", color: 'var(--color-text-muted)' }}>
|
||||
{node.address}
|
||||
</div>
|
||||
{node.labels && Object.keys(node.labels).length > 0 && (
|
||||
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 3, marginTop: 3 }}>
|
||||
{Object.entries(node.labels).slice(0, 5).map(([k, v]) => (
|
||||
<span key={k} style={{
|
||||
fontSize: '0.625rem', padding: '1px 5px', borderRadius: 3,
|
||||
background: 'var(--color-bg-tertiary)', color: 'var(--color-text-muted)',
|
||||
fontFamily: "'JetBrains Mono', monospace", border: '1px solid var(--color-border-subtle)',
|
||||
}}>{k}={v}</span>
|
||||
))}
|
||||
{Object.keys(node.labels).length > 5 && (
|
||||
<span style={{ fontSize: '0.625rem', color: 'var(--color-text-muted)' }}>
|
||||
+{Object.keys(node.labels).length - 5} more
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
@@ -593,6 +737,7 @@ export default function Nodes() {
|
||||
<th>State</th>
|
||||
<th>In-Flight</th>
|
||||
<th style={{ width: 40 }}>Logs</th>
|
||||
<th style={{ textAlign: 'right' }}>Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
@@ -628,6 +773,21 @@ export default function Nodes() {
|
||||
<i className="fas fa-terminal" />
|
||||
</a>
|
||||
</td>
|
||||
<td style={{ textAlign: 'right' }}>
|
||||
<button
|
||||
className="btn btn-danger btn-sm"
|
||||
disabled={m.in_flight > 0}
|
||||
title={m.in_flight > 0 ? 'Cannot unload while serving requests' : 'Unload model'}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
if (confirm(`Unload "${m.model_name}" from ${node.name}?`)) {
|
||||
handleUnloadModel(node.id, m.model_name)
|
||||
}
|
||||
}}
|
||||
>
|
||||
<i className="fas fa-stop" />
|
||||
</button>
|
||||
</td>
|
||||
</tr>
|
||||
)
|
||||
})}
|
||||
@@ -689,6 +849,50 @@ export default function Nodes() {
|
||||
</tbody>
|
||||
</table>
|
||||
)}
|
||||
|
||||
{/* Labels */}
|
||||
<div style={{ marginTop: 'var(--spacing-md)' }}>
|
||||
<h4 style={{ fontSize: '0.8125rem', fontWeight: 600, marginBottom: 'var(--spacing-sm)', color: 'var(--color-text-secondary)' }}>
|
||||
<i className="fas fa-tags" style={{ marginRight: 6 }} />
|
||||
Labels
|
||||
</h4>
|
||||
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 'var(--spacing-xs)', marginBottom: 'var(--spacing-sm)' }}>
|
||||
{node.labels && Object.entries(node.labels).map(([k, v]) => (
|
||||
<span key={k} style={{
|
||||
display: 'inline-flex', alignItems: 'center', gap: 4,
|
||||
fontSize: '0.75rem', padding: '2px 8px', borderRadius: 4,
|
||||
background: 'var(--color-bg-tertiary)', border: '1px solid var(--color-border-subtle)',
|
||||
fontFamily: "'JetBrains Mono', monospace",
|
||||
}}>
|
||||
{k}={v}
|
||||
<button
|
||||
onClick={(e) => { e.stopPropagation(); handleDeleteLabel(node.id, k) }}
|
||||
style={{ background: 'none', border: 'none', cursor: 'pointer', color: 'var(--color-text-muted)', fontSize: '0.625rem', padding: 0 }}
|
||||
title="Remove label"
|
||||
>
|
||||
<i className="fas fa-times" />
|
||||
</button>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
{/* Add label form */}
|
||||
<div style={{ display: 'flex', gap: 'var(--spacing-xs)', alignItems: 'center' }}>
|
||||
<input
|
||||
type="text" placeholder="key" style={{ width: 100, fontSize: '0.75rem' }}
|
||||
id={`label-key-${node.id}`}
|
||||
/>
|
||||
<input
|
||||
type="text" placeholder="value" style={{ width: 100, fontSize: '0.75rem' }}
|
||||
id={`label-value-${node.id}`}
|
||||
/>
|
||||
<button className="btn btn-secondary btn-sm" onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
const key = document.getElementById(`label-key-${node.id}`).value.trim()
|
||||
const val = document.getElementById(`label-value-${node.id}`).value.trim()
|
||||
if (key) handleAddLabel(node.id, key, val)
|
||||
}}>Add</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
@@ -700,6 +904,78 @@ export default function Nodes() {
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</>}
|
||||
|
||||
{activeTab === 'scheduling' && (
|
||||
<div>
|
||||
<button className="btn btn-primary btn-sm" style={{ marginBottom: 'var(--spacing-md)' }}
|
||||
onClick={() => setShowSchedulingForm(f => !f)}>
|
||||
<i className="fas fa-plus" style={{ marginRight: 6 }} />
|
||||
Add Scheduling Rule
|
||||
</button>
|
||||
{showSchedulingForm && <SchedulingForm onSave={async (config) => {
|
||||
try {
|
||||
await nodesApi.setScheduling(config)
|
||||
fetchScheduling()
|
||||
setShowSchedulingForm(false)
|
||||
addToast('Scheduling rule saved', 'success')
|
||||
} catch (err) {
|
||||
addToast(`Failed to save rule: ${err.message}`, 'error')
|
||||
}
|
||||
}} onCancel={() => setShowSchedulingForm(false)} />}
|
||||
{schedulingConfigs.length === 0 && !showSchedulingForm ? (
|
||||
<p style={{ fontSize: '0.875rem', color: 'var(--color-text-muted)', textAlign: 'center', padding: 'var(--spacing-xl) 0' }}>
|
||||
No scheduling rules configured. Add a rule to control how models are placed on nodes.
|
||||
</p>
|
||||
) : schedulingConfigs.length > 0 && (
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<thead><tr>
|
||||
<th>Model</th>
|
||||
<th>Node Selector</th>
|
||||
<th>Min Replicas</th>
|
||||
<th>Max Replicas</th>
|
||||
<th style={{ textAlign: 'right' }}>Actions</th>
|
||||
</tr></thead>
|
||||
<tbody>
|
||||
{schedulingConfigs.map(cfg => (
|
||||
<tr key={cfg.id || cfg.model_name}>
|
||||
<td style={{ fontWeight: 600, fontSize: '0.875rem' }}>{cfg.model_name}</td>
|
||||
<td>
|
||||
{cfg.node_selector ? (() => {
|
||||
try {
|
||||
const sel = typeof cfg.node_selector === 'string' ? JSON.parse(cfg.node_selector) : cfg.node_selector
|
||||
return Object.entries(sel).map(([k,v]) => (
|
||||
<span key={k} style={{
|
||||
display: 'inline-block', fontSize: '0.75rem', padding: '2px 6px', borderRadius: 3,
|
||||
background: 'var(--color-bg-tertiary)', border: '1px solid var(--color-border-subtle)',
|
||||
fontFamily: "'JetBrains Mono', monospace", marginRight: 4,
|
||||
}}>{k}={v}</span>
|
||||
))
|
||||
} catch { return <span style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem' }}>{cfg.node_selector}</span> }
|
||||
})() : <span style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem' }}>Any node</span>}
|
||||
</td>
|
||||
<td style={{ fontFamily: "'JetBrains Mono', monospace" }}>{cfg.min_replicas || '-'}</td>
|
||||
<td style={{ fontFamily: "'JetBrains Mono', monospace" }}>{cfg.max_replicas || 'unlimited'}</td>
|
||||
<td style={{ textAlign: 'right' }}>
|
||||
<button className="btn btn-danger btn-sm" onClick={async () => {
|
||||
try {
|
||||
await nodesApi.deleteScheduling(cfg.model_name)
|
||||
fetchScheduling()
|
||||
addToast('Rule deleted', 'success')
|
||||
} catch (err) {
|
||||
addToast(`Failed to delete rule: ${err.message}`, 'error')
|
||||
}
|
||||
}}><i className="fas fa-trash" /></button>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<ConfirmDialog
|
||||
open={!!confirmDelete}
|
||||
|
||||
@@ -266,9 +266,6 @@ export default function Settings() {
|
||||
<SettingRow label="Max Active Backends" description="Maximum models to keep loaded simultaneously (0 = unlimited)">
|
||||
<input className="input" type="number" style={{ width: 120 }} value={settings.max_active_backends ?? ''} onChange={(e) => update('max_active_backends', parseInt(e.target.value) || 0)} placeholder="0" />
|
||||
</SettingRow>
|
||||
<SettingRow label="Parallel Backend Requests" description="Enable parallel request handling per backend">
|
||||
<Toggle checked={settings.parallel_backend_requests} onChange={(v) => update('parallel_backend_requests', v)} />
|
||||
</SettingRow>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
7
core/http/react-ui/src/utils/api.js
vendored
7
core/http/react-ui/src/utils/api.js
vendored
@@ -423,6 +423,13 @@ export const nodesApi = {
|
||||
deleteBackend: (id, backend) => postJSON(API_CONFIG.endpoints.nodeBackendsDelete(id), { backend }),
|
||||
getBackendLogs: (id) => fetchJSON(API_CONFIG.endpoints.nodeBackendLogs(id)),
|
||||
getBackendLogLines: (id, modelId) => fetchJSON(API_CONFIG.endpoints.nodeBackendLogsModel(id, modelId)),
|
||||
unloadModel: (id, modelName) => postJSON(API_CONFIG.endpoints.nodeModelsUnload(id), { model_name: modelName }),
|
||||
getLabels: (id) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id)),
|
||||
mergeLabels: (id, labels) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id), { method: 'PATCH', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(labels) }),
|
||||
deleteLabel: (id, key) => fetchJSON(API_CONFIG.endpoints.nodeLabelKey(id, key), { method: 'DELETE' }),
|
||||
listScheduling: () => fetchJSON(API_CONFIG.endpoints.nodesScheduling),
|
||||
setScheduling: (config) => postJSON(API_CONFIG.endpoints.nodesScheduling, config),
|
||||
deleteScheduling: (model) => fetchJSON(API_CONFIG.endpoints.nodesSchedulingModel(model), { method: 'DELETE' }),
|
||||
}
|
||||
|
||||
// File to base64 helper
|
||||
|
||||
5
core/http/react-ui/src/utils/config.js
vendored
5
core/http/react-ui/src/utils/config.js
vendored
@@ -107,5 +107,10 @@ export const API_CONFIG = {
|
||||
nodeBackendsDelete: (id) => `/api/nodes/${id}/backends/delete`,
|
||||
nodeBackendLogs: (id) => `/api/nodes/${id}/backend-logs`,
|
||||
nodeBackendLogsModel: (id, modelId) => `/api/nodes/${id}/backend-logs/${encodeURIComponent(modelId)}`,
|
||||
nodeModelsUnload: (id) => `/api/nodes/${id}/models/unload`,
|
||||
nodeLabels: (id) => `/api/nodes/${id}/labels`,
|
||||
nodeLabelKey: (id, key) => `/api/nodes/${id}/labels/${key}`,
|
||||
nodesScheduling: '/api/nodes/scheduling',
|
||||
nodesSchedulingModel: (model) => `/api/nodes/scheduling/${encodeURIComponent(model)}`,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -61,6 +61,13 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
|
||||
|
||||
admin := e.Group("/api/nodes", readyMw, adminMw)
|
||||
admin.GET("", localai.ListNodesEndpoint(registry))
|
||||
|
||||
// Model scheduling (registered before /:id to avoid route conflicts)
|
||||
admin.GET("/scheduling", localai.ListSchedulingEndpoint(registry))
|
||||
admin.GET("/scheduling/:model", localai.GetSchedulingEndpoint(registry))
|
||||
admin.POST("/scheduling", localai.SetSchedulingEndpoint(registry))
|
||||
admin.DELETE("/scheduling/:model", localai.DeleteSchedulingEndpoint(registry))
|
||||
|
||||
admin.GET("/:id", localai.GetNodeEndpoint(registry))
|
||||
admin.GET("/:id/models", localai.GetNodeModelsEndpoint(registry))
|
||||
admin.DELETE("/:id", localai.DeregisterNodeEndpoint(registry))
|
||||
@@ -80,6 +87,12 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
|
||||
admin.GET("/:id/backend-logs", localai.NodeBackendLogsListEndpoint(registry, registrationToken))
|
||||
admin.GET("/:id/backend-logs/:modelId", localai.NodeBackendLogsLinesEndpoint(registry, registrationToken))
|
||||
|
||||
// Label management
|
||||
admin.GET("/:id/labels", localai.GetNodeLabelsEndpoint(registry))
|
||||
admin.PUT("/:id/labels", localai.SetNodeLabelsEndpoint(registry))
|
||||
admin.PATCH("/:id/labels", localai.MergeNodeLabelsEndpoint(registry))
|
||||
admin.DELETE("/:id/labels/:key", localai.DeleteNodeLabelEndpoint(registry))
|
||||
|
||||
// WebSocket proxy for real-time log streaming from workers
|
||||
e.GET("/ws/nodes/:id/backend-logs/:modelId", localai.NodeBackendLogsWSEndpoint(registry, registrationToken), readyMw, adminMw)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,12 @@ type ModelRouter interface {
|
||||
FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (*NodeModel, error)
|
||||
FindLRUModel(ctx context.Context, nodeID string) (*NodeModel, error)
|
||||
Get(ctx context.Context, nodeID string) (*BackendNode, error)
|
||||
GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error)
|
||||
FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error)
|
||||
FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error)
|
||||
FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
|
||||
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
|
||||
GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
|
||||
}
|
||||
|
||||
// NodeHealthStore is used by HealthMonitor for node status management.
|
||||
|
||||
@@ -72,6 +72,24 @@ func (f *fakeModelRouterForSmartRouter) Get(_ context.Context, nodeID string) (*
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Compile-time check
|
||||
var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil)
|
||||
|
||||
236
core/services/nodes/reconciler.go
Normal file
236
core/services/nodes/reconciler.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/advisorylock"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ReplicaReconciler periodically ensures model replica counts match their
|
||||
// scheduling configs. It scales up replicas when below MinReplicas or when
|
||||
// all replicas are busy (up to MaxReplicas), and scales down idle replicas
|
||||
// above MinReplicas.
|
||||
//
|
||||
// Only processes models with auto-scaling enabled (MinReplicas > 0 or MaxReplicas > 0).
|
||||
type ReplicaReconciler struct {
|
||||
registry *NodeRegistry
|
||||
scheduler ModelScheduler // interface for scheduling new models
|
||||
unloader NodeCommandSender
|
||||
db *gorm.DB
|
||||
interval time.Duration
|
||||
scaleDownDelay time.Duration
|
||||
}
|
||||
|
||||
// ModelScheduler abstracts the scheduling logic needed by the reconciler.
|
||||
// SmartRouter implements this interface.
|
||||
type ModelScheduler interface {
|
||||
// ScheduleAndLoadModel picks a node (optionally from candidateNodeIDs),
|
||||
// installs the backend, and loads the model. Returns the node used.
|
||||
ScheduleAndLoadModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error)
|
||||
}
|
||||
|
||||
// ReplicaReconcilerOptions holds configuration for creating a ReplicaReconciler.
|
||||
type ReplicaReconcilerOptions struct {
|
||||
Registry *NodeRegistry
|
||||
Scheduler ModelScheduler
|
||||
Unloader NodeCommandSender
|
||||
DB *gorm.DB
|
||||
Interval time.Duration // default 30s
|
||||
ScaleDownDelay time.Duration // default 5m
|
||||
}
|
||||
|
||||
// NewReplicaReconciler creates a new ReplicaReconciler.
|
||||
func NewReplicaReconciler(opts ReplicaReconcilerOptions) *ReplicaReconciler {
|
||||
interval := opts.Interval
|
||||
if interval == 0 {
|
||||
interval = 30 * time.Second
|
||||
}
|
||||
scaleDownDelay := opts.ScaleDownDelay
|
||||
if scaleDownDelay == 0 {
|
||||
scaleDownDelay = 5 * time.Minute
|
||||
}
|
||||
return &ReplicaReconciler{
|
||||
registry: opts.Registry,
|
||||
scheduler: opts.Scheduler,
|
||||
unloader: opts.Unloader,
|
||||
db: opts.DB,
|
||||
interval: interval,
|
||||
scaleDownDelay: scaleDownDelay,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the reconciliation loop. It blocks until ctx is cancelled.
|
||||
func (rc *ReplicaReconciler) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(rc.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rc.reconcileOnce(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconcileOnce performs a single reconciliation pass.
|
||||
// Uses an advisory lock so only one frontend instance reconciles at a time.
|
||||
func (rc *ReplicaReconciler) reconcileOnce(ctx context.Context) {
|
||||
if rc.db != nil {
|
||||
lockKey := advisorylock.KeyFromString("replica-reconciler")
|
||||
_ = advisorylock.WithLockCtx(ctx, rc.db, lockKey, func() error {
|
||||
rc.reconcile(ctx)
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
rc.reconcile(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *ReplicaReconciler) reconcile(ctx context.Context) {
|
||||
configs, err := rc.registry.ListAutoScalingConfigs(ctx)
|
||||
if err != nil {
|
||||
xlog.Warn("Reconciler: failed to list auto-scaling configs", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, cfg := range configs {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return // context cancelled
|
||||
}
|
||||
rc.reconcileModel(ctx, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedulingConfig) {
|
||||
current, err := rc.registry.CountLoadedReplicas(ctx, cfg.ModelName)
|
||||
if err != nil {
|
||||
xlog.Warn("Reconciler: failed to count replicas", "model", cfg.ModelName, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 1. Ensure minimum replicas
|
||||
if cfg.MinReplicas > 0 && int(current) < cfg.MinReplicas {
|
||||
needed := cfg.MinReplicas - int(current)
|
||||
xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName,
|
||||
"current", current, "min", cfg.MinReplicas, "adding", needed)
|
||||
rc.scaleUp(ctx, cfg, needed)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Auto-scale up if all replicas are busy
|
||||
if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
|
||||
if rc.allReplicasBusy(ctx, cfg.ModelName) {
|
||||
xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName,
|
||||
"current", current)
|
||||
rc.scaleUp(ctx, cfg, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Scale down idle replicas above minimum
|
||||
floor := cfg.MinReplicas
|
||||
if floor < 1 {
|
||||
floor = 1
|
||||
}
|
||||
if int(current) > floor {
|
||||
rc.scaleDownIdle(ctx, cfg, int(current), floor)
|
||||
}
|
||||
}
|
||||
|
||||
// scaleUp schedules additional replicas of the model.
|
||||
func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) {
|
||||
if rc.scheduler == nil {
|
||||
xlog.Warn("Reconciler: no scheduler available, cannot scale up")
|
||||
return
|
||||
}
|
||||
|
||||
// Determine candidate nodes from selector
|
||||
var candidateNodeIDs []string
|
||||
if cfg.NodeSelector != "" {
|
||||
selector := parseSelector(cfg.NodeSelector)
|
||||
if len(selector) > 0 {
|
||||
candidates, err := rc.registry.FindNodesBySelector(ctx, selector)
|
||||
if err != nil || len(candidates) == 0 {
|
||||
xlog.Warn("Reconciler: no nodes match selector", "model", cfg.ModelName,
|
||||
"selector", cfg.NodeSelector)
|
||||
return
|
||||
}
|
||||
candidateNodeIDs = make([]string, len(candidates))
|
||||
for i, n := range candidates {
|
||||
candidateNodeIDs[i] = n.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
node, err := rc.scheduler.ScheduleAndLoadModel(ctx, cfg.ModelName, candidateNodeIDs)
|
||||
if err != nil {
|
||||
xlog.Warn("Reconciler: failed to scale up replica", "model", cfg.ModelName,
|
||||
"attempt", i+1, "error", err)
|
||||
return // stop trying on first failure
|
||||
}
|
||||
xlog.Info("Reconciler: scaled up replica", "model", cfg.ModelName, "node", node.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// scaleDownIdle removes idle replicas above the floor.
|
||||
func (rc *ReplicaReconciler) scaleDownIdle(ctx context.Context, cfg ModelSchedulingConfig, current, floor int) {
|
||||
if rc.unloader == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Find idle replicas that have been unused for longer than scaleDownDelay
|
||||
cutoff := time.Now().Add(-rc.scaleDownDelay)
|
||||
var idleModels []NodeModel
|
||||
rc.registry.db.WithContext(ctx).
|
||||
Where("model_name = ? AND state = ? AND in_flight = 0 AND last_used < ?",
|
||||
cfg.ModelName, "loaded", cutoff).
|
||||
Order("last_used ASC").
|
||||
Find(&idleModels)
|
||||
|
||||
toRemove := current - floor
|
||||
removed := 0
|
||||
for _, nm := range idleModels {
|
||||
if removed >= toRemove {
|
||||
break
|
||||
}
|
||||
// Remove from registry
|
||||
if err := rc.registry.RemoveNodeModel(ctx, nm.NodeID, nm.ModelName); err != nil {
|
||||
xlog.Warn("Reconciler: failed to remove model record", "error", err)
|
||||
continue
|
||||
}
|
||||
// Unload from worker
|
||||
if err := rc.unloader.UnloadModelOnNode(nm.NodeID, nm.ModelName); err != nil {
|
||||
xlog.Warn("Reconciler: unload failed (model already removed from registry)", "error", err)
|
||||
}
|
||||
xlog.Info("Reconciler: scaled down idle replica", "model", cfg.ModelName, "node", nm.NodeID)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
// allReplicasBusy returns true if all loaded replicas of a model have in-flight requests.
|
||||
func (rc *ReplicaReconciler) allReplicasBusy(ctx context.Context, modelName string) bool {
|
||||
var idleCount int64
|
||||
rc.registry.db.WithContext(ctx).Model(&NodeModel{}).
|
||||
Where("model_name = ? AND state = ? AND in_flight = 0", modelName, "loaded").
|
||||
Count(&idleCount)
|
||||
return idleCount == 0
|
||||
}
|
||||
|
||||
// parseSelector decodes a JSON node selector string into a map.
|
||||
func parseSelector(selectorJSON string) map[string]string {
|
||||
if selectorJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var selector map[string]string
|
||||
if err := json.Unmarshal([]byte(selectorJSON), &selector); err != nil {
|
||||
xlog.Warn("Failed to parse node selector", "selector", selectorJSON, "error", err)
|
||||
return nil
|
||||
}
|
||||
return selector
|
||||
}
|
||||
241
core/services/nodes/reconciler_test.go
Normal file
241
core/services/nodes/reconciler_test.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fake ModelScheduler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type fakeScheduler struct {
|
||||
scheduleNode *BackendNode
|
||||
scheduleErr error
|
||||
scheduleCalls []scheduleCall
|
||||
}
|
||||
|
||||
type scheduleCall struct {
|
||||
modelName string
|
||||
candidateIDs []string
|
||||
}
|
||||
|
||||
func (f *fakeScheduler) ScheduleAndLoadModel(_ context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) {
|
||||
f.scheduleCalls = append(f.scheduleCalls, scheduleCall{modelName, candidateNodeIDs})
|
||||
return f.scheduleNode, f.scheduleErr
|
||||
}
|
||||
|
||||
var _ = Describe("ReplicaReconciler", func() {
|
||||
var (
|
||||
db *gorm.DB
|
||||
registry *NodeRegistry
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
Skip("testcontainers requires Docker, not available on macOS CI")
|
||||
}
|
||||
db = testutil.SetupTestDB()
|
||||
var err error
|
||||
registry, err = NewNodeRegistry(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
// Helper to register a healthy node.
|
||||
registerNode := func(name, address string) *BackendNode {
|
||||
node := &BackendNode{
|
||||
Name: name,
|
||||
NodeType: NodeTypeBackend,
|
||||
Address: address,
|
||||
}
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
return node
|
||||
}
|
||||
|
||||
// Helper to set up a scheduling config.
|
||||
setSchedulingConfig := func(modelName string, minReplicas, maxReplicas int, nodeSelector string) {
|
||||
cfg := &ModelSchedulingConfig{
|
||||
ModelName: modelName,
|
||||
MinReplicas: minReplicas,
|
||||
MaxReplicas: maxReplicas,
|
||||
NodeSelector: nodeSelector,
|
||||
}
|
||||
Expect(registry.SetModelScheduling(context.Background(), cfg)).To(Succeed())
|
||||
}
|
||||
|
||||
Context("model below min_replicas", func() {
|
||||
It("scales up to min_replicas", func() {
|
||||
node := registerNode("node-1", "10.0.0.1:50051")
|
||||
setSchedulingConfig("model-a", 2, 4, "")
|
||||
|
||||
scheduler := &fakeScheduler{
|
||||
scheduleNode: node,
|
||||
}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
})
|
||||
|
||||
// No replicas loaded — should schedule 2
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(2))
|
||||
Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-a"))
|
||||
Expect(scheduler.scheduleCalls[1].modelName).To(Equal("model-a"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("all replicas busy and below max_replicas", func() {
|
||||
It("scales up by 1", func() {
|
||||
node := registerNode("node-busy", "10.0.0.2:50051")
|
||||
setSchedulingConfig("model-b", 1, 4, "")
|
||||
|
||||
// Load 2 replicas, both busy (in_flight > 0)
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", "loaded", "addr1", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-b")).To(Succeed())
|
||||
|
||||
node2 := registerNode("node-busy-2", "10.0.0.3:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-b", "loaded", "addr2", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-b")).To(Succeed())
|
||||
|
||||
scheduler := &fakeScheduler{
|
||||
scheduleNode: node,
|
||||
}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1))
|
||||
Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-b"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("all replicas busy and at max_replicas", func() {
|
||||
It("does not scale up", func() {
|
||||
node := registerNode("node-max", "10.0.0.4:50051")
|
||||
setSchedulingConfig("model-c", 1, 2, "")
|
||||
|
||||
// Load 2 replicas (at max), both busy
|
||||
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-c", "loaded", "addr1", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-c")).To(Succeed())
|
||||
|
||||
node2 := registerNode("node-max-2", "10.0.0.5:50051")
|
||||
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-c", "loaded", "addr2", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-c")).To(Succeed())
|
||||
|
||||
scheduler := &fakeScheduler{
|
||||
scheduleNode: node,
|
||||
}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("idle replicas above min_replicas", func() {
|
||||
It("scales down after idle delay", func() {
|
||||
node1 := registerNode("node-idle-1", "10.0.0.6:50051")
|
||||
node2 := registerNode("node-idle-2", "10.0.0.7:50051")
|
||||
node3 := registerNode("node-idle-3", "10.0.0.8:50051")
|
||||
setSchedulingConfig("model-d", 1, 4, "")
|
||||
|
||||
// Load 3 replicas, all idle with last_used in the past
|
||||
pastTime := time.Now().Add(-10 * time.Minute)
|
||||
for _, n := range []*BackendNode{node1, node2, node3} {
|
||||
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-d", "loaded", "", 0)).To(Succeed())
|
||||
// Set last_used to past time to trigger scale-down
|
||||
db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-d").
|
||||
Update("last_used", pastTime)
|
||||
}
|
||||
|
||||
unloader := &fakeUnloader{}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Unloader: unloader,
|
||||
DB: db,
|
||||
ScaleDownDelay: 1 * time.Minute, // short delay for test
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
// Should scale down 2 replicas (3 - floor of 1)
|
||||
Expect(unloader.unloadCalls).To(HaveLen(2))
|
||||
})
|
||||
})
|
||||
|
||||
Context("idle replicas at min_replicas", func() {
|
||||
It("does not scale down", func() {
|
||||
node1 := registerNode("node-keep-1", "10.0.0.9:50051")
|
||||
node2 := registerNode("node-keep-2", "10.0.0.10:50051")
|
||||
setSchedulingConfig("model-e", 2, 4, "")
|
||||
|
||||
// Load exactly 2 replicas (at min), both idle with past last_used
|
||||
pastTime := time.Now().Add(-10 * time.Minute)
|
||||
for _, n := range []*BackendNode{node1, node2} {
|
||||
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-e", "loaded", "", 0)).To(Succeed())
|
||||
db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-e").
|
||||
Update("last_used", pastTime)
|
||||
}
|
||||
|
||||
unloader := &fakeUnloader{}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Unloader: unloader,
|
||||
DB: db,
|
||||
ScaleDownDelay: 1 * time.Minute,
|
||||
})
|
||||
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(unloader.unloadCalls).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("model with node_selector", func() {
|
||||
It("passes candidate node IDs to scheduler", func() {
|
||||
node1 := registerNode("gpu-node", "10.0.0.11:50051")
|
||||
node2 := registerNode("cpu-node", "10.0.0.12:50051")
|
||||
|
||||
// Add labels — only node1 matches the selector
|
||||
Expect(registry.SetNodeLabel(context.Background(), node1.ID, "gpu.vendor", "nvidia")).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), node2.ID, "gpu.vendor", "none")).To(Succeed())
|
||||
|
||||
setSchedulingConfig("model-f", 1, 2, `{"gpu.vendor":"nvidia"}`)
|
||||
|
||||
scheduler := &fakeScheduler{
|
||||
scheduleNode: node1,
|
||||
}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
DB: db,
|
||||
})
|
||||
|
||||
// No replicas loaded — should schedule 1 with candidate node IDs
|
||||
reconciler.reconcile(context.Background())
|
||||
|
||||
Expect(scheduler.scheduleCalls).To(HaveLen(1))
|
||||
Expect(scheduler.scheduleCalls[0].modelName).To(Equal("model-f"))
|
||||
Expect(scheduler.scheduleCalls[0].candidateIDs).To(ContainElement(node1.ID))
|
||||
Expect(scheduler.scheduleCalls[0].candidateIDs).ToNot(ContainElement(node2.ID))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -68,6 +68,39 @@ type NodeModel struct {
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// NodeLabel is a key-value label on a node (like K8s labels).
|
||||
type NodeLabel struct {
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
NodeID string `gorm:"uniqueIndex:idx_node_label;size:36" json:"node_id"`
|
||||
Key string `gorm:"uniqueIndex:idx_node_label;size:128" json:"key"`
|
||||
Value string `gorm:"size:255" json:"value"`
|
||||
}
|
||||
|
||||
// ModelSchedulingConfig defines how a model should be scheduled across the cluster.
|
||||
// All fields are optional:
|
||||
// - NodeSelector only → constrain nodes, single replica
|
||||
// - MinReplicas/MaxReplicas only → auto-scale on any node
|
||||
// - Both → auto-scale on matching nodes
|
||||
// - Neither → no-op (default behavior)
|
||||
//
|
||||
// Auto-scaling is enabled when MinReplicas > 0 or MaxReplicas > 0.
|
||||
type ModelSchedulingConfig struct {
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
ModelName string `gorm:"uniqueIndex;size:255" json:"model_name"`
|
||||
NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...}
|
||||
MinReplicas int `gorm:"default:0" json:"min_replicas"`
|
||||
MaxReplicas int `gorm:"default:0" json:"max_replicas"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// NodeWithExtras extends BackendNode with computed fields for list views.
|
||||
type NodeWithExtras struct {
|
||||
BackendNode
|
||||
ModelCount int `json:"model_count"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
// NodeRegistry manages backend node registration and lookup in PostgreSQL.
|
||||
type NodeRegistry struct {
|
||||
db *gorm.DB
|
||||
@@ -78,7 +111,7 @@ type NodeRegistry struct {
|
||||
// when multiple instances (frontend + workers) start at the same time.
|
||||
func NewNodeRegistry(db *gorm.DB) (*NodeRegistry, error) {
|
||||
if err := advisorylock.WithLockCtx(context.Background(), db, advisorylock.KeySchemaMigrate, func() error {
|
||||
return db.AutoMigrate(&BackendNode{}, &NodeModel{})
|
||||
return db.AutoMigrate(&BackendNode{}, &NodeModel{}, &NodeLabel{}, &ModelSchedulingConfig{})
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("migrating node tables: %w", err)
|
||||
}
|
||||
@@ -207,18 +240,19 @@ func (r *NodeRegistry) FindNodeWithVRAM(ctx context.Context, minBytes uint64) (*
|
||||
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
|
||||
Group("node_id")
|
||||
|
||||
// Try idle nodes with enough VRAM first
|
||||
// Try idle nodes with enough VRAM first, prefer the one with most free VRAM
|
||||
var node BackendNode
|
||||
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, minBytes, loadedModels).
|
||||
Order("available_vram DESC").
|
||||
First(&node).Error
|
||||
if err == nil {
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
// Fall back to least-loaded nodes with enough VRAM
|
||||
// Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
|
||||
err = db.Where("status = ? AND node_type = ? AND available_vram >= ?", StatusHealthy, NodeTypeBackend, minBytes).
|
||||
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
|
||||
Order("COALESCE(load.total_inflight, 0) ASC").
|
||||
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
|
||||
First(&node).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no healthy nodes with %d bytes available VRAM: %w", minBytes, err)
|
||||
@@ -407,9 +441,12 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
|
||||
var node BackendNode
|
||||
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Order by in_flight ASC (least busy replica), then by available_vram DESC
|
||||
// (prefer nodes with more free VRAM to spread load across the cluster).
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Where("model_name = ? AND state = ?", modelName, "loaded").
|
||||
Order("in_flight ASC").
|
||||
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
|
||||
Where("node_models.model_name = ? AND node_models.state = ?", modelName, "loaded").
|
||||
Order("node_models.in_flight ASC, backend_nodes.available_vram DESC").
|
||||
First(&nm).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -463,7 +500,7 @@ func (r *NodeRegistry) FindLeastLoadedNode(ctx context.Context) (*BackendNode, e
|
||||
Group("node_id")
|
||||
|
||||
err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
|
||||
Order("COALESCE(load.total_inflight, 0) ASC").
|
||||
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
|
||||
First(&node).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finding least loaded node: %w", err)
|
||||
@@ -482,6 +519,7 @@ func (r *NodeRegistry) FindIdleNode(ctx context.Context) (*BackendNode, error) {
|
||||
Where("state = ?", "loaded").
|
||||
Group("node_id")
|
||||
err := db.Where("status = ? AND node_type = ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, loadedModels).
|
||||
Order("available_vram DESC").
|
||||
First(&node).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -578,3 +616,287 @@ func (r *NodeRegistry) FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (
|
||||
}
|
||||
return &nm, nil
|
||||
}
|
||||
|
||||
// --- NodeLabel operations ---
|
||||
|
||||
// SetNodeLabel upserts a single label on a node.
|
||||
func (r *NodeRegistry) SetNodeLabel(ctx context.Context, nodeID, key, value string) error {
|
||||
label := NodeLabel{
|
||||
ID: uuid.New().String(),
|
||||
NodeID: nodeID,
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
return r.db.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "node_id"}, {Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value"}),
|
||||
}).
|
||||
Create(&label).Error
|
||||
}
|
||||
|
||||
// SetNodeLabels replaces all labels for a node with the given map.
|
||||
func (r *NodeRegistry) SetNodeLabels(ctx context.Context, nodeID string, labels map[string]string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("node_id = ?", nodeID).Delete(&NodeLabel{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range labels {
|
||||
label := NodeLabel{ID: uuid.New().String(), NodeID: nodeID, Key: k, Value: v}
|
||||
if err := tx.Create(&label).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveNodeLabel removes a single label from a node.
|
||||
func (r *NodeRegistry) RemoveNodeLabel(ctx context.Context, nodeID, key string) error {
|
||||
return r.db.WithContext(ctx).Where("node_id = ? AND key = ?", nodeID, key).Delete(&NodeLabel{}).Error
|
||||
}
|
||||
|
||||
// GetNodeLabels returns all labels for a node.
|
||||
func (r *NodeRegistry) GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error) {
|
||||
var labels []NodeLabel
|
||||
err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Find(&labels).Error
|
||||
return labels, err
|
||||
}
|
||||
|
||||
// GetAllNodeLabelsMap returns all labels grouped by node ID.
|
||||
func (r *NodeRegistry) GetAllNodeLabelsMap(ctx context.Context) (map[string]map[string]string, error) {
|
||||
var labels []NodeLabel
|
||||
if err := r.db.WithContext(ctx).Find(&labels).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]map[string]string)
|
||||
for _, l := range labels {
|
||||
if result[l.NodeID] == nil {
|
||||
result[l.NodeID] = make(map[string]string)
|
||||
}
|
||||
result[l.NodeID][l.Key] = l.Value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// --- Selector-based queries ---
|
||||
|
||||
// FindNodesBySelector returns healthy backend nodes matching ALL key-value pairs in the selector.
|
||||
func (r *NodeRegistry) FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error) {
|
||||
if len(selector) == 0 {
|
||||
// Empty selector matches all healthy backend nodes
|
||||
var nodes []BackendNode
|
||||
err := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend).Find(&nodes).Error
|
||||
return nodes, err
|
||||
}
|
||||
|
||||
db := r.db.WithContext(ctx).Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend)
|
||||
for k, v := range selector {
|
||||
db = db.Where("EXISTS (SELECT 1 FROM node_labels WHERE node_labels.node_id = backend_nodes.id AND node_labels.key = ? AND node_labels.value = ?)", k, v)
|
||||
}
|
||||
|
||||
var nodes []BackendNode
|
||||
err := db.Find(&nodes).Error
|
||||
return nodes, err
|
||||
}
|
||||
|
||||
// FindNodeWithVRAMFromSet is like FindNodeWithVRAM but restricted to the given node IDs.
|
||||
func (r *NodeRegistry) FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error) {
|
||||
db := r.db.WithContext(ctx)
|
||||
|
||||
loadedModels := db.Model(&NodeModel{}).
|
||||
Select("node_id").
|
||||
Where("state = ?", "loaded").
|
||||
Group("node_id")
|
||||
|
||||
subquery := db.Model(&NodeModel{}).
|
||||
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
|
||||
Group("node_id")
|
||||
|
||||
// Try idle nodes with enough VRAM first, prefer the one with most free VRAM
|
||||
var node BackendNode
|
||||
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, minBytes, loadedModels, nodeIDs).
|
||||
Order("available_vram DESC").
|
||||
First(&node).Error
|
||||
if err == nil {
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
// Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
|
||||
err = db.Where("status = ? AND node_type = ? AND available_vram >= ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, minBytes, nodeIDs).
|
||||
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
|
||||
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
|
||||
First(&node).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no healthy nodes in set with %d bytes available VRAM: %w", minBytes, err)
|
||||
}
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
// FindIdleNodeFromSet is like FindIdleNode but restricted to the given node IDs.
|
||||
func (r *NodeRegistry) FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) {
|
||||
db := r.db.WithContext(ctx)
|
||||
|
||||
var node BackendNode
|
||||
loadedModels := db.Model(&NodeModel{}).
|
||||
Select("node_id").
|
||||
Where("state = ?", "loaded").
|
||||
Group("node_id")
|
||||
err := db.Where("status = ? AND node_type = ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, loadedModels, nodeIDs).
|
||||
Order("available_vram DESC").
|
||||
First(&node).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
// FindLeastLoadedNodeFromSet is like FindLeastLoadedNode but restricted to the given node IDs.
|
||||
func (r *NodeRegistry) FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) {
|
||||
db := r.db.WithContext(ctx)
|
||||
|
||||
var node BackendNode
|
||||
query := db.Where("status = ? AND node_type = ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, nodeIDs)
|
||||
// Order by total in-flight across all models on the node
|
||||
subquery := db.Model(&NodeModel{}).
|
||||
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
|
||||
Group("node_id")
|
||||
|
||||
err := query.Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
|
||||
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
|
||||
First(&node).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finding least loaded node in set: %w", err)
|
||||
}
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
// --- ModelSchedulingConfig operations ---
|
||||
|
||||
// SetModelScheduling creates or updates a scheduling config for a model.
|
||||
func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSchedulingConfig) error {
|
||||
if config.ID == "" {
|
||||
config.ID = uuid.New().String()
|
||||
}
|
||||
return r.db.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"node_selector", "min_replicas", "max_replicas", "updated_at"}),
|
||||
}).
|
||||
Create(config).Error
|
||||
}
|
||||
|
||||
// GetModelScheduling returns the scheduling config for a model, or nil if none exists.
|
||||
func (r *NodeRegistry) GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error) {
|
||||
var config ModelSchedulingConfig
|
||||
err := r.db.WithContext(ctx).Where("model_name = ?", modelName).First(&config).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// ListModelSchedulings returns all scheduling configs.
|
||||
func (r *NodeRegistry) ListModelSchedulings(ctx context.Context) ([]ModelSchedulingConfig, error) {
|
||||
var configs []ModelSchedulingConfig
|
||||
err := r.db.WithContext(ctx).Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
// ListAutoScalingConfigs returns scheduling configs where auto-scaling is enabled.
|
||||
func (r *NodeRegistry) ListAutoScalingConfigs(ctx context.Context) ([]ModelSchedulingConfig, error) {
|
||||
var configs []ModelSchedulingConfig
|
||||
err := r.db.WithContext(ctx).Where("min_replicas > 0 OR max_replicas > 0").Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
// DeleteModelScheduling removes a scheduling config by model name.
|
||||
func (r *NodeRegistry) DeleteModelScheduling(ctx context.Context, modelName string) error {
|
||||
return r.db.WithContext(ctx).Where("model_name = ?", modelName).Delete(&ModelSchedulingConfig{}).Error
|
||||
}
|
||||
|
||||
// CountLoadedReplicas returns the number of loaded replicas for a model.
|
||||
func (r *NodeRegistry) CountLoadedReplicas(ctx context.Context, modelName string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&NodeModel{}).Where("model_name = ? AND state = ?", modelName, "loaded").Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// --- Composite queries ---
|
||||
|
||||
// ListWithExtras returns all nodes with model counts and labels.
|
||||
func (r *NodeRegistry) ListWithExtras(ctx context.Context) ([]NodeWithExtras, error) {
|
||||
// Get all nodes
|
||||
var nodes []BackendNode
|
||||
if err := r.db.WithContext(ctx).Find(&nodes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get model counts per node
|
||||
type modelCount struct {
|
||||
NodeID string
|
||||
Count int
|
||||
}
|
||||
var counts []modelCount
|
||||
if err := r.db.WithContext(ctx).Model(&NodeModel{}).
|
||||
Select("node_id, COUNT(*) as count").
|
||||
Where("state = ?", "loaded").
|
||||
Group("node_id").
|
||||
Find(&counts).Error; err != nil {
|
||||
xlog.Warn("ListWithExtras: failed to get model counts", "error", err)
|
||||
}
|
||||
|
||||
countMap := make(map[string]int)
|
||||
for _, c := range counts {
|
||||
countMap[c.NodeID] = c.Count
|
||||
}
|
||||
|
||||
// Get all labels
|
||||
labelsMap, err := r.GetAllNodeLabelsMap(ctx)
|
||||
if err != nil {
|
||||
xlog.Warn("ListWithExtras: failed to get labels", "error", err)
|
||||
}
|
||||
|
||||
// Build result
|
||||
result := make([]NodeWithExtras, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = NodeWithExtras{
|
||||
BackendNode: n,
|
||||
ModelCount: countMap[n.ID],
|
||||
Labels: labelsMap[n.ID],
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ApplyAutoLabels sets automatic labels based on node hardware info.
|
||||
func (r *NodeRegistry) ApplyAutoLabels(ctx context.Context, nodeID string, node *BackendNode) {
|
||||
if node.GPUVendor != "" {
|
||||
_ = r.SetNodeLabel(ctx, nodeID, "gpu.vendor", node.GPUVendor)
|
||||
}
|
||||
if node.TotalVRAM > 0 {
|
||||
gb := node.TotalVRAM / (1024 * 1024 * 1024)
|
||||
var bucket string
|
||||
switch {
|
||||
case gb >= 80:
|
||||
bucket = "80GB+"
|
||||
case gb >= 48:
|
||||
bucket = "48GB"
|
||||
case gb >= 24:
|
||||
bucket = "24GB"
|
||||
case gb >= 16:
|
||||
bucket = "16GB"
|
||||
case gb >= 8:
|
||||
bucket = "8GB"
|
||||
default:
|
||||
bucket = fmt.Sprintf("%dGB", gb)
|
||||
}
|
||||
_ = r.SetNodeLabel(ctx, nodeID, "gpu.vram", bucket)
|
||||
}
|
||||
if node.Name != "" {
|
||||
_ = r.SetNodeLabel(ctx, nodeID, "node.name", node.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -305,6 +305,252 @@ var _ = Describe("NodeRegistry", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("NodeLabel CRUD", func() {
|
||||
It("sets and retrieves labels for a node", func() {
|
||||
node := makeNode("label-node", "10.0.0.70:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
|
||||
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed())
|
||||
|
||||
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(labels).To(HaveLen(2))
|
||||
|
||||
labelMap := make(map[string]string)
|
||||
for _, l := range labels {
|
||||
labelMap[l.Key] = l.Value
|
||||
}
|
||||
Expect(labelMap["env"]).To(Equal("prod"))
|
||||
Expect(labelMap["region"]).To(Equal("us-east"))
|
||||
})
|
||||
|
||||
It("overwrites existing label with same key", func() {
|
||||
node := makeNode("label-overwrite", "10.0.0.71:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
|
||||
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "dev")).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
|
||||
|
||||
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(labels).To(HaveLen(1))
|
||||
Expect(labels[0].Value).To(Equal("prod"))
|
||||
})
|
||||
|
||||
It("removes a single label by key", func() {
|
||||
node := makeNode("label-remove", "10.0.0.72:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
|
||||
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed())
|
||||
|
||||
Expect(registry.RemoveNodeLabel(context.Background(), node.ID, "env")).To(Succeed())
|
||||
|
||||
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(labels).To(HaveLen(1))
|
||||
Expect(labels[0].Key).To(Equal("region"))
|
||||
})
|
||||
|
||||
It("SetNodeLabels replaces all labels", func() {
|
||||
node := makeNode("label-replace", "10.0.0.73:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
||||
|
||||
Expect(registry.SetNodeLabel(context.Background(), node.ID, "old-key", "old-val")).To(Succeed())
|
||||
|
||||
newLabels := map[string]string{"new-a": "val-a", "new-b": "val-b"}
|
||||
Expect(registry.SetNodeLabels(context.Background(), node.ID, newLabels)).To(Succeed())
|
||||
|
||||
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(labels).To(HaveLen(2))
|
||||
|
||||
labelMap := make(map[string]string)
|
||||
for _, l := range labels {
|
||||
labelMap[l.Key] = l.Value
|
||||
}
|
||||
Expect(labelMap).To(Equal(newLabels))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("FindNodesBySelector", func() {
|
||||
It("returns nodes matching all labels in selector", func() {
|
||||
n1 := makeNode("sel-match", "10.0.0.80:50051", 8_000_000_000)
|
||||
n2 := makeNode("sel-nomatch", "10.0.0.81:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
|
||||
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
|
||||
|
||||
Expect(registry.SetNodeLabel(context.Background(), n1.ID, "env", "prod")).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), n1.ID, "region", "us-east")).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), n2.ID, "env", "dev")).To(Succeed())
|
||||
|
||||
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod", "region": "us-east"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(nodes).To(HaveLen(1))
|
||||
Expect(nodes[0].Name).To(Equal("sel-match"))
|
||||
})
|
||||
|
||||
It("returns empty when no nodes match", func() {
|
||||
n := makeNode("sel-empty", "10.0.0.82:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "dev")).To(Succeed())
|
||||
|
||||
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(nodes).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("ignores unhealthy nodes", func() {
|
||||
n := makeNode("sel-unhealthy", "10.0.0.83:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed())
|
||||
Expect(registry.MarkUnhealthy(context.Background(), n.ID)).To(Succeed())
|
||||
|
||||
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(nodes).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("matches nodes with more labels than selector requires", func() {
|
||||
n := makeNode("sel-superset", "10.0.0.84:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), n.ID, "region", "us-east")).To(Succeed())
|
||||
Expect(registry.SetNodeLabel(context.Background(), n.ID, "tier", "gpu")).To(Succeed())
|
||||
|
||||
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(nodes).To(HaveLen(1))
|
||||
Expect(nodes[0].Name).To(Equal("sel-superset"))
|
||||
})
|
||||
|
||||
It("returns all healthy nodes for empty selector", func() {
|
||||
n1 := makeNode("sel-all-1", "10.0.0.85:50051", 8_000_000_000)
|
||||
n2 := makeNode("sel-all-2", "10.0.0.86:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
|
||||
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
|
||||
|
||||
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(nodes)).To(BeNumerically(">=", 2))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ModelSchedulingConfig CRUD", func() {
|
||||
It("creates and retrieves a scheduling config", func() {
|
||||
config := &ModelSchedulingConfig{
|
||||
ModelName: "llama-7b",
|
||||
NodeSelector: `{"gpu.vendor":"nvidia"}`,
|
||||
MinReplicas: 1,
|
||||
MaxReplicas: 3,
|
||||
}
|
||||
Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed())
|
||||
Expect(config.ID).ToNot(BeEmpty())
|
||||
|
||||
fetched, err := registry.GetModelScheduling(context.Background(), "llama-7b")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(fetched).ToNot(BeNil())
|
||||
Expect(fetched.ModelName).To(Equal("llama-7b"))
|
||||
Expect(fetched.NodeSelector).To(Equal(`{"gpu.vendor":"nvidia"}`))
|
||||
Expect(fetched.MinReplicas).To(Equal(1))
|
||||
Expect(fetched.MaxReplicas).To(Equal(3))
|
||||
})
|
||||
|
||||
It("updates existing config via SetModelScheduling", func() {
|
||||
config := &ModelSchedulingConfig{
|
||||
ModelName: "update-model",
|
||||
MinReplicas: 1,
|
||||
MaxReplicas: 2,
|
||||
}
|
||||
Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed())
|
||||
|
||||
config2 := &ModelSchedulingConfig{
|
||||
ModelName: "update-model",
|
||||
MinReplicas: 2,
|
||||
MaxReplicas: 5,
|
||||
}
|
||||
Expect(registry.SetModelScheduling(context.Background(), config2)).To(Succeed())
|
||||
|
||||
fetched, err := registry.GetModelScheduling(context.Background(), "update-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(fetched.MinReplicas).To(Equal(2))
|
||||
Expect(fetched.MaxReplicas).To(Equal(5))
|
||||
})
|
||||
|
||||
It("lists all configs", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed())
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed())
|
||||
|
||||
configs, err := registry.ListModelSchedulings(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(configs)).To(BeNumerically(">=", 2))
|
||||
})
|
||||
|
||||
It("lists only auto-scaling configs", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-a", MinReplicas: 2})).To(Succeed())
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-b", MaxReplicas: 3})).To(Succeed())
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "no-auto", NodeSelector: `{"env":"prod"}`})).To(Succeed())
|
||||
|
||||
configs, err := registry.ListAutoScalingConfigs(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
names := make([]string, len(configs))
|
||||
for i, c := range configs {
|
||||
names[i] = c.ModelName
|
||||
}
|
||||
Expect(names).To(ContainElement("auto-a"))
|
||||
Expect(names).To(ContainElement("auto-b"))
|
||||
Expect(names).ToNot(ContainElement("no-auto"))
|
||||
})
|
||||
|
||||
It("deletes a config", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "delete-me", MinReplicas: 1})).To(Succeed())
|
||||
|
||||
Expect(registry.DeleteModelScheduling(context.Background(), "delete-me")).To(Succeed())
|
||||
|
||||
fetched, err := registry.GetModelScheduling(context.Background(), "delete-me")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(fetched).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil for non-existent model", func() {
|
||||
fetched, err := registry.GetModelScheduling(context.Background(), "does-not-exist")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(fetched).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("CountLoadedReplicas", func() {
|
||||
It("returns correct count of loaded replicas", func() {
|
||||
n1 := makeNode("replica-node-1", "10.0.0.90:50051", 8_000_000_000)
|
||||
n2 := makeNode("replica-node-2", "10.0.0.91:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
|
||||
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
|
||||
|
||||
Expect(registry.SetNodeModel(context.Background(), n1.ID, "counted-model", "loaded", "", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), n2.ID, "counted-model", "loaded", "", 0)).To(Succeed())
|
||||
|
||||
count, err := registry.CountLoadedReplicas(context.Background(), "counted-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(2)))
|
||||
})
|
||||
|
||||
It("excludes non-loaded states", func() {
|
||||
n1 := makeNode("replica-loaded", "10.0.0.92:50051", 8_000_000_000)
|
||||
n2 := makeNode("replica-loading", "10.0.0.93:50051", 8_000_000_000)
|
||||
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
|
||||
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
|
||||
|
||||
Expect(registry.SetNodeModel(context.Background(), n1.ID, "state-model", "loaded", "", 0)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), n2.ID, "state-model", "loading", "", 0)).To(Succeed())
|
||||
|
||||
count, err := registry.CountLoadedReplicas(context.Background(), "state-model")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(1)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("DecrementInFlight", func() {
|
||||
It("does not go below zero", func() {
|
||||
node := makeNode("dec-node", "10.0.0.50:50051", 4_000_000_000)
|
||||
|
||||
@@ -2,6 +2,7 @@ package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -69,6 +70,18 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
||||
// Unloader returns the remote unloader adapter for external use.
|
||||
func (r *SmartRouter) Unloader() NodeCommandSender { return r.unloader }
|
||||
|
||||
// ScheduleAndLoadModel implements ModelScheduler for the reconciler.
|
||||
// It schedules a model on a suitable node (optionally from candidates) and loads it.
|
||||
func (r *SmartRouter) ScheduleAndLoadModel(ctx context.Context, modelName string, candidateNodeIDs []string) (*BackendNode, error) {
|
||||
// Use scheduleNewModel with empty backend type and nil model options.
|
||||
// The reconciler doesn't know the backend type — it will be determined by the model config.
|
||||
node, _, err := r.scheduleNewModel(ctx, "", modelName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// RouteResult contains the routing decision.
|
||||
type RouteResult struct {
|
||||
Node *BackendNode
|
||||
@@ -110,18 +123,26 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
xlog.Warn("Backend not reachable for cached model, falling through to reload",
|
||||
"node", node.Name, "model", modelName)
|
||||
} else {
|
||||
// Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight,
|
||||
// and Release decrements it. No InFlightTrackingClient to avoid double-counting.
|
||||
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
|
||||
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
|
||||
return &RouteResult{
|
||||
Node: node,
|
||||
Client: grpcClient,
|
||||
Release: func() {
|
||||
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
|
||||
closeClient(grpcClient)
|
||||
},
|
||||
}, nil
|
||||
// Verify node still matches scheduling constraints
|
||||
if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
|
||||
r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
|
||||
xlog.Info("Cached model on node that no longer matches selector, falling through",
|
||||
"node", node.Name, "model", trackingKey)
|
||||
// Fall through to step 2 (scheduleNewModel)
|
||||
} else {
|
||||
// Node is alive — use raw client; FindAndLockNodeWithModel already incremented in-flight,
|
||||
// and Release decrements it. No InFlightTrackingClient to avoid double-counting.
|
||||
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
|
||||
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
|
||||
return &RouteResult{
|
||||
Node: node,
|
||||
Client: grpcClient,
|
||||
Release: func() {
|
||||
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
|
||||
closeClient(grpcClient)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,17 +164,25 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
xlog.Warn("Backend not reachable for cached model inside lock, proceeding to load",
|
||||
"node", node.Name, "model", modelName)
|
||||
} else {
|
||||
// Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting
|
||||
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
|
||||
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
|
||||
return &RouteResult{
|
||||
Node: node,
|
||||
Client: grpcClient,
|
||||
Release: func() {
|
||||
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
|
||||
closeClient(grpcClient)
|
||||
},
|
||||
}, nil
|
||||
// Verify node still matches scheduling constraints
|
||||
if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
|
||||
r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
|
||||
xlog.Info("Cached model on node that no longer matches selector, falling through",
|
||||
"node", node.Name, "model", trackingKey)
|
||||
// Fall through to scheduling below
|
||||
} else {
|
||||
// Model loaded while we waited — reuse it; no InFlightTrackingClient to avoid double-counting
|
||||
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
|
||||
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
|
||||
return &RouteResult{
|
||||
Node: node,
|
||||
Client: grpcClient,
|
||||
Release: func() {
|
||||
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
|
||||
closeClient(grpcClient)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,6 +252,59 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
|
||||
return loadModel()
|
||||
}
|
||||
|
||||
// parseSelectorJSON decodes a JSON node selector string into a map.
|
||||
func parseSelectorJSON(selectorJSON string) map[string]string {
|
||||
if selectorJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var selector map[string]string
|
||||
if err := json.Unmarshal([]byte(selectorJSON), &selector); err != nil {
|
||||
xlog.Warn("Failed to parse node selector", "selector", selectorJSON, "error", err)
|
||||
return nil
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
func extractNodeIDs(nodes []BackendNode) []string {
|
||||
ids := make([]string, len(nodes))
|
||||
for i, n := range nodes {
|
||||
ids[i] = n.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model.
|
||||
// Returns true if no constraints exist or the node matches all selector labels.
|
||||
func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool {
|
||||
sched, err := r.registry.GetModelScheduling(ctx, modelName)
|
||||
if err != nil || sched == nil || sched.NodeSelector == "" {
|
||||
return true // no constraints
|
||||
}
|
||||
|
||||
selector := parseSelectorJSON(sched.NodeSelector)
|
||||
if len(selector) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
labels, err := r.registry.GetNodeLabels(ctx, node.ID)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to get node labels for selector check", "node", node.ID, "error", err)
|
||||
return true // fail open
|
||||
}
|
||||
|
||||
labelMap := make(map[string]string)
|
||||
for _, l := range labels {
|
||||
labelMap[l.Key] = l.Value
|
||||
}
|
||||
|
||||
for k, v := range selector {
|
||||
if labelMap[k] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// scheduleNewModel picks the best node for loading a new model.
|
||||
// Strategy: VRAM-aware → idle-first → least-loaded.
|
||||
// Sends backend.install via NATS so the chosen node has the right backend running.
|
||||
@@ -233,12 +315,30 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
|
||||
estimatedVRAM = r.estimateModelVRAM(ctx, modelOpts)
|
||||
}
|
||||
|
||||
// Check for scheduling constraints (node selector)
|
||||
sched, _ := r.registry.GetModelScheduling(ctx, modelID)
|
||||
var candidateNodeIDs []string // nil = all nodes eligible
|
||||
|
||||
if sched != nil && sched.NodeSelector != "" {
|
||||
selector := parseSelectorJSON(sched.NodeSelector)
|
||||
if len(selector) > 0 {
|
||||
candidates, err := r.registry.FindNodesBySelector(ctx, selector)
|
||||
if err != nil || len(candidates) == 0 {
|
||||
return nil, "", fmt.Errorf("no healthy nodes match selector for model %s: %v", modelID, sched.NodeSelector)
|
||||
}
|
||||
candidateNodeIDs = extractNodeIDs(candidates)
|
||||
}
|
||||
}
|
||||
|
||||
var node *BackendNode
|
||||
var err error
|
||||
|
||||
if estimatedVRAM > 0 {
|
||||
// 1. Prefer nodes with enough VRAM (idle-first, then least-loaded)
|
||||
node, err = r.registry.FindNodeWithVRAM(ctx, estimatedVRAM)
|
||||
if candidateNodeIDs != nil {
|
||||
node, err = r.registry.FindNodeWithVRAMFromSet(ctx, estimatedVRAM, candidateNodeIDs)
|
||||
} else {
|
||||
node, err = r.registry.FindNodeWithVRAM(ctx, estimatedVRAM)
|
||||
}
|
||||
if err != nil {
|
||||
xlog.Warn("No nodes with enough VRAM, falling back to standard scheduling",
|
||||
"required_vram", vram.FormatBytes(estimatedVRAM), "error", err)
|
||||
@@ -246,11 +346,16 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
|
||||
}
|
||||
|
||||
if node == nil {
|
||||
// 2. Prefer truly idle nodes (no loaded models, no in-flight)
|
||||
node, err = r.registry.FindIdleNode(ctx)
|
||||
if err != nil {
|
||||
// 3. Fall back to least-loaded node (can run an additional backend process)
|
||||
node, err = r.registry.FindLeastLoadedNode(ctx)
|
||||
if candidateNodeIDs != nil {
|
||||
node, err = r.registry.FindIdleNodeFromSet(ctx, candidateNodeIDs)
|
||||
if err != nil {
|
||||
node, err = r.registry.FindLeastLoadedNodeFromSet(ctx, candidateNodeIDs)
|
||||
}
|
||||
} else {
|
||||
node, err = r.registry.FindIdleNode(ctx)
|
||||
if err != nil {
|
||||
node, err = r.registry.FindLeastLoadedNode(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -610,7 +715,12 @@ func (r *SmartRouter) evictLRUAndFreeNode(ctx context.Context) (*BackendNode, er
|
||||
// Lock the row so no other frontend can evict the same model
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
|
||||
Where("node_models.in_flight = 0 AND node_models.state = ? AND backend_nodes.status = ?", "loaded", StatusHealthy).
|
||||
Where(`node_models.in_flight = 0 AND node_models.state = ? AND backend_nodes.status = ?
|
||||
AND (
|
||||
NOT EXISTS (SELECT 1 FROM model_scheduling_configs sc WHERE sc.model_name = node_models.model_name AND (sc.min_replicas > 0 OR sc.max_replicas > 0))
|
||||
OR (SELECT COUNT(*) FROM node_models nm2 WHERE nm2.model_name = node_models.model_name AND nm2.state = 'loaded')
|
||||
> COALESCE((SELECT sc2.min_replicas FROM model_scheduling_configs sc2 WHERE sc2.model_name = node_models.model_name), 1)
|
||||
)`, "loaded", StatusHealthy).
|
||||
Order("node_models.last_used ASC").
|
||||
First(&lru).Error; err != nil {
|
||||
return err
|
||||
|
||||
@@ -86,6 +86,26 @@ type fakeModelRouter struct {
|
||||
getNode *BackendNode
|
||||
getErr error
|
||||
|
||||
// GetModelScheduling returns
|
||||
getModelScheduling *ModelSchedulingConfig
|
||||
getModelSchedErr error
|
||||
|
||||
// FindNodesBySelector returns
|
||||
findBySelectorNodes []BackendNode
|
||||
findBySelectorErr error
|
||||
|
||||
// *FromSet variants
|
||||
findVRAMFromSetNode *BackendNode
|
||||
findVRAMFromSetErr error
|
||||
findIdleFromSetNode *BackendNode
|
||||
findIdleFromSetErr error
|
||||
findLeastLoadedFromSetNode *BackendNode
|
||||
findLeastLoadedFromSetErr error
|
||||
|
||||
// GetNodeLabels returns
|
||||
getNodeLabels []NodeLabel
|
||||
getNodeLabelsErr error
|
||||
|
||||
// Track calls for assertions
|
||||
decrementCalls []string // "nodeID:modelName"
|
||||
incrementCalls []string
|
||||
@@ -146,6 +166,30 @@ func (f *fakeModelRouter) Get(_ context.Context, _ string) (*BackendNode, error)
|
||||
return f.getNode, f.getErr
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
|
||||
return f.getModelScheduling, f.getModelSchedErr
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
|
||||
return f.findBySelectorNodes, f.findBySelectorErr
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
|
||||
return f.findVRAMFromSetNode, f.findVRAMFromSetErr
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
|
||||
return f.findIdleFromSetNode, f.findIdleFromSetErr
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
|
||||
return f.findLeastLoadedFromSetNode, f.findLeastLoadedFromSetErr
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
|
||||
return f.getNodeLabels, f.getNodeLabelsErr
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fake BackendClientFactory + Backend
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -478,6 +522,135 @@ var _ = Describe("SmartRouter", func() {
|
||||
})
|
||||
})
|
||||
|
||||
Describe("scheduleNewModel with node selector (mock-based, via Route)", func() {
|
||||
var (
|
||||
reg *fakeModelRouter
|
||||
backend *stubBackend
|
||||
factory *stubClientFactory
|
||||
unloader *fakeUnloader
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
reg = &fakeModelRouter{
|
||||
findAndLockErr: errors.New("not found"),
|
||||
}
|
||||
backend = &stubBackend{
|
||||
loadResult: &pb.Result{Success: true},
|
||||
}
|
||||
factory = &stubClientFactory{client: backend}
|
||||
unloader = &fakeUnloader{
|
||||
installReply: &messaging.BackendInstallReply{
|
||||
Success: true,
|
||||
Address: "10.0.0.1:9001",
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
It("uses *FromSet methods when model has a node selector", func() {
|
||||
gpuNode := &BackendNode{ID: "gpu-1", Name: "gpu-node", Address: "10.0.0.50:50051"}
|
||||
reg.getModelScheduling = &ModelSchedulingConfig{
|
||||
ModelName: "selector-model",
|
||||
NodeSelector: `{"gpu.vendor":"nvidia"}`,
|
||||
}
|
||||
reg.findBySelectorNodes = []BackendNode{*gpuNode}
|
||||
reg.findIdleFromSetNode = gpuNode
|
||||
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
})
|
||||
|
||||
result, err := router.Route(context.Background(), "selector-model", "models/selector.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).ToNot(BeNil())
|
||||
Expect(result.Node.ID).To(Equal("gpu-1"))
|
||||
})
|
||||
|
||||
It("returns error when no nodes match selector", func() {
|
||||
reg.getModelScheduling = &ModelSchedulingConfig{
|
||||
ModelName: "no-match-model",
|
||||
NodeSelector: `{"gpu.vendor":"tpu"}`,
|
||||
}
|
||||
reg.findBySelectorNodes = nil
|
||||
reg.findBySelectorErr = nil
|
||||
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
})
|
||||
|
||||
_, err := router.Route(context.Background(), "no-match-model", "models/nomatch.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no healthy nodes match selector"))
|
||||
})
|
||||
|
||||
It("uses regular methods when model has no scheduling config", func() {
|
||||
reg.getModelScheduling = nil
|
||||
idleNode := &BackendNode{ID: "regular-1", Name: "regular-node", Address: "10.0.0.60:50051"}
|
||||
reg.findIdleNode = idleNode
|
||||
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
})
|
||||
|
||||
result, err := router.Route(context.Background(), "regular-model", "models/regular.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).ToNot(BeNil())
|
||||
Expect(result.Node.ID).To(Equal("regular-1"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Route with selector validation on cached model (mock-based)", func() {
|
||||
It("falls through when cached node no longer matches selector", func() {
|
||||
cachedNode := &BackendNode{ID: "n-old", Name: "old-node", Address: "10.0.0.70:50051"}
|
||||
newNode := &BackendNode{ID: "n-new", Name: "new-node", Address: "10.0.0.71:50051"}
|
||||
|
||||
backend := &stubBackend{
|
||||
healthResult: true,
|
||||
loadResult: &pb.Result{Success: true},
|
||||
}
|
||||
factory := &stubClientFactory{client: backend}
|
||||
unloader := &fakeUnloader{
|
||||
installReply: &messaging.BackendInstallReply{
|
||||
Success: true,
|
||||
Address: "10.0.0.71:9001",
|
||||
},
|
||||
}
|
||||
|
||||
reg := &fakeModelRouter{
|
||||
// Step 1: cached model found on old node
|
||||
findAndLockNode: cachedNode,
|
||||
findAndLockNM: &NodeModel{NodeID: "n-old", ModelName: "sel-model", Address: "10.0.0.70:9001"},
|
||||
// Scheduling config with selector that old node does NOT match
|
||||
getModelScheduling: &ModelSchedulingConfig{
|
||||
ModelName: "sel-model",
|
||||
NodeSelector: `{"gpu.vendor":"nvidia"}`,
|
||||
},
|
||||
// Old node has no labels matching the selector
|
||||
getNodeLabels: []NodeLabel{
|
||||
{NodeID: "n-old", Key: "gpu.vendor", Value: "amd"},
|
||||
},
|
||||
// For scheduling fallthrough: selector matches new node
|
||||
findBySelectorNodes: []BackendNode{*newNode},
|
||||
findIdleFromSetNode: newNode,
|
||||
}
|
||||
|
||||
router := NewSmartRouter(reg, SmartRouterOptions{
|
||||
Unloader: unloader,
|
||||
ClientFactory: factory,
|
||||
})
|
||||
|
||||
result, err := router.Route(context.Background(), "sel-model", "models/sel.gguf", "llama-cpp", nil, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).ToNot(BeNil())
|
||||
// Should have fallen through to the new node
|
||||
Expect(result.Node.ID).To(Equal("n-new"))
|
||||
// Old node should have had its in-flight decremented
|
||||
Expect(reg.decrementCalls).To(ContainElement("n-old:sel-model"))
|
||||
})
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Integration tests using real PostgreSQL (existing)
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
@@ -134,6 +134,47 @@ local-ai worker \
|
||||
**HTTP file transfer:** Each worker also runs a small HTTP server for file transfer (model files, configs). By default it listens on the gRPC base port - 1 (e.g., if gRPC base is 50051, HTTP is on 50050). gRPC ports grow upward from the base port as additional models are loaded. Set `--advertise-http-addr` if the auto-detected address is not routable from the frontend.
|
||||
{{% /notice %}}
|
||||
|
||||
### Worker Address Configuration
|
||||
|
||||
The simplest way to configure a worker's network address is with a single variable:
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `LOCALAI_ADDR` | Reachable address of this worker (`host:port`). The port is used as the base for gRPC backend processes, and `port-1` for the HTTP file transfer server. |
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
environment:
|
||||
LOCALAI_ADDR: "192.168.1.100:50051"
|
||||
LOCALAI_NATS_URL: "nats://frontend:4222"
|
||||
LOCALAI_REGISTER_TO: "http://frontend:8080"
|
||||
LOCALAI_REGISTRATION_TOKEN: "my-secret"
|
||||
```
|
||||
|
||||
For advanced networking scenarios (NAT, load balancers, separate gRPC/HTTP ports), the following override variables are available:
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `LOCALAI_SERVE_ADDR` | gRPC base port bind address | `0.0.0.0:50051` |
|
||||
| `LOCALAI_HTTP_ADDR` | HTTP file transfer bind address | `0.0.0.0:{gRPC port - 1}` |
|
||||
| `LOCALAI_ADVERTISE_ADDR` | Public gRPC address (if different from `LOCALAI_ADDR`) | Derived from `LOCALAI_ADDR` |
|
||||
| `LOCALAI_ADVERTISE_HTTP_ADDR` | Public HTTP address (if different from gRPC host) | Derived from advertise host + HTTP port |
|
||||
|
||||
### Node Labels
|
||||
|
||||
Workers can declare labels at startup for scheduling constraints:
|
||||
|
||||
| Variable | Description | Example |
|
||||
|----------|-------------|---------|
|
||||
| `LOCALAI_NODE_LABELS` | Comma-separated `key=value` labels | `tier=premium,gpu=a100,zone=us-east` |
|
||||
|
||||
Labels can also be managed via the admin API (see [Label Management API](#label-management-api) below).
|
||||
|
||||
The system automatically applies hardware-detected labels on registration:
|
||||
- `gpu.vendor` -- GPU vendor (nvidia, amd, intel, vulkan)
|
||||
- `gpu.vram` -- GPU VRAM bucket (8GB, 16GB, 24GB, 48GB, 80GB+)
|
||||
- `node.name` -- The node's registered name
|
||||
|
||||
### How Workers Operate
|
||||
|
||||
Workers start as generic processes with no backend installed. When the SmartRouter needs to load a model on a worker, it sends a NATS `backend.install` event with the backend name and model ID. The worker:
|
||||
@@ -262,6 +303,75 @@ local-ai worker \
|
||||
|
||||
**Multiple frontend replicas:** Run multiple LocalAI frontends behind a load balancer. Since all state is in PostgreSQL and coordination is via NATS, frontends are fully stateless and interchangeable.
|
||||
|
||||
## Model Scheduling
|
||||
|
||||
Model scheduling controls where models are placed and how many replicas are maintained. It combines two optional features:
|
||||
|
||||
### Node Selectors
|
||||
|
||||
Pin models to nodes with specific labels. Only nodes matching **all** selector labels are eligible:
|
||||
|
||||
```bash
|
||||
# Only schedule on NVIDIA nodes in the us-east zone
|
||||
curl -X POST http://frontend:8080/api/nodes/scheduling \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model_name": "llama3", "node_selector": {"gpu.vendor": "nvidia", "zone": "us-east"}}'
|
||||
```
|
||||
|
||||
Without a node selector, models can schedule on any healthy node (default behavior).
|
||||
|
||||
### Replica Auto-Scaling
|
||||
|
||||
Control the number of model replicas across the cluster:
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `min_replicas` | Minimum replicas to maintain (0 = no minimum, single replica default) |
|
||||
| `max_replicas` | Maximum replicas allowed (0 = unlimited) |
|
||||
|
||||
Auto-scaling is **only active** when `min_replicas > 0` or `max_replicas > 0`.
|
||||
|
||||
```bash
|
||||
# Scale llama3 between 2 and 4 replicas on NVIDIA nodes
|
||||
curl -X POST http://frontend:8080/api/nodes/scheduling \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model_name": "llama3",
|
||||
"node_selector": {"gpu.vendor": "nvidia"},
|
||||
"min_replicas": 2,
|
||||
"max_replicas": 4
|
||||
}'
|
||||
```
|
||||
|
||||
The **Replica Reconciler** runs as a background process on the frontend:
|
||||
- **Scale up**: Adds replicas when all existing replicas are busy (have in-flight requests)
|
||||
- **Scale down**: Removes idle replicas after 5 minutes of inactivity
|
||||
- **Maintain minimum**: Ensures `min_replicas` are always loaded (recovers from node failures)
|
||||
- **Eviction protection**: Models with auto-scaling enabled are never evicted below `min_replicas`
|
||||
|
||||
All fields are optional and composable:
|
||||
- Node selector only: pin model to matching nodes, single replica
|
||||
- Replicas only: auto-scale across all nodes
|
||||
- Both: auto-scale on matching nodes only
|
||||
|
||||
## Label Management API
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `GET` | `/api/nodes/:id/labels` | Get labels for a node |
|
||||
| `PUT` | `/api/nodes/:id/labels` | Replace all labels (JSON object) |
|
||||
| `PATCH` | `/api/nodes/:id/labels` | Merge labels (add/update) |
|
||||
| `DELETE` | `/api/nodes/:id/labels/:key` | Remove a single label |
|
||||
|
||||
## Scheduling API
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `GET` | `/api/nodes/scheduling` | List all scheduling configs |
|
||||
| `GET` | `/api/nodes/scheduling/:model` | Get config for a model |
|
||||
| `POST` | `/api/nodes/scheduling` | Create/update config |
|
||||
| `DELETE` | `/api/nodes/scheduling/:model` | Remove config |
|
||||
|
||||
## Comparison with P2P
|
||||
|
||||
| | P2P / Federation | Distributed Mode |
|
||||
|
||||
@@ -28,7 +28,6 @@ Changes to watchdog settings are applied immediately by restarting the watchdog
|
||||
### Backend Configuration
|
||||
|
||||
- **Max Active Backends**: Maximum number of active backends (loaded models). When exceeded, the least recently used model is automatically evicted. Set to `0` for unlimited, `1` for single-backend mode
|
||||
- **Parallel Backend Requests**: Enable backends to handle multiple requests in parallel if supported
|
||||
- **Force Eviction When Busy**: Allow evicting models even when they have active API calls (default: disabled for safety). **Warning:** Enabling this can interrupt active requests
|
||||
- **LRU Eviction Max Retries**: Maximum number of retries when waiting for busy models to become idle before eviction (default: 30)
|
||||
- **LRU Eviction Retry Interval**: Interval between retries when waiting for busy models (default: `1s`)
|
||||
@@ -123,7 +122,6 @@ The `runtime_settings.json` file follows this structure:
|
||||
"watchdog_idle_timeout": "15m",
|
||||
"watchdog_busy_timeout": "5m",
|
||||
"max_active_backends": 0,
|
||||
"parallel_backend_requests": true,
|
||||
"force_eviction_when_busy": false,
|
||||
"lru_eviction_max_retries": 30,
|
||||
"lru_eviction_retry_interval": "1s",
|
||||
|
||||
@@ -39,7 +39,6 @@ Complete reference for all LocalAI command-line interface (CLI) parameters and e
|
||||
| `--external-grpc-backends` | | A list of external gRPC backends (format: `BACKEND_NAME:URI`) | `$LOCALAI_EXTERNAL_GRPC_BACKENDS`, `$EXTERNAL_GRPC_BACKENDS` |
|
||||
| `--backend-galleries` | | JSON list of backend galleries | `$LOCALAI_BACKEND_GALLERIES`, `$BACKEND_GALLERIES` |
|
||||
| `--autoload-backend-galleries` | `true` | Automatically load backend galleries on startup | `$LOCALAI_AUTOLOAD_BACKEND_GALLERIES`, `$AUTOLOAD_BACKEND_GALLERIES` |
|
||||
| `--parallel-requests` | `false` | Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm) | `$LOCALAI_PARALLEL_REQUESTS`, `$PARALLEL_REQUESTS` |
|
||||
| `--max-active-backends` | `0` | Maximum number of active backends (loaded models). When exceeded, the least recently used model is evicted. Set to `0` for unlimited, `1` for single-backend mode | `$LOCALAI_MAX_ACTIVE_BACKENDS`, `$MAX_ACTIVE_BACKENDS` |
|
||||
| `--single-active-backend` | `false` | **DEPRECATED** - Use `--max-active-backends=1` instead. Allow only one backend to be run at a time | `$LOCALAI_SINGLE_ACTIVE_BACKEND`, `$SINGLE_ACTIVE_BACKEND` |
|
||||
| `--preload-backend-only` | `false` | Do not launch the API services, only the preloaded models/backends are started (useful for multi-node setups) | `$LOCALAI_PRELOAD_BACKEND_ONLY`, `$PRELOAD_BACKEND_ONLY` |
|
||||
|
||||
Reference in New Issue
Block a user