mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-18 21:45:01 -04:00
feat(concurrency-groups): per-model exclusive groups for backend loading (#9662)
* feat(concurrency-groups): per-model exclusive groups for backend loading Adds `concurrency_groups: [...]` to model YAML configs. Two models that share a group cannot be loaded concurrently on the same node — loading one evicts the others, reusing the existing pinned/busy/retry policy from LRU eviction. Layered design: - Watchdog (pkg/model): per-node correctness floor — on every Load(), evict any loaded model that shares a group with the requested one. Pinned skips surface NeedMore so the loader retries (and ultimately logs a clear warning), instead of silently allowing the rule to be violated. - Distributed scheduler (core/services/nodes): soft anti-affinity hint — scheduleNewModel prefers nodes that don't already host a same-group model, falling back to eviction only if every candidate has a conflict. Composes with NodeSelector at the same point in the candidate pipeline. Per-node, not cluster-wide: VRAM is a node-local resource, and two heavy models running on different nodes is fine. The ConfigLoader is wired into SmartRouter via a small ConcurrencyConflictResolver interface so the nodes package keeps a narrow surface on core/config. Refactors the inner LRU eviction body into a shared collectEvictionsLocked helper and the loader retry loop into retryEnforce(fn, maxRetries, interval), so both LRU and group enforcement share busy/pinned/retry semantics. Closes #9659. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(watchdog): sync pinned + concurrency_groups at startup The startup-time watchdog setup lives in initializeWatchdog (startup.go), not in startWatchdog (watchdog.go). The latter is only invoked from the runtime-settings RestartWatchdog path. As a result, neither SyncPinnedModelsToWatchdog nor SyncModelGroupsToWatchdog ran at boot, so `pinned: true` and `concurrency_groups: [...]` only became effective after a settings-driven watchdog restart. Fix by adding both sync calls to initializeWatchdog. Confirmed end-to-end: loading model A in group "heavy", then C with no group (coexists), then B in group "heavy" now correctly evicts A and leaves [B, C]. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(test): satisfy errcheck on new os.Remove in concurrency_groups spec CI lint runs new-from-merge-base, so the existing pre-existing `defer os.Remove(tmp.Name())` lines are baseline-grandfathered but the one introduced by the concurrency_groups YAML round-trip test is held to errcheck. Wrap the remove in a closure that discards the error. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
22ae415695
commit
bbcaebc1ef
@@ -71,7 +71,9 @@ func (ds *DistributedServices) Shutdown() {
|
||||
// initDistributed validates distributed mode prerequisites and initializes
|
||||
// NATS, object storage, node registry, and instance identity.
|
||||
// Returns nil if distributed mode is not enabled.
|
||||
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
|
||||
// configLoader is used by the SmartRouter to compute concurrency-group
|
||||
// anti-affinity at placement time (#9659); it may be nil in tests.
|
||||
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoader *config.ModelConfigLoader) (*DistributedServices, error) {
|
||||
if !cfg.Distributed.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -234,12 +236,17 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
|
||||
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
|
||||
|
||||
// All dependencies ready — build SmartRouter with all options at once
|
||||
var conflictResolver nodes.ConcurrencyConflictResolver
|
||||
if configLoader != nil {
|
||||
conflictResolver = configLoader
|
||||
}
|
||||
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
|
||||
Unloader: remoteUnloader,
|
||||
FileStager: fileStager,
|
||||
GalleriesJSON: routerGalleriesJSON,
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
Unloader: remoteUnloader,
|
||||
FileStager: fileStager,
|
||||
GalleriesJSON: routerGalleriesJSON,
|
||||
AuthToken: routerAuthToken,
|
||||
DB: authDB,
|
||||
ConflictResolver: conflictResolver,
|
||||
})
|
||||
|
||||
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
|
||||
|
||||
@@ -139,7 +139,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
|
||||
// Initialize distributed mode services (NATS, object storage, node registry)
|
||||
distSvc, err := initDistributed(options, application.authDB)
|
||||
distSvc, err := initDistributed(options, application.authDB, application.ModelConfigLoader())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
|
||||
}
|
||||
@@ -680,6 +680,12 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon
|
||||
options.LRUEvictionRetryInterval,
|
||||
)
|
||||
|
||||
// Sync per-model state from configs to the watchdog. Without this,
|
||||
// `pinned: true` and `concurrency_groups:` are only honored after a
|
||||
// settings-driven RestartWatchdog and never at boot.
|
||||
application.SyncPinnedModelsToWatchdog()
|
||||
application.SyncModelGroupsToWatchdog()
|
||||
|
||||
// Start watchdog goroutine if any periodic checks are enabled
|
||||
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
||||
// But memory reclaimer needs the Run() loop for periodic checking
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
@@ -26,6 +27,40 @@ func (a *Application) SyncPinnedModelsToWatchdog() {
|
||||
xlog.Debug("Synced pinned models to watchdog", "count", len(pinned))
|
||||
}
|
||||
|
||||
// SyncModelGroupsToWatchdog reads concurrency_groups from all model configs and
|
||||
// updates the watchdog so EnforceGroupExclusivity has the current view.
|
||||
func (a *Application) SyncModelGroupsToWatchdog() {
|
||||
cl := a.ModelConfigLoader()
|
||||
if cl == nil {
|
||||
return
|
||||
}
|
||||
wd := a.modelLoader.GetWatchDog()
|
||||
if wd == nil {
|
||||
return
|
||||
}
|
||||
groups := extractModelGroupsFromConfigs(cl.GetAllModelsConfigs())
|
||||
wd.ReplaceModelGroups(groups)
|
||||
xlog.Debug("Synced concurrency groups to watchdog", "count", len(groups))
|
||||
}
|
||||
|
||||
// extractModelGroupsFromConfigs builds the model→groups map the watchdog
|
||||
// expects. Disabled models are skipped — their declared groups should not
|
||||
// block other models from loading.
|
||||
func extractModelGroupsFromConfigs(configs []config.ModelConfig) map[string][]string {
|
||||
out := make(map[string][]string)
|
||||
for _, cfg := range configs {
|
||||
if cfg.IsDisabled() {
|
||||
continue
|
||||
}
|
||||
gs := cfg.GetConcurrencyGroups()
|
||||
if len(gs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[cfg.Name] = gs
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (a *Application) StopWatchdog() error {
|
||||
if a.watchdogStop != nil {
|
||||
close(a.watchdogStop)
|
||||
@@ -65,8 +100,9 @@ func (a *Application) startWatchdog() error {
|
||||
// Set the watchdog on the model loader
|
||||
a.modelLoader.SetWatchDog(wd)
|
||||
|
||||
// Sync pinned models from config to the watchdog
|
||||
// Sync pinned models and concurrency groups from config to the watchdog
|
||||
a.SyncPinnedModelsToWatchdog()
|
||||
a.SyncModelGroupsToWatchdog()
|
||||
|
||||
// Start watchdog goroutine if any periodic checks are enabled
|
||||
// LRU eviction doesn't need the Run() loop - it's triggered on model load
|
||||
@@ -148,8 +184,9 @@ func (a *Application) RestartWatchdog() error {
|
||||
newWD.RestoreState(oldState)
|
||||
}
|
||||
|
||||
// Re-sync pinned models after restart
|
||||
// Re-sync pinned models and concurrency groups after restart
|
||||
a.SyncPinnedModelsToWatchdog()
|
||||
a.SyncModelGroupsToWatchdog()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
47
core/application/watchdog_test.go
Normal file
47
core/application/watchdog_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("extractModelGroupsFromConfigs", func() {
|
||||
It("returns an empty map when no config declares groups", func() {
|
||||
out := extractModelGroupsFromConfigs([]config.ModelConfig{
|
||||
{Name: "a"},
|
||||
{Name: "b"},
|
||||
})
|
||||
Expect(out).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns each model's normalized groups", func() {
|
||||
out := extractModelGroupsFromConfigs([]config.ModelConfig{
|
||||
{Name: "a", ConcurrencyGroups: []string{" heavy ", "vision", "heavy"}},
|
||||
{Name: "b", ConcurrencyGroups: []string{"heavy"}},
|
||||
{Name: "c"}, // no groups → omitted
|
||||
})
|
||||
Expect(out).To(HaveLen(2))
|
||||
Expect(out["a"]).To(Equal([]string{"heavy", "vision"}))
|
||||
Expect(out["b"]).To(Equal([]string{"heavy"}))
|
||||
Expect(out).ToNot(HaveKey("c"))
|
||||
})
|
||||
|
||||
It("omits models whose groups normalize to empty", func() {
|
||||
out := extractModelGroupsFromConfigs([]config.ModelConfig{
|
||||
{Name: "blanks", ConcurrencyGroups: []string{"", " "}},
|
||||
})
|
||||
Expect(out).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("skips disabled models so they cannot block loading after re-enable", func() {
|
||||
disabled := true
|
||||
out := extractModelGroupsFromConfigs([]config.ModelConfig{
|
||||
{Name: "a", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled},
|
||||
{Name: "b", ConcurrencyGroups: []string{"heavy"}},
|
||||
})
|
||||
Expect(out).To(HaveLen(1))
|
||||
Expect(out).To(HaveKey("b"))
|
||||
Expect(out).ToNot(HaveKey("a"))
|
||||
})
|
||||
})
|
||||
@@ -87,6 +87,11 @@ type ModelConfig struct {
|
||||
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
|
||||
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`
|
||||
|
||||
// ConcurrencyGroups declares per-node mutual-exclusion groups: the model
|
||||
// cannot be loaded alongside another model that shares any group name.
|
||||
// See docs/content/advanced/vram-management.md for usage.
|
||||
ConcurrencyGroups []string `yaml:"concurrency_groups,omitempty" json:"concurrency_groups,omitempty"`
|
||||
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
|
||||
@@ -587,6 +592,28 @@ func (c *ModelConfig) IsPinned() bool {
|
||||
return c.Pinned != nil && *c.Pinned
|
||||
}
|
||||
|
||||
// GetConcurrencyGroups returns the model's concurrency groups, normalized:
|
||||
// trimmed of whitespace, empty entries dropped, deduped. Returns nil when no
|
||||
// effective groups remain. The result is a fresh slice; the caller may
|
||||
// mutate it without affecting the config.
|
||||
func (c *ModelConfig) GetConcurrencyGroups() []string {
|
||||
if len(c.ConcurrencyGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(c.ConcurrencyGroups))
|
||||
for _, g := range c.ConcurrencyGroups {
|
||||
g = strings.TrimSpace(g)
|
||||
if g == "" || slices.Contains(out, g) {
|
||||
continue
|
||||
}
|
||||
out = append(out, g)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type ModelConfigUsecase int
|
||||
|
||||
const (
|
||||
|
||||
@@ -249,6 +249,40 @@ func (bcl *ModelConfigLoader) RemoveModelConfig(m string) {
|
||||
delete(bcl.configs, m)
|
||||
}
|
||||
|
||||
// GetModelsConflictingWith returns the names of every other configured (and
|
||||
// not-disabled) model that shares at least one concurrency group with the
|
||||
// named model. Returns nil if the named model has no groups, is unknown, or
|
||||
// has no peers in any of its groups. The result excludes the queried name.
|
||||
func (bcl *ModelConfigLoader) GetModelsConflictingWith(name string) []string {
|
||||
bcl.Lock()
|
||||
defer bcl.Unlock()
|
||||
target, ok := bcl.configs[name]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
targetGroups := target.GetConcurrencyGroups()
|
||||
if len(targetGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
var conflicts []string
|
||||
for n, cfg := range bcl.configs {
|
||||
if n == name || cfg.IsDisabled() {
|
||||
continue
|
||||
}
|
||||
other := cfg.GetConcurrencyGroups()
|
||||
if len(other) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, g := range targetGroups {
|
||||
if slices.Contains(other, g) {
|
||||
conflicts = append(conflicts, n)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return conflicts
|
||||
}
|
||||
|
||||
// UpdateModelConfig updates an existing model config in the loader.
|
||||
// This is useful for updating runtime-detected properties like thinking support.
|
||||
func (bcl *ModelConfigLoader) UpdateModelConfig(m string, updater func(*ModelConfig)) {
|
||||
|
||||
63
core/config/model_config_loader_test.go
Normal file
63
core/config/model_config_loader_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ModelConfigLoader.GetModelsConflictingWith", func() {
|
||||
var bcl *ModelConfigLoader
|
||||
|
||||
BeforeEach(func() {
|
||||
bcl = NewModelConfigLoader("/tmp/conflict-test-models")
|
||||
})
|
||||
|
||||
insert := func(cfg ModelConfig) {
|
||||
bcl.Lock()
|
||||
bcl.configs[cfg.Name] = cfg
|
||||
bcl.Unlock()
|
||||
}
|
||||
|
||||
It("returns nil when the named model has no groups", func() {
|
||||
insert(ModelConfig{Name: "loner"})
|
||||
Expect(bcl.GetModelsConflictingWith("loner")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when the named model is unknown", func() {
|
||||
Expect(bcl.GetModelsConflictingWith("ghost")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when no other model shares a group", func() {
|
||||
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
|
||||
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"vision"}})
|
||||
Expect(bcl.GetModelsConflictingWith("a")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns models that share at least one group", func() {
|
||||
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
|
||||
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}})
|
||||
insert(ModelConfig{Name: "c", ConcurrencyGroups: []string{"vision"}})
|
||||
insert(ModelConfig{Name: "d", ConcurrencyGroups: []string{"heavy", "vision"}})
|
||||
|
||||
conflicts := bcl.GetModelsConflictingWith("a")
|
||||
Expect(conflicts).To(ConsistOf("b", "d"))
|
||||
})
|
||||
|
||||
It("never lists the queried model itself", func() {
|
||||
insert(ModelConfig{Name: "self", ConcurrencyGroups: []string{"heavy"}})
|
||||
Expect(bcl.GetModelsConflictingWith("self")).To(BeNil())
|
||||
})
|
||||
|
||||
It("ignores disabled conflicting models", func() {
|
||||
disabled := true
|
||||
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
|
||||
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled})
|
||||
Expect(bcl.GetModelsConflictingWith("a")).To(BeNil())
|
||||
})
|
||||
|
||||
It("normalizes groups so whitespace and duplicates do not break overlap", func() {
|
||||
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{" heavy "}})
|
||||
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy", "heavy"}})
|
||||
Expect(bcl.GetModelsConflictingWith("a")).To(ConsistOf("b"))
|
||||
})
|
||||
})
|
||||
@@ -264,4 +264,53 @@ mcp:
|
||||
Expect(err).To(BeNil())
|
||||
Expect(valid).To(BeTrue())
|
||||
})
|
||||
Context("ConcurrencyGroups", func() {
|
||||
It("returns nil when no groups are configured", func() {
|
||||
cfg := &ModelConfig{Name: "no-groups"}
|
||||
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
|
||||
})
|
||||
It("returns nil when all entries are blank", func() {
|
||||
cfg := &ModelConfig{
|
||||
Name: "blanks",
|
||||
ConcurrencyGroups: []string{"", " ", "\t"},
|
||||
}
|
||||
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
|
||||
})
|
||||
It("trims whitespace, drops empty entries, and dedupes", func() {
|
||||
cfg := &ModelConfig{
|
||||
Name: "messy",
|
||||
ConcurrencyGroups: []string{" vram-heavy ", "", "vram-heavy", "vision", " vision "},
|
||||
}
|
||||
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "vision"}))
|
||||
})
|
||||
It("returns a defensive copy", func() {
|
||||
cfg := &ModelConfig{
|
||||
Name: "copy",
|
||||
ConcurrencyGroups: []string{"heavy"},
|
||||
}
|
||||
got := cfg.GetConcurrencyGroups()
|
||||
got[0] = "tampered"
|
||||
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"heavy"}))
|
||||
})
|
||||
It("parses concurrency_groups from YAML", func() {
|
||||
tmp, err := os.CreateTemp("", "concgroups.yaml")
|
||||
Expect(err).To(BeNil())
|
||||
defer func() { _ = os.Remove(tmp.Name()) }()
|
||||
_, err = tmp.WriteString(
|
||||
`name: heavy-a
|
||||
backend: llama-cpp
|
||||
parameters:
|
||||
model: heavy-a.gguf
|
||||
concurrency_groups:
|
||||
- vram-heavy
|
||||
- "120b"
|
||||
`)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
configs, err := readModelConfigsFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(configs).To(HaveLen(1))
|
||||
Expect(configs[0].ConcurrencyGroups).To(Equal([]string{"vram-heavy", "120b"}))
|
||||
Expect(configs[0].GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "120b"}))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -35,6 +35,15 @@ type ModelRouter interface {
|
||||
FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
|
||||
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
|
||||
GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
|
||||
FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error)
|
||||
}
|
||||
|
||||
// ConcurrencyConflictResolver returns the names of configured models that
|
||||
// share at least one concurrency group with the given model. It is satisfied
|
||||
// by *config.ModelConfigLoader and lets the SmartRouter make group-aware
|
||||
// placement decisions without importing the config package's full surface.
|
||||
type ConcurrencyConflictResolver interface {
|
||||
GetModelsConflictingWith(modelName string) []string
|
||||
}
|
||||
|
||||
// NodeHealthStore is used by HealthMonitor for node status management.
|
||||
|
||||
@@ -115,6 +115,9 @@ func (f *fakeModelRouterForSmartRouter) FindLeastLoadedNodeFromSet(_ context.Con
|
||||
func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeModelRouterForSmartRouter) FindNodesWithModel(_ context.Context, _ string) ([]BackendNode, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Compile-time check
|
||||
var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil)
|
||||
|
||||
@@ -37,18 +37,24 @@ type SmartRouterOptions struct {
|
||||
AuthToken string
|
||||
ClientFactory BackendClientFactory // optional; defaults to tokenClientFactory
|
||||
DB *gorm.DB // for advisory locks during routing
|
||||
// ConflictResolver, when set, lets the scheduler narrow placement
|
||||
// candidates by per-model concurrency_groups (#9659). When nil, group
|
||||
// anti-affinity is disabled at the scheduler layer; the per-node
|
||||
// watchdog still enforces the rule on arrival.
|
||||
ConflictResolver ConcurrencyConflictResolver
|
||||
}
|
||||
|
||||
// SmartRouter routes inference requests to the best available backend node.
|
||||
// It uses the ModelRouter interface (backed by NodeRegistry in production) for routing decisions.
|
||||
type SmartRouter struct {
|
||||
registry ModelRouter
|
||||
unloader NodeCommandSender // optional, for NATS-driven load/unload
|
||||
fileStager FileStager // optional, for distributed file transfer
|
||||
galleriesJSON string // backend gallery config for dynamic installation
|
||||
clientFactory BackendClientFactory // creates gRPC backend clients
|
||||
db *gorm.DB // for advisory locks during routing
|
||||
stagingTracker *StagingTracker // tracks file staging progress for UI visibility
|
||||
registry ModelRouter
|
||||
unloader NodeCommandSender // optional, for NATS-driven load/unload
|
||||
fileStager FileStager // optional, for distributed file transfer
|
||||
galleriesJSON string // backend gallery config for dynamic installation
|
||||
clientFactory BackendClientFactory // creates gRPC backend clients
|
||||
db *gorm.DB // for advisory locks during routing
|
||||
stagingTracker *StagingTracker // tracks file staging progress for UI visibility
|
||||
conflictResolver ConcurrencyConflictResolver
|
||||
}
|
||||
|
||||
// NewSmartRouter creates a new SmartRouter backed by the given ModelRouter.
|
||||
@@ -59,13 +65,14 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
||||
factory = &tokenClientFactory{token: opts.AuthToken}
|
||||
}
|
||||
return &SmartRouter{
|
||||
registry: registry,
|
||||
unloader: opts.Unloader,
|
||||
fileStager: opts.FileStager,
|
||||
galleriesJSON: opts.GalleriesJSON,
|
||||
clientFactory: factory,
|
||||
db: opts.DB,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
registry: registry,
|
||||
unloader: opts.Unloader,
|
||||
fileStager: opts.FileStager,
|
||||
galleriesJSON: opts.GalleriesJSON,
|
||||
clientFactory: factory,
|
||||
db: opts.DB,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
conflictResolver: opts.ConflictResolver,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -382,6 +389,60 @@ func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID str
|
||||
return extractNodeIDs(candidates), nil
|
||||
}
|
||||
|
||||
// narrowByGroupAntiAffinity removes candidate nodes that already host a model
|
||||
// declared as concurrent-conflicting with modelID via concurrency_groups
|
||||
// (#9659). This is a soft filter: when *every* candidate would be excluded,
|
||||
// the original set is returned and the per-node watchdog evicts on arrival.
|
||||
//
|
||||
// candidates may be nil ("any healthy node" — registry helpers treat nil as
|
||||
// no filter). nil is returned unchanged: hard-narrowing the implicit "all
|
||||
// nodes" set would silently exclude every node we know nothing about.
|
||||
func (r *SmartRouter) narrowByGroupAntiAffinity(ctx context.Context, modelID string, candidates []string) ([]string, error) {
|
||||
if r.conflictResolver == nil || candidates == nil {
|
||||
return candidates, nil
|
||||
}
|
||||
conflicts := r.conflictResolver.GetModelsConflictingWith(modelID)
|
||||
if len(conflicts) == 0 {
|
||||
return candidates, nil
|
||||
}
|
||||
|
||||
excluded := make(map[string]struct{})
|
||||
for _, name := range conflicts {
|
||||
nodes, err := r.registry.FindNodesWithModel(ctx, name)
|
||||
if err != nil {
|
||||
// Best-effort: a single lookup failure shouldn't fail placement.
|
||||
// Log and move on — the watchdog still enforces the rule on arrival.
|
||||
xlog.Warn("Group anti-affinity: lookup failed, skipping", "model", name, "error", err)
|
||||
continue
|
||||
}
|
||||
for _, n := range nodes {
|
||||
excluded[n.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(excluded) == 0 {
|
||||
return candidates, nil
|
||||
}
|
||||
|
||||
narrowed := candidates[:0:0]
|
||||
for _, id := range candidates {
|
||||
if _, bad := excluded[id]; bad {
|
||||
continue
|
||||
}
|
||||
narrowed = append(narrowed, id)
|
||||
}
|
||||
if len(narrowed) == 0 {
|
||||
// Soft fallback: every candidate has a conflict. Return the original
|
||||
// set and let the per-node watchdog evict on arrival rather than
|
||||
// failing the request.
|
||||
xlog.Debug("Group anti-affinity: all candidates conflict, falling back to original set",
|
||||
"model", modelID, "conflicts", conflicts)
|
||||
return candidates, nil
|
||||
}
|
||||
xlog.Debug("Group anti-affinity narrowed candidates",
|
||||
"model", modelID, "before", len(candidates), "after", len(narrowed))
|
||||
return narrowed, nil
|
||||
}
|
||||
|
||||
// nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model.
|
||||
// Returns true if no constraints exist or the node matches all selector labels.
|
||||
func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool {
|
||||
@@ -438,6 +499,15 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
// Apply concurrency-group anti-affinity (#9659): prefer nodes that don't
|
||||
// already host a model declared exclusive with this one. Soft filter — if
|
||||
// every candidate has a conflict, the original set is returned and the
|
||||
// per-node watchdog evicts on arrival.
|
||||
candidateNodeIDs, err = r.narrowByGroupAntiAffinity(ctx, modelID, candidateNodeIDs)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
// Narrow candidates to nodes that still have a free replica slot for this
|
||||
// model. Without this filter, the scheduler would happily pick a node
|
||||
// already at capacity for this model (e.g. when MinReplicas > free
|
||||
|
||||
@@ -106,6 +106,10 @@ type fakeModelRouter struct {
|
||||
getNodeLabels []NodeLabel
|
||||
getNodeLabelsErr error
|
||||
|
||||
// FindNodesWithModel returns (keyed by model name)
|
||||
findNodesWithModelByName map[string][]BackendNode
|
||||
findNodesWithModelErr error
|
||||
|
||||
// Track calls for assertions
|
||||
decrementCalls []string // "nodeID:modelName"
|
||||
incrementCalls []string
|
||||
@@ -228,6 +232,25 @@ func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabe
|
||||
return f.getNodeLabels, f.getNodeLabelsErr
|
||||
}
|
||||
|
||||
func (f *fakeModelRouter) FindNodesWithModel(_ context.Context, modelName string) ([]BackendNode, error) {
|
||||
if f.findNodesWithModelErr != nil {
|
||||
return nil, f.findNodesWithModelErr
|
||||
}
|
||||
return f.findNodesWithModelByName[modelName], nil
|
||||
}
|
||||
|
||||
// fakeConflictResolver implements ConcurrencyConflictResolver from a static map.
|
||||
type fakeConflictResolver struct {
|
||||
conflicts map[string][]string
|
||||
}
|
||||
|
||||
func (f *fakeConflictResolver) GetModelsConflictingWith(name string) []string {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
return f.conflicts[name]
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fake BackendClientFactory + Backend
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -847,4 +870,84 @@ var _ = Describe("SmartRouter", func() {
|
||||
Expect(original.MMProj).To(Equal(origMMProj))
|
||||
})
|
||||
})
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// narrowByGroupAntiAffinity
|
||||
// -----------------------------------------------------------------------
|
||||
Describe("narrowByGroupAntiAffinity", func() {
|
||||
var (
|
||||
reg *fakeModelRouter
|
||||
resolver *fakeConflictResolver
|
||||
router *SmartRouter
|
||||
ctx context.Context
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
reg = &fakeModelRouter{}
|
||||
resolver = &fakeConflictResolver{conflicts: map[string][]string{}}
|
||||
router = NewSmartRouter(reg, SmartRouterOptions{
|
||||
ConflictResolver: resolver,
|
||||
})
|
||||
ctx = context.Background()
|
||||
})
|
||||
|
||||
It("returns the input set unchanged when the model has no conflicts", func() {
|
||||
candidates := []string{"n1", "n2", "n3"}
|
||||
out, err := router.narrowByGroupAntiAffinity(ctx, "lonely", candidates)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal(candidates))
|
||||
})
|
||||
|
||||
It("removes nodes that already host a conflicting model", func() {
|
||||
resolver.conflicts["b"] = []string{"a"}
|
||||
reg.findNodesWithModelByName = map[string][]BackendNode{
|
||||
"a": {{ID: "n1"}},
|
||||
}
|
||||
out, err := router.narrowByGroupAntiAffinity(ctx, "b", []string{"n1", "n2"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(ConsistOf("n2"))
|
||||
})
|
||||
|
||||
It("returns the original set unchanged when every candidate has a conflict (soft fallback)", func() {
|
||||
resolver.conflicts["b"] = []string{"a"}
|
||||
reg.findNodesWithModelByName = map[string][]BackendNode{
|
||||
"a": {{ID: "n1"}, {ID: "n2"}},
|
||||
}
|
||||
candidates := []string{"n1", "n2"}
|
||||
out, err := router.narrowByGroupAntiAffinity(ctx, "b", candidates)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal(candidates))
|
||||
})
|
||||
|
||||
It("removes nodes hosting any of multiple conflicting models", func() {
|
||||
resolver.conflicts["c"] = []string{"a", "b"}
|
||||
reg.findNodesWithModelByName = map[string][]BackendNode{
|
||||
"a": {{ID: "n1"}},
|
||||
"b": {{ID: "n2"}},
|
||||
}
|
||||
out, err := router.narrowByGroupAntiAffinity(ctx, "c", []string{"n1", "n2", "n3"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(ConsistOf("n3"))
|
||||
})
|
||||
|
||||
It("treats a nil candidate set (\"any healthy node\") by returning nil unchanged when narrowing yields nothing", func() {
|
||||
resolver.conflicts["b"] = []string{"a"}
|
||||
reg.findNodesWithModelByName = map[string][]BackendNode{
|
||||
"a": {{ID: "n1"}, {ID: "n2"}},
|
||||
}
|
||||
out, err := router.narrowByGroupAntiAffinity(ctx, "b", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// nil in → nil out: caller's "any healthy node" semantics preserved.
|
||||
// Hard-narrowing nil would silently exclude every other node.
|
||||
Expect(out).To(BeNil())
|
||||
})
|
||||
|
||||
It("is a no-op when no resolver is configured", func() {
|
||||
plain := NewSmartRouter(reg, SmartRouterOptions{})
|
||||
candidates := []string{"n1", "n2"}
|
||||
out, err := plain.narrowByGroupAntiAffinity(ctx, "b", candidates)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(out).To(Equal(candidates))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user