mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-14 11:49:33 -04:00
feat(distributed): declarative per-model scheduling via env/args (#10308)
* feat(distributed): add SpreadAll column and authoritative scheduling seeding Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): parse declarative model scheduling config (env/file) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): reconcile spread_all to one replica per matching node Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): wire LOCALAI_MODEL_SCHEDULING env/args and startup seeding Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): expose spread_all on the scheduling API endpoint Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): add spread-to-all-nodes mode to the scheduling UI Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * docs(distributed): document LOCALAI_MODEL_SCHEDULING env/args Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * docs(distributed): clarify replica modes and all-nodes spread in scheduling config Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -161,6 +161,21 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
|
||||
}
|
||||
xlog.Info("Node registry initialized")
|
||||
|
||||
// Seed declarative per-model scheduling config (LOCALAI_MODEL_SCHEDULING /
|
||||
// LOCALAI_MODEL_SCHEDULING_CONFIG). Authoritative: overwrites matching models
|
||||
// on every boot. Runs before the reconciler starts so the first tick already
|
||||
// sees the desired state. Models not listed are left untouched.
|
||||
if cfg.Distributed.ModelSchedulingJSON != "" || cfg.Distributed.ModelSchedulingConfigPath != "" {
|
||||
schedConfigs, err := nodes.ParseSchedulingSeed(cfg.Distributed.ModelSchedulingJSON, cfg.Distributed.ModelSchedulingConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing declarative model scheduling config: %w", err)
|
||||
}
|
||||
if err := registry.SeedModelScheduling(context.Background(), schedConfigs); err != nil {
|
||||
return nil, fmt.Errorf("seeding declarative model scheduling config: %w", err)
|
||||
}
|
||||
xlog.Info("Applied declarative model scheduling config", "models", len(schedConfigs))
|
||||
}
|
||||
|
||||
// Collect SmartRouter option values; the router itself is created after all
|
||||
// dependencies (including FileStager and Unloader) are ready.
|
||||
var routerAuthToken string
|
||||
|
||||
@@ -172,6 +172,8 @@ type RunCMD struct {
|
||||
NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"`
|
||||
NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"`
|
||||
ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"`
|
||||
ModelScheduling string `env:"LOCALAI_MODEL_SCHEDULING" help:"Declarative per-model scheduling config applied at startup (inline JSON list of {model_name,node_selector,min_replicas,max_replicas,replicas:\"all\"}). Authoritative: overwrites matching models on every boot. Distributed mode only." group:"distributed"`
|
||||
ModelSchedulingConfig string `env:"LOCALAI_MODEL_SCHEDULING_CONFIG" help:"Path to a YAML file with the same per-model scheduling list as LOCALAI_MODEL_SCHEDULING. Distributed mode only." group:"distributed"`
|
||||
|
||||
Version bool
|
||||
|
||||
@@ -347,6 +349,15 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
if r.ExposeNodeHeader {
|
||||
opts = append(opts, config.WithExposeNodeHeader(true))
|
||||
}
|
||||
if r.ModelScheduling != "" {
|
||||
opts = append(opts, config.WithModelSchedulingJSON(r.ModelScheduling))
|
||||
}
|
||||
if r.ModelSchedulingConfig != "" {
|
||||
opts = append(opts, config.WithModelSchedulingConfigPath(r.ModelSchedulingConfig))
|
||||
}
|
||||
if !r.Distributed && (r.ModelScheduling != "" || r.ModelSchedulingConfig != "") {
|
||||
xlog.Warn("LOCALAI_MODEL_SCHEDULING / LOCALAI_MODEL_SCHEDULING_CONFIG is set but distributed mode is disabled (LOCALAI_DISTRIBUTED=false) - ignoring")
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
opts = append(opts, config.DisableMetricsEndpoint)
|
||||
|
||||
@@ -84,6 +84,12 @@ type DistributedConfig struct {
|
||||
// drives the background eviction cadence (eviction runs every TTL/2). Zero
|
||||
// means use the prefixcache package default (5m).
|
||||
PrefixCacheTTL time.Duration
|
||||
// ModelSchedulingJSON is an inline JSON list of per-model scheduling configs
|
||||
// applied authoritatively at startup (LOCALAI_MODEL_SCHEDULING).
|
||||
ModelSchedulingJSON string
|
||||
// ModelSchedulingConfigPath is a path to a YAML file with the same list
|
||||
// (LOCALAI_MODEL_SCHEDULING_CONFIG).
|
||||
ModelSchedulingConfigPath string
|
||||
}
|
||||
|
||||
// Validate checks that the distributed configuration is internally consistent.
|
||||
@@ -290,6 +296,21 @@ func WithPrefixCacheTTL(d time.Duration) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithModelSchedulingJSON sets the inline-JSON declarative scheduling config.
|
||||
func WithModelSchedulingJSON(s string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.ModelSchedulingJSON = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithModelSchedulingConfigPath sets the path to a YAML declarative scheduling
|
||||
// config file.
|
||||
func WithModelSchedulingConfigPath(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.Distributed.ModelSchedulingConfigPath = path
|
||||
}
|
||||
}
|
||||
|
||||
// Flag names for distributed timeout / interval configuration. These are
|
||||
// the kebab-case identifiers kong derives from the matching RunCMD struct
|
||||
// fields; they appear in Validate error messages and any other operator-
|
||||
|
||||
@@ -937,12 +937,13 @@ func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
// distinguishable from an explicit zero. On update, an omitted prefix-cache
|
||||
// field preserves the model's previously-configured value instead of resetting
|
||||
// it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector,
|
||||
// MinReplicas and MaxReplicas keep their full-replace PUT semantics.
|
||||
// MinReplicas, MaxReplicas and SpreadAll keep their full-replace PUT semantics.
|
||||
type SetSchedulingRequest struct {
|
||||
ModelName string `json:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas"`
|
||||
MaxReplicas int `json:"max_replicas"`
|
||||
SpreadAll bool `json:"spread_all,omitempty"`
|
||||
RoutePolicy *string `json:"route_policy,omitempty"`
|
||||
BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"`
|
||||
BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"`
|
||||
@@ -959,6 +960,9 @@ func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, abs
|
||||
if req.ModelName == "" {
|
||||
return errors.New("model_name is required")
|
||||
}
|
||||
if req.SpreadAll && (req.MinReplicas != 0 || req.MaxReplicas != 0) {
|
||||
return errors.New("spread_all and min_replicas/max_replicas are mutually exclusive")
|
||||
}
|
||||
if req.MinReplicas < 0 {
|
||||
return errors.New("min_replicas must be >= 0")
|
||||
}
|
||||
@@ -1045,6 +1049,7 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: req.MinReplicas,
|
||||
MaxReplicas: req.MaxReplicas,
|
||||
SpreadAll: req.SpreadAll,
|
||||
RoutePolicy: routePolicy,
|
||||
BalanceAbsThreshold: absThr,
|
||||
BalanceRelThreshold: relThr,
|
||||
|
||||
22
core/http/endpoints/localai/nodes_scheduling_test.go
Normal file
22
core/http/endpoints/localai/nodes_scheduling_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("validateSchedulingRequest spread_all", func() {
|
||||
It("rejects spread_all combined with min_replicas", func() {
|
||||
err := validateSchedulingRequest(SetSchedulingRequest{
|
||||
ModelName: "m", SpreadAll: true, MinReplicas: 2,
|
||||
}, "", 0, 0, 0)
|
||||
Expect(err).To(MatchError(ContainSubstring("mutually exclusive")))
|
||||
})
|
||||
|
||||
It("accepts spread_all alone", func() {
|
||||
err := validateSchedulingRequest(SetSchedulingRequest{
|
||||
ModelName: "m", SpreadAll: true,
|
||||
}, "", 0, 0, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
@@ -506,6 +506,7 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
const isValid = () => {
|
||||
if (!modelName) return false
|
||||
if (mode === 'placement') return hasSelector
|
||||
if (mode === 'spread') return true
|
||||
return minReplicas > 0 || maxReplicas > 0
|
||||
}
|
||||
|
||||
@@ -513,8 +514,9 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
onSave({
|
||||
model_name: modelName,
|
||||
node_selector: hasSelector ? selector : undefined,
|
||||
min_replicas: mode === 'placement' ? 0 : minReplicas,
|
||||
max_replicas: mode === 'placement' ? 0 : maxReplicas,
|
||||
min_replicas: mode === 'autoscaling' ? minReplicas : 0,
|
||||
max_replicas: mode === 'autoscaling' ? maxReplicas : 0,
|
||||
spread_all: mode === 'spread',
|
||||
route_policy: routePolicy,
|
||||
balance_abs_threshold: balanceAbsThreshold,
|
||||
balance_rel_threshold: balanceRelThreshold,
|
||||
@@ -542,10 +544,19 @@ function SchedulingForm({ onSave, onCancel }) {
|
||||
>
|
||||
<i className="fas fa-arrows-up-down" aria-hidden="true" /> Auto-scale
|
||||
</button>
|
||||
<button
|
||||
type="button" role="radio" aria-checked={mode === 'spread'}
|
||||
className={`segmented__item${mode === 'spread' ? ' is-active' : ''}`}
|
||||
onClick={() => setMode('spread')}
|
||||
>
|
||||
<i className="fas fa-network-wired" aria-hidden="true" /> Spread to all
|
||||
</button>
|
||||
</div>
|
||||
<p style={{ fontSize: '0.8125rem', color: 'var(--color-text-muted)', margin: '0 0 var(--spacing-lg) 0' }}>
|
||||
{mode === 'placement'
|
||||
? 'Restrict this model to specific nodes. Loaded on demand, evictable when idle.'
|
||||
: mode === 'spread'
|
||||
? 'Run one replica on every node matching the selector (all healthy nodes when empty). Tracks nodes joining and leaving.'
|
||||
: 'Maintain a target replica count across the cluster. Min \u2265 1 protects from eviction.'}
|
||||
</p>
|
||||
|
||||
@@ -1563,10 +1574,11 @@ export default function Nodes() {
|
||||
</tr></thead>
|
||||
<tbody>
|
||||
{schedulingConfigs.map(cfg => {
|
||||
const isAutoScaling = cfg.min_replicas > 0 || cfg.max_replicas > 0
|
||||
const isSpread = !!cfg.spread_all
|
||||
const isAutoScaling = !isSpread && (cfg.min_replicas > 0 || cfg.max_replicas > 0)
|
||||
const hasSelector = !!cfg.node_selector
|
||||
const modeLabel = isAutoScaling ? 'Auto-scaling' : hasSelector ? 'Placement' : 'Inactive'
|
||||
const modeColor = isAutoScaling ? 'var(--color-success)' : hasSelector ? 'var(--color-primary)' : 'var(--color-text-muted)'
|
||||
const modeLabel = isSpread ? 'Spread' : isAutoScaling ? 'Auto-scaling' : hasSelector ? 'Placement' : 'Inactive'
|
||||
const modeColor = isSpread ? 'var(--color-warning)' : isAutoScaling ? 'var(--color-success)' : hasSelector ? 'var(--color-primary)' : 'var(--color-text-muted)'
|
||||
// Cooldown: reconciler tripped the circuit breaker because cluster
|
||||
// capacity is exhausted. Surface so the operator sees it instead
|
||||
// of the model silently failing to scale.
|
||||
@@ -1597,10 +1609,16 @@ export default function Nodes() {
|
||||
})() : <span style={{ color: 'var(--color-text-muted)', fontSize: '0.8125rem' }}>Any node</span>}
|
||||
</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)' }}>
|
||||
{isAutoScaling ? cfg.min_replicas : '-'}
|
||||
{isSpread
|
||||
? <span style={{
|
||||
display: 'inline-block', fontSize: '0.75rem', padding: '2px 8px', borderRadius: "var(--radius-sm)",
|
||||
background: 'var(--color-bg-tertiary)', border: '1px solid var(--color-warning)',
|
||||
color: 'var(--color-warning)', fontWeight: 600, fontFamily: 'var(--font-sans)',
|
||||
}}>Spread: all matching nodes</span>
|
||||
: isAutoScaling ? cfg.min_replicas : '-'}
|
||||
</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)' }}>
|
||||
{isAutoScaling ? (cfg.max_replicas || 'no limit') : '-'}
|
||||
{isSpread ? '-' : isAutoScaling ? (cfg.max_replicas || 'no limit') : '-'}
|
||||
</td>
|
||||
<td style={{ fontSize: '0.8125rem' }}>
|
||||
{cfg.route_policy || 'default'}
|
||||
|
||||
@@ -399,6 +399,28 @@ func (rc *ReplicaReconciler) candidateNodeIDsForSelector(ctx context.Context, cf
|
||||
}
|
||||
|
||||
func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedulingConfig) {
|
||||
// spread_all: derive a dynamic replica target equal to the number of nodes
|
||||
// currently matching the selector (all healthy backend nodes when the
|
||||
// selector is empty). Feeding it through Min==Max==target reuses every
|
||||
// existing path: the floor scales up toward target (capped at capacity),
|
||||
// Max==target stops busy-burst/pressure overshooting, and idle scale-down
|
||||
// trims above target. The target re-tracks node join/leave each tick. cfg is
|
||||
// a by-value copy, so mutating it here is local to this tick.
|
||||
if cfg.SpreadAll {
|
||||
matched, err := rc.registry.FindNodesBySelector(ctx, parseSelector(cfg.NodeSelector))
|
||||
if err != nil {
|
||||
xlog.Warn("Reconciler: spread_all failed to resolve matching nodes", "model", cfg.ModelName, "error", err)
|
||||
return
|
||||
}
|
||||
if len(matched) == 0 {
|
||||
xlog.Info("Reconciler: spread_all has no matching nodes; nothing to schedule",
|
||||
"model", cfg.ModelName, "selector", cfg.NodeSelector)
|
||||
return
|
||||
}
|
||||
cfg.MinReplicas = len(matched)
|
||||
cfg.MaxReplicas = len(matched)
|
||||
}
|
||||
|
||||
// Cooldown gate: if we previously decided this config is unsatisfiable,
|
||||
// don't even bother checking until the cooldown expires. ClearAllUnsatisfiable
|
||||
// (fired by node lifecycle events) bypasses this by zeroing the column.
|
||||
|
||||
@@ -34,6 +34,13 @@ func (f *fakeScheduler) ScheduleAndLoadModel(_ context.Context, modelName string
|
||||
return f.scheduleNode, f.scheduleErr
|
||||
}
|
||||
|
||||
func mustGetSched(r *NodeRegistry, model string) ModelSchedulingConfig {
|
||||
cfg, err := r.GetModelScheduling(context.Background(), model)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg).ToNot(BeNil())
|
||||
return *cfg
|
||||
}
|
||||
|
||||
var _ = Describe("ReplicaReconciler", func() {
|
||||
var (
|
||||
db *gorm.DB
|
||||
@@ -78,6 +85,45 @@ var _ = Describe("ReplicaReconciler", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), cfg)).To(Succeed())
|
||||
}
|
||||
|
||||
Context("spread_all mode", func() {
|
||||
It("targets one replica per matching node (empty selector = all nodes)", func() {
|
||||
n1 := registerNode("s1", "10.1.0.1:50051")
|
||||
registerNode("s2", "10.1.0.2:50051")
|
||||
// spread config, no selector -> all healthy backend nodes (2)
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
|
||||
ModelName: "spread-model", SpreadAll: true,
|
||||
})).To(Succeed())
|
||||
|
||||
scheduler := &fakeScheduler{scheduleNode: n1}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
})
|
||||
|
||||
reconciler.reconcileModel(context.Background(), mustGetSched(registry, "spread-model"))
|
||||
|
||||
// With current==0 and a target of 2, the MinReplicas floor path
|
||||
// schedules up to cluster capacity (2 nodes).
|
||||
Expect(len(scheduler.scheduleCalls)).To(Equal(2))
|
||||
})
|
||||
|
||||
It("is a no-op when no nodes match", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
|
||||
ModelName: "spread-model", SpreadAll: true,
|
||||
NodeSelector: `{"tier":"nope"}`,
|
||||
})).To(Succeed())
|
||||
|
||||
scheduler := &fakeScheduler{}
|
||||
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
|
||||
Registry: registry,
|
||||
Scheduler: scheduler,
|
||||
})
|
||||
|
||||
reconciler.reconcileModel(context.Background(), mustGetSched(registry, "spread-model"))
|
||||
Expect(scheduler.scheduleCalls).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("model below min_replicas", func() {
|
||||
It("scales up to min_replicas", func() {
|
||||
node := registerNode("node-1", "10.0.0.1:50051")
|
||||
|
||||
@@ -135,13 +135,18 @@ type NodeLabel struct {
|
||||
// - Both → auto-scale on matching nodes
|
||||
// - Neither → no-op (default behavior)
|
||||
//
|
||||
// Auto-scaling is enabled when MinReplicas > 0 or MaxReplicas > 0.
|
||||
// Auto-scaling is enabled when MinReplicas > 0, MaxReplicas > 0, or SpreadAll is set.
|
||||
type ModelSchedulingConfig struct {
|
||||
ID string `gorm:"primaryKey;size:36" json:"id"`
|
||||
ModelName string `gorm:"uniqueIndex;size:255" json:"model_name"`
|
||||
NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...}
|
||||
MinReplicas int `gorm:"default:0" json:"min_replicas"`
|
||||
MaxReplicas int `gorm:"default:0" json:"max_replicas"`
|
||||
// SpreadAll requests one replica on every node matching NodeSelector
|
||||
// (every healthy backend node when the selector is empty), tracked as
|
||||
// nodes join and leave. Mutually exclusive with MinReplicas/MaxReplicas.
|
||||
// The reconciler turns this into a dynamic Min==Max target each tick.
|
||||
SpreadAll bool `gorm:"column:spread_all;default:false" json:"spread_all,omitempty"`
|
||||
// Prefix-cache-aware routing (epic #10063). RoutePolicy "" means inherit
|
||||
// the cluster-wide default. Thresholds are per-model overrides; 0 means
|
||||
// inherit the global default.
|
||||
@@ -1392,7 +1397,7 @@ func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSche
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "model_name"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{
|
||||
"node_selector", "min_replicas", "max_replicas",
|
||||
"node_selector", "min_replicas", "max_replicas", "spread_all",
|
||||
"route_policy", "balance_abs_threshold", "balance_rel_threshold", "min_prefix_match",
|
||||
"updated_at",
|
||||
}),
|
||||
@@ -1400,6 +1405,20 @@ func (r *NodeRegistry) SetModelScheduling(ctx context.Context, config *ModelSche
|
||||
Create(config).Error
|
||||
}
|
||||
|
||||
// SeedModelScheduling authoritatively applies a batch of scheduling configs at
|
||||
// startup. Each config is upserted (full-replace on model_name), overwriting any
|
||||
// prior row for that model. Models not present in configs are left untouched.
|
||||
func (r *NodeRegistry) SeedModelScheduling(ctx context.Context, configs []ModelSchedulingConfig) error {
|
||||
for i := range configs {
|
||||
if err := r.SetModelScheduling(ctx, &configs[i]); err != nil {
|
||||
return fmt.Errorf("seeding scheduling config for model %q: %w", configs[i].ModelName, err)
|
||||
}
|
||||
xlog.Info("Seeded model scheduling config", "model", configs[i].ModelName,
|
||||
"spread_all", configs[i].SpreadAll, "min", configs[i].MinReplicas, "max", configs[i].MaxReplicas)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelScheduling returns the scheduling config for a model, or nil if none exists.
|
||||
func (r *NodeRegistry) GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error) {
|
||||
var config ModelSchedulingConfig
|
||||
@@ -1423,7 +1442,7 @@ func (r *NodeRegistry) ListModelSchedulings(ctx context.Context) ([]ModelSchedul
|
||||
// ListAutoScalingConfigs returns scheduling configs where auto-scaling is enabled.
|
||||
func (r *NodeRegistry) ListAutoScalingConfigs(ctx context.Context) ([]ModelSchedulingConfig, error) {
|
||||
var configs []ModelSchedulingConfig
|
||||
err := r.db.WithContext(ctx).Where("min_replicas > 0 OR max_replicas > 0").Find(&configs).Error
|
||||
err := r.db.WithContext(ctx).Where("min_replicas > 0 OR max_replicas > 0 OR spread_all = ?", true).Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1489,3 +1489,59 @@ var _ = Describe("NodeRegistry", func() {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("ModelScheduling spread + seeding", func() {
|
||||
var (
|
||||
db *gorm.DB
|
||||
registry *NodeRegistry
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
Skip("testcontainers requires Docker, not available on macOS CI")
|
||||
}
|
||||
db = testutil.SetupTestDB()
|
||||
var err error
|
||||
registry, err = NewNodeRegistry(db)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("persists and round-trips SpreadAll", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
|
||||
ModelName: "m", SpreadAll: true,
|
||||
})).To(Succeed())
|
||||
got, err := registry.GetModelScheduling(context.Background(), "m")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got.SpreadAll).To(BeTrue())
|
||||
})
|
||||
|
||||
It("includes SpreadAll configs in ListAutoScalingConfigs", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
|
||||
ModelName: "m", SpreadAll: true,
|
||||
})).To(Succeed())
|
||||
configs, err := registry.ListAutoScalingConfigs(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(configs).To(HaveLen(1))
|
||||
Expect(configs[0].ModelName).To(Equal("m"))
|
||||
})
|
||||
|
||||
It("seeds configs with authoritative upsert", func() {
|
||||
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
|
||||
ModelName: "m", MinReplicas: 9,
|
||||
})).To(Succeed())
|
||||
|
||||
err := registry.SeedModelScheduling(context.Background(), []ModelSchedulingConfig{
|
||||
{ModelName: "m", MinReplicas: 1, MaxReplicas: 2},
|
||||
{ModelName: "n", SpreadAll: true},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
m, _ := registry.GetModelScheduling(context.Background(), "m")
|
||||
Expect(m.MinReplicas).To(Equal(1))
|
||||
Expect(m.MaxReplicas).To(Equal(2))
|
||||
Expect(m.SpreadAll).To(BeFalse())
|
||||
|
||||
n, _ := registry.GetModelScheduling(context.Background(), "n")
|
||||
Expect(n.SpreadAll).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
171
core/services/nodes/scheduling_seed.go
Normal file
171
core/services/nodes/scheduling_seed.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ReplicasSpec parses the "replicas" convenience field used in the env/file
|
||||
// scheduling config. It accepts the string "all" (or boolean true) to mean
|
||||
// "spread one replica onto every matching node". The strings "" / "auto" and
|
||||
// boolean false leave SpreadAll unset and defer to min_replicas/max_replicas.
|
||||
// A numeric value is rejected with a hint pointing at min/max_replicas, which
|
||||
// are the dedicated fields for fixed counts.
|
||||
type ReplicasSpec struct {
|
||||
SpreadAll bool
|
||||
}
|
||||
|
||||
func (r *ReplicasSpec) set(v any) error {
|
||||
switch t := v.(type) {
|
||||
case nil:
|
||||
r.SpreadAll = false
|
||||
case bool:
|
||||
r.SpreadAll = t
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(t)) {
|
||||
case "all":
|
||||
r.SpreadAll = true
|
||||
case "", "auto":
|
||||
r.SpreadAll = false
|
||||
default:
|
||||
return fmt.Errorf("invalid replicas value %q (expected \"all\" or \"auto\")", t)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid replicas value %v (use min_replicas/max_replicas for a fixed count, or \"all\" to spread)", v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for the replicas alias.
|
||||
func (r *ReplicasSpec) UnmarshalJSON(b []byte) error {
|
||||
var v any
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.set(v)
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements yaml.Unmarshaler for the replicas alias.
|
||||
func (r *ReplicasSpec) UnmarshalYAML(value *yaml.Node) error {
|
||||
var v any
|
||||
if err := value.Decode(&v); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.set(v)
|
||||
}
|
||||
|
||||
// SeedSchedulingEntry is one entry in the env/file scheduling config. It mirrors
|
||||
// the API's SetSchedulingRequest shape, plus the "replicas" alias and the
|
||||
// canonical "spread_all" boolean.
|
||||
type SeedSchedulingEntry struct {
|
||||
ModelName string `json:"model_name" yaml:"model_name"`
|
||||
NodeSelector map[string]string `json:"node_selector,omitempty" yaml:"node_selector,omitempty"`
|
||||
MinReplicas int `json:"min_replicas,omitempty" yaml:"min_replicas,omitempty"`
|
||||
MaxReplicas int `json:"max_replicas,omitempty" yaml:"max_replicas,omitempty"`
|
||||
Replicas *ReplicasSpec `json:"replicas,omitempty" yaml:"replicas,omitempty"`
|
||||
SpreadAll bool `json:"spread_all,omitempty" yaml:"spread_all,omitempty"`
|
||||
|
||||
RoutePolicy string `json:"route_policy,omitempty" yaml:"route_policy,omitempty"`
|
||||
BalanceAbsThreshold int `json:"balance_abs_threshold,omitempty" yaml:"balance_abs_threshold,omitempty"`
|
||||
BalanceRelThreshold float64 `json:"balance_rel_threshold,omitempty" yaml:"balance_rel_threshold,omitempty"`
|
||||
MinPrefixMatch float64 `json:"min_prefix_match,omitempty" yaml:"min_prefix_match,omitempty"`
|
||||
}
|
||||
|
||||
// spread reports whether this entry requests spread-to-all-matching-nodes mode,
|
||||
// via either the canonical spread_all field or the replicas alias.
|
||||
func (e SeedSchedulingEntry) spread() bool {
|
||||
return e.SpreadAll || (e.Replicas != nil && e.Replicas.SpreadAll)
|
||||
}
|
||||
|
||||
// ValidateSeedEntry enforces the invariants of a single scheduling entry. It
|
||||
// mirrors the API's validateSchedulingRequest, with the added rule that spread
|
||||
// mode is mutually exclusive with explicit min/max replica counts.
|
||||
func ValidateSeedEntry(e SeedSchedulingEntry) error {
|
||||
if e.ModelName == "" {
|
||||
return fmt.Errorf("model_name is required")
|
||||
}
|
||||
if e.MinReplicas < 0 {
|
||||
return fmt.Errorf("min_replicas must be >= 0 (model %q)", e.ModelName)
|
||||
}
|
||||
if e.MaxReplicas < 0 {
|
||||
return fmt.Errorf("max_replicas must be >= 0 (model %q)", e.ModelName)
|
||||
}
|
||||
if e.spread() && (e.MinReplicas != 0 || e.MaxReplicas != 0) {
|
||||
return fmt.Errorf("spread (replicas: all) and min_replicas/max_replicas are mutually exclusive (model %q)", e.ModelName)
|
||||
}
|
||||
if e.MaxReplicas > 0 && e.MinReplicas > e.MaxReplicas {
|
||||
return fmt.Errorf("min_replicas must be <= max_replicas (model %q)", e.ModelName)
|
||||
}
|
||||
if err := prefixcache.ValidateThresholds(e.RoutePolicy, e.BalanceAbsThreshold, e.BalanceRelThreshold, e.MinPrefixMatch); err != nil {
|
||||
return fmt.Errorf("%w (model %q)", err, e.ModelName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e SeedSchedulingEntry) toConfig() (ModelSchedulingConfig, error) {
|
||||
selectorJSON := ""
|
||||
if len(e.NodeSelector) > 0 {
|
||||
b, err := json.Marshal(e.NodeSelector)
|
||||
if err != nil {
|
||||
return ModelSchedulingConfig{}, fmt.Errorf("serializing node_selector for model %q: %w", e.ModelName, err)
|
||||
}
|
||||
selectorJSON = string(b)
|
||||
}
|
||||
return ModelSchedulingConfig{
|
||||
ModelName: e.ModelName,
|
||||
NodeSelector: selectorJSON,
|
||||
MinReplicas: e.MinReplicas,
|
||||
MaxReplicas: e.MaxReplicas,
|
||||
SpreadAll: e.spread(),
|
||||
RoutePolicy: e.RoutePolicy,
|
||||
BalanceAbsThreshold: e.BalanceAbsThreshold,
|
||||
BalanceRelThreshold: e.BalanceRelThreshold,
|
||||
MinPrefixMatch: e.MinPrefixMatch,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseSchedulingSeed parses the inline-JSON and/or YAML-file scheduling config
|
||||
// into validated ModelSchedulingConfig rows ready to upsert. Entries from both
|
||||
// sources are concatenated (jsonStr first, then the file). Either argument may
|
||||
// be empty.
|
||||
func ParseSchedulingSeed(jsonStr, configPath string) ([]ModelSchedulingConfig, error) {
|
||||
var entries []SeedSchedulingEntry
|
||||
|
||||
if strings.TrimSpace(jsonStr) != "" {
|
||||
var fromJSON []SeedSchedulingEntry
|
||||
if err := json.Unmarshal([]byte(jsonStr), &fromJSON); err != nil {
|
||||
return nil, fmt.Errorf("parsing LOCALAI_MODEL_SCHEDULING JSON: %w", err)
|
||||
}
|
||||
entries = append(entries, fromJSON...)
|
||||
}
|
||||
|
||||
if configPath != "" {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading model scheduling config %q: %w", configPath, err)
|
||||
}
|
||||
var fromYAML []SeedSchedulingEntry
|
||||
if err := yaml.Unmarshal(data, &fromYAML); err != nil {
|
||||
return nil, fmt.Errorf("parsing model scheduling config %q: %w", configPath, err)
|
||||
}
|
||||
entries = append(entries, fromYAML...)
|
||||
}
|
||||
|
||||
configs := make([]ModelSchedulingConfig, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if err := ValidateSeedEntry(e); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg, err := e.toConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configs = append(configs, cfg)
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
75
core/services/nodes/scheduling_seed_test.go
Normal file
75
core/services/nodes/scheduling_seed_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ParseSchedulingSeed", func() {
|
||||
It("parses inline JSON with static min/max replicas", func() {
|
||||
configs, err := ParseSchedulingSeed(`[{"model_name":"m","node_selector":{"tier":"gpu"},"min_replicas":1,"max_replicas":4}]`, "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(configs).To(HaveLen(1))
|
||||
Expect(configs[0].ModelName).To(Equal("m"))
|
||||
Expect(configs[0].MinReplicas).To(Equal(1))
|
||||
Expect(configs[0].MaxReplicas).To(Equal(4))
|
||||
Expect(configs[0].SpreadAll).To(BeFalse())
|
||||
Expect(configs[0].NodeSelector).To(Equal(`{"tier":"gpu"}`))
|
||||
})
|
||||
|
||||
It("maps replicas: all to SpreadAll", func() {
|
||||
configs, err := ParseSchedulingSeed(`[{"model_name":"m","replicas":"all"}]`, "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(configs[0].SpreadAll).To(BeTrue())
|
||||
})
|
||||
|
||||
It("maps replicas: true to SpreadAll", func() {
|
||||
configs, err := ParseSchedulingSeed(`[{"model_name":"m","replicas":true}]`, "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(configs[0].SpreadAll).To(BeTrue())
|
||||
})
|
||||
|
||||
It("accepts the spread_all field directly", func() {
|
||||
configs, err := ParseSchedulingSeed(`[{"model_name":"m","spread_all":true}]`, "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(configs[0].SpreadAll).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects spread_all combined with min/max replicas", func() {
|
||||
_, err := ParseSchedulingSeed(`[{"model_name":"m","replicas":"all","min_replicas":2}]`, "")
|
||||
Expect(err).To(MatchError(ContainSubstring("mutually exclusive")))
|
||||
})
|
||||
|
||||
It("rejects a missing model_name", func() {
|
||||
_, err := ParseSchedulingSeed(`[{"min_replicas":1}]`, "")
|
||||
Expect(err).To(MatchError(ContainSubstring("model_name is required")))
|
||||
})
|
||||
|
||||
It("rejects a numeric replicas value pointing the user at min/max", func() {
|
||||
_, err := ParseSchedulingSeed(`[{"model_name":"m","replicas":3}]`, "")
|
||||
Expect(err).To(MatchError(ContainSubstring("min_replicas")))
|
||||
})
|
||||
|
||||
It("returns no configs for empty input", func() {
|
||||
configs, err := ParseSchedulingSeed("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(configs).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("parses a YAML file with replicas: all and a node_selector", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
path := filepath.Join(dir, "scheduling.yaml")
|
||||
yaml := "- model_name: m\n replicas: all\n node_selector:\n tier: gpu\n"
|
||||
Expect(os.WriteFile(path, []byte(yaml), 0o600)).To(Succeed())
|
||||
|
||||
configs, err := ParseSchedulingSeed("", path)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(configs).To(HaveLen(1))
|
||||
Expect(configs[0].ModelName).To(Equal("m"))
|
||||
Expect(configs[0].SpreadAll).To(BeTrue())
|
||||
Expect(configs[0].NodeSelector).To(Equal(`{"tier":"gpu"}`))
|
||||
})
|
||||
})
|
||||
@@ -604,6 +604,91 @@ All fields are optional and composable:
|
||||
- Replicas only: auto-scale across all nodes
|
||||
- Both: auto-scale on matching nodes only
|
||||
|
||||
### Declarative per-model scheduling (unattended installs)
|
||||
|
||||
In distributed mode you can declare per-model scheduling at startup, instead of
|
||||
using the WebUI/API. Config is **authoritative**: it is re-applied on every boot
|
||||
and overwrites the listed models (models not listed are left untouched).
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `LOCALAI_MODEL_SCHEDULING` | Inline JSON list of scheduling entries |
|
||||
| `LOCALAI_MODEL_SCHEDULING_CONFIG` | Path to a YAML file with the same list |
|
||||
|
||||
Entry fields: `model_name` (required), `node_selector` (a label map; **omit it to
|
||||
match every node**), and then **one of two replica modes** (they are mutually
|
||||
exclusive):
|
||||
|
||||
- **`replicas: all`** - static spread: place exactly **one replica on every
|
||||
matching node**, proactively, regardless of load, and keep it in sync as nodes
|
||||
join and leave. Use this for "run model X everywhere (with this label)".
|
||||
- **`min_replicas` / `max_replicas`** - elastic auto-scaling: keep at least
|
||||
`min_replicas` running, and burst **up to** `max_replicas` only when all
|
||||
replicas are busy, scaling back down to the minimum when idle. `max_replicas: 0`
|
||||
means **no upper bound** (grow to cluster capacity). To enable this mode you
|
||||
must set `min_replicas >= 1` or `max_replicas >= 1` - an entry with only
|
||||
`max_replicas: 0` (and no `replicas: all`) does nothing.
|
||||
|
||||
Net effect at a glance:
|
||||
|
||||
| Config | Behavior |
|
||||
|--------|----------|
|
||||
| `replicas: all` | One replica per matching node, placed immediately, tracks join/leave |
|
||||
| `min_replicas: 1, max_replicas: 0` | Always >=1, bursts to cluster capacity under load, back to 1 when idle |
|
||||
| `min_replicas: 2, max_replicas: 4` | Always >=2, bursts to at most 4 under load |
|
||||
|
||||
`node_selector` constrains which nodes a model may use; with no selector the
|
||||
model may use **all** healthy nodes. So "spread model X across all nodes" is just
|
||||
`replicas: all` with no `node_selector`. `replicas: all` targets one replica per
|
||||
matching node; with the default per-node cap of one replica per model this lands
|
||||
exactly one on each node (see the note below about `LOCALAI_MAX_REPLICAS_PER_MODEL`).
|
||||
|
||||
YAML example (`scheduling.yaml`):
|
||||
|
||||
```yaml
|
||||
# One replica on every GPU-labelled node (static spread, tracks join/leave):
|
||||
- model_name: gpt-oss
|
||||
node_selector:
|
||||
tier: gpu
|
||||
replicas: all
|
||||
|
||||
# One replica on EVERY node in the cluster (no selector = all nodes):
|
||||
- model_name: embeddings
|
||||
replicas: all
|
||||
|
||||
# Elastic on CPU nodes: always >=1, burst to capacity under load, 0 = no cap:
|
||||
- model_name: whisper
|
||||
node_selector:
|
||||
tier: cpu
|
||||
min_replicas: 1
|
||||
max_replicas: 0
|
||||
```
|
||||
|
||||
```bash
|
||||
LOCALAI_DISTRIBUTED=true \
|
||||
LOCALAI_MODEL_SCHEDULING_CONFIG=/etc/localai/scheduling.yaml \
|
||||
local-ai run
|
||||
```
|
||||
|
||||
Inline equivalent:
|
||||
|
||||
```bash
|
||||
LOCALAI_MODEL_SCHEDULING='[{"model_name":"gpt-oss","node_selector":{"tier":"gpu"},"replicas":"all"}]'
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- Because the config is authoritative, each listed model's **entire** scheduling
|
||||
row is replaced on every boot, including the optional prefix-cache routing
|
||||
overrides (`route_policy`, `balance_abs_threshold`, `balance_rel_threshold`,
|
||||
`min_prefix_match`). For a model you manage via this config, set those fields
|
||||
here too if you need non-default values; values set only through the API are
|
||||
reset on the next restart. Models not listed in the config are never touched.
|
||||
- `replicas: all` places one replica per matching node by relying on the default
|
||||
per-node cap of one replica per model. If you raise `LOCALAI_MAX_REPLICAS_PER_MODEL`
|
||||
on a worker above 1, the target count can be met by stacking replicas on fewer
|
||||
nodes rather than spreading one to each.
|
||||
|
||||
## Label Management API
|
||||
|
||||
| Method | Path | Description |
|
||||
|
||||
Reference in New Issue
Block a user