diff --git a/core/application/distributed.go b/core/application/distributed.go index 3c8c6ec32..00c39422d 100644 --- a/core/application/distributed.go +++ b/core/application/distributed.go @@ -161,6 +161,21 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade } xlog.Info("Node registry initialized") + // Seed declarative per-model scheduling config (LOCALAI_MODEL_SCHEDULING / + // LOCALAI_MODEL_SCHEDULING_CONFIG). Authoritative: overwrites matching models + // on every boot. Runs before the reconciler starts so the first tick already + // sees the desired state. Models not listed are left untouched. + if cfg.Distributed.ModelSchedulingJSON != "" || cfg.Distributed.ModelSchedulingConfigPath != "" { + schedConfigs, err := nodes.ParseSchedulingSeed(cfg.Distributed.ModelSchedulingJSON, cfg.Distributed.ModelSchedulingConfigPath) + if err != nil { + return nil, fmt.Errorf("parsing declarative model scheduling config: %w", err) + } + if err := registry.SeedModelScheduling(context.Background(), schedConfigs); err != nil { + return nil, fmt.Errorf("seeding declarative model scheduling config: %w", err) + } + xlog.Info("Applied declarative model scheduling config", "models", len(schedConfigs)) + } + // Collect SmartRouter option values; the router itself is created after all // dependencies (including FileStager and Unloader) are ready. var routerAuthToken string diff --git a/core/cli/run.go b/core/cli/run.go index 65cb07c0f..d011f3293 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -172,6 +172,8 @@ type RunCMD struct { NatsTLSCert string `env:"LOCALAI_NATS_TLS_CERT" type:"existingfile" help:"Client certificate for NATS mTLS" group:"distributed"` NatsTLSKey string `env:"LOCALAI_NATS_TLS_KEY" type:"existingfile" help:"Client private key for NATS mTLS" group:"distributed"` ExposeNodeHeader bool `env:"LOCALAI_EXPOSE_NODE_HEADER" default:"false" help:"Set the X-LocalAI-Node response header on inference responses (OpenAI chat/completions/embeddings, Anthropic /v1/messages, Ollama /api/chat,/api/generate,/api/embed) with the ID of the worker that served the request. Disabled by default: the node ID reveals internal topology and should not be exposed on a public endpoint. Best-effort: under heavy concurrency the header may reflect a recent routing decision rather than this exact request's." group:"distributed"` + ModelScheduling string `env:"LOCALAI_MODEL_SCHEDULING" help:"Declarative per-model scheduling config applied at startup (inline JSON list of {model_name,node_selector,min_replicas,max_replicas,replicas:\"all\"}). Authoritative: overwrites matching models on every boot. Distributed mode only." group:"distributed"` + ModelSchedulingConfig string `env:"LOCALAI_MODEL_SCHEDULING_CONFIG" help:"Path to a YAML file with the same per-model scheduling list as LOCALAI_MODEL_SCHEDULING. Distributed mode only." group:"distributed"` Version bool @@ -347,6 +349,15 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { if r.ExposeNodeHeader { opts = append(opts, config.WithExposeNodeHeader(true)) } + if r.ModelScheduling != "" { + opts = append(opts, config.WithModelSchedulingJSON(r.ModelScheduling)) + } + if r.ModelSchedulingConfig != "" { + opts = append(opts, config.WithModelSchedulingConfigPath(r.ModelSchedulingConfig)) + } + if !r.Distributed && (r.ModelScheduling != "" || r.ModelSchedulingConfig != "") { + xlog.Warn("LOCALAI_MODEL_SCHEDULING / LOCALAI_MODEL_SCHEDULING_CONFIG is set but distributed mode is disabled (LOCALAI_DISTRIBUTED=false) - ignoring") + } if r.DisableMetricsEndpoint { opts = append(opts, config.DisableMetricsEndpoint) diff --git a/core/config/distributed_config.go b/core/config/distributed_config.go index d07e2a825..3403487a9 100644 --- a/core/config/distributed_config.go +++ b/core/config/distributed_config.go @@ -84,6 +84,12 @@ type DistributedConfig struct { // drives the background eviction cadence (eviction runs every TTL/2). Zero // means use the prefixcache package default (5m). PrefixCacheTTL time.Duration + // ModelSchedulingJSON is an inline JSON list of per-model scheduling configs + // applied authoritatively at startup (LOCALAI_MODEL_SCHEDULING). + ModelSchedulingJSON string + // ModelSchedulingConfigPath is a path to a YAML file with the same list + // (LOCALAI_MODEL_SCHEDULING_CONFIG). + ModelSchedulingConfigPath string } // Validate checks that the distributed configuration is internally consistent. @@ -290,6 +296,21 @@ func WithPrefixCacheTTL(d time.Duration) AppOption { } } +// WithModelSchedulingJSON sets the inline-JSON declarative scheduling config. +func WithModelSchedulingJSON(s string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.ModelSchedulingJSON = s + } +} + +// WithModelSchedulingConfigPath sets the path to a YAML declarative scheduling +// config file. +func WithModelSchedulingConfigPath(path string) AppOption { + return func(o *ApplicationConfig) { + o.Distributed.ModelSchedulingConfigPath = path + } +} + // Flag names for distributed timeout / interval configuration. These are // the kebab-case identifiers kong derives from the matching RunCMD struct // fields; they appear in Validate error messages and any other operator- diff --git a/core/http/endpoints/localai/nodes.go b/core/http/endpoints/localai/nodes.go index 930070506..5a6edab22 100644 --- a/core/http/endpoints/localai/nodes.go +++ b/core/http/endpoints/localai/nodes.go @@ -937,12 +937,13 @@ func GetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { // distinguishable from an explicit zero. On update, an omitted prefix-cache // field preserves the model's previously-configured value instead of resetting // it (see SetSchedulingEndpoint's PATCH-style merge). ModelName, NodeSelector, -// MinReplicas and MaxReplicas keep their full-replace PUT semantics. +// MinReplicas, MaxReplicas and SpreadAll keep their full-replace PUT semantics. type SetSchedulingRequest struct { ModelName string `json:"model_name"` NodeSelector map[string]string `json:"node_selector,omitempty"` MinReplicas int `json:"min_replicas"` MaxReplicas int `json:"max_replicas"` + SpreadAll bool `json:"spread_all,omitempty"` RoutePolicy *string `json:"route_policy,omitempty"` BalanceAbsThreshold *int `json:"balance_abs_threshold,omitempty"` BalanceRelThreshold *float64 `json:"balance_rel_threshold,omitempty"` @@ -959,6 +960,9 @@ func validateSchedulingRequest(req SetSchedulingRequest, routePolicy string, abs if req.ModelName == "" { return errors.New("model_name is required") } + if req.SpreadAll && (req.MinReplicas != 0 || req.MaxReplicas != 0) { + return errors.New("spread_all and min_replicas/max_replicas are mutually exclusive") + } if req.MinReplicas < 0 { return errors.New("min_replicas must be >= 0") } @@ -1045,6 +1049,7 @@ func SetSchedulingEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc { NodeSelector: selectorJSON, MinReplicas: req.MinReplicas, MaxReplicas: req.MaxReplicas, + SpreadAll: req.SpreadAll, RoutePolicy: routePolicy, BalanceAbsThreshold: absThr, BalanceRelThreshold: relThr, diff --git a/core/http/endpoints/localai/nodes_scheduling_test.go b/core/http/endpoints/localai/nodes_scheduling_test.go new file mode 100644 index 000000000..927a41be2 --- /dev/null +++ b/core/http/endpoints/localai/nodes_scheduling_test.go @@ -0,0 +1,22 @@ +package localai + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("validateSchedulingRequest spread_all", func() { + It("rejects spread_all combined with min_replicas", func() { + err := validateSchedulingRequest(SetSchedulingRequest{ + ModelName: "m", SpreadAll: true, MinReplicas: 2, + }, "", 0, 0, 0) + Expect(err).To(MatchError(ContainSubstring("mutually exclusive"))) + }) + + It("accepts spread_all alone", func() { + err := validateSchedulingRequest(SetSchedulingRequest{ + ModelName: "m", SpreadAll: true, + }, "", 0, 0, 0) + Expect(err).ToNot(HaveOccurred()) + }) +}) diff --git a/core/http/react-ui/src/pages/Nodes.jsx b/core/http/react-ui/src/pages/Nodes.jsx index f2eb9d955..06372f4b1 100644 --- a/core/http/react-ui/src/pages/Nodes.jsx +++ b/core/http/react-ui/src/pages/Nodes.jsx @@ -506,6 +506,7 @@ function SchedulingForm({ onSave, onCancel }) { const isValid = () => { if (!modelName) return false if (mode === 'placement') return hasSelector + if (mode === 'spread') return true return minReplicas > 0 || maxReplicas > 0 } @@ -513,8 +514,9 @@ function SchedulingForm({ onSave, onCancel }) { onSave({ model_name: modelName, node_selector: hasSelector ? selector : undefined, - min_replicas: mode === 'placement' ? 0 : minReplicas, - max_replicas: mode === 'placement' ? 0 : maxReplicas, + min_replicas: mode === 'autoscaling' ? minReplicas : 0, + max_replicas: mode === 'autoscaling' ? maxReplicas : 0, + spread_all: mode === 'spread', route_policy: routePolicy, balance_abs_threshold: balanceAbsThreshold, balance_rel_threshold: balanceRelThreshold, @@ -542,10 +544,19 @@ function SchedulingForm({ onSave, onCancel }) { >