feat(concurrency-groups): per-model exclusive groups for backend loading (#9662)

* feat(concurrency-groups): per-model exclusive groups for backend loading

Adds `concurrency_groups: [...]` to model YAML configs. Two models that share
a group cannot be loaded concurrently on the same node — loading one evicts
the others, reusing the existing pinned/busy/retry policy from LRU eviction.

Layered design:
- Watchdog (pkg/model): per-node correctness floor — on every Load(), evict
  any loaded model that shares a group with the requested one. Pinned skips
  surface NeedMore so the loader retries (and ultimately logs a clear
  warning), instead of silently allowing the rule to be violated.
- Distributed scheduler (core/services/nodes): soft anti-affinity hint —
  scheduleNewModel prefers nodes that don't already host a same-group
  model, falling back to eviction only if every candidate has a conflict.
  Composes with NodeSelector at the same point in the candidate pipeline.

Per-node, not cluster-wide: VRAM is a node-local resource, and two heavy
models running on different nodes is fine. The ConfigLoader is wired into
SmartRouter via a small ConcurrencyConflictResolver interface so the nodes
package keeps a narrow surface on core/config.

Refactors the inner LRU eviction body into a shared collectEvictionsLocked
helper and the loader retry loop into retryEnforce(fn, maxRetries, interval),
so both LRU and group enforcement share busy/pinned/retry semantics.

Closes #9659.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(watchdog): sync pinned + concurrency_groups at startup

The startup-time watchdog setup lives in initializeWatchdog (startup.go),
not in startWatchdog (watchdog.go). The latter is only invoked from the
runtime-settings RestartWatchdog path. As a result, neither
SyncPinnedModelsToWatchdog nor SyncModelGroupsToWatchdog ran at boot,
so `pinned: true` and `concurrency_groups: [...]` only became effective
after a settings-driven watchdog restart.

Fix by adding both sync calls to initializeWatchdog. Confirmed end-to-end:
loading model A in group "heavy", then C with no group (coexists),
then B in group "heavy" now correctly evicts A and leaves [B, C].

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(test): satisfy errcheck on new os.Remove in concurrency_groups spec

CI lint runs new-from-merge-base, so the existing pre-existing
`defer os.Remove(tmp.Name())` lines are baseline-grandfathered but the
one introduced by the concurrency_groups YAML round-trip test is held
to errcheck. Wrap the remove in a closure that discards the error.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-05-05 08:42:50 +02:00
committed by GitHub
parent 22ae415695
commit bbcaebc1ef
17 changed files with 981 additions and 76 deletions

View File

@@ -71,7 +71,9 @@ func (ds *DistributedServices) Shutdown() {
// initDistributed validates distributed mode prerequisites and initializes
// NATS, object storage, node registry, and instance identity.
// Returns nil if distributed mode is not enabled.
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
// configLoader is used by the SmartRouter to compute concurrency-group
// anti-affinity at placement time (#9659); it may be nil in tests.
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoader *config.ModelConfigLoader) (*DistributedServices, error) {
if !cfg.Distributed.Enabled {
return nil, nil
}
@@ -234,12 +236,17 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)
// All dependencies ready — build SmartRouter with all options at once
var conflictResolver nodes.ConcurrencyConflictResolver
if configLoader != nil {
conflictResolver = configLoader
}
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
Unloader: remoteUnloader,
FileStager: fileStager,
GalleriesJSON: routerGalleriesJSON,
AuthToken: routerAuthToken,
DB: authDB,
Unloader: remoteUnloader,
FileStager: fileStager,
GalleriesJSON: routerGalleriesJSON,
AuthToken: routerAuthToken,
DB: authDB,
ConflictResolver: conflictResolver,
})
// Create ReplicaReconciler for auto-scaling model replicas. Adapter +

View File

@@ -139,7 +139,7 @@ func New(opts ...config.AppOption) (*Application, error) {
}
// Initialize distributed mode services (NATS, object storage, node registry)
distSvc, err := initDistributed(options, application.authDB)
distSvc, err := initDistributed(options, application.authDB, application.ModelConfigLoader())
if err != nil {
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
}
@@ -680,6 +680,12 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon
options.LRUEvictionRetryInterval,
)
// Sync per-model state from configs to the watchdog. Without this,
// `pinned: true` and `concurrency_groups:` are only honored after a
// settings-driven RestartWatchdog and never at boot.
application.SyncPinnedModelsToWatchdog()
application.SyncModelGroupsToWatchdog()
// Start watchdog goroutine if any periodic checks are enabled
// LRU eviction doesn't need the Run() loop - it's triggered on model load
// But memory reclaimer needs the Run() loop for periodic checking

View File

@@ -1,6 +1,7 @@
package application
import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
@@ -26,6 +27,40 @@ func (a *Application) SyncPinnedModelsToWatchdog() {
xlog.Debug("Synced pinned models to watchdog", "count", len(pinned))
}
// SyncModelGroupsToWatchdog reads concurrency_groups from all model configs and
// updates the watchdog so EnforceGroupExclusivity has the current view.
func (a *Application) SyncModelGroupsToWatchdog() {
cl := a.ModelConfigLoader()
if cl == nil {
return
}
wd := a.modelLoader.GetWatchDog()
if wd == nil {
return
}
groups := extractModelGroupsFromConfigs(cl.GetAllModelsConfigs())
wd.ReplaceModelGroups(groups)
xlog.Debug("Synced concurrency groups to watchdog", "count", len(groups))
}
// extractModelGroupsFromConfigs builds the model→groups map the watchdog
// expects. Disabled models are skipped — their declared groups should not
// block other models from loading.
func extractModelGroupsFromConfigs(configs []config.ModelConfig) map[string][]string {
out := make(map[string][]string)
for _, cfg := range configs {
if cfg.IsDisabled() {
continue
}
gs := cfg.GetConcurrencyGroups()
if len(gs) == 0 {
continue
}
out[cfg.Name] = gs
}
return out
}
func (a *Application) StopWatchdog() error {
if a.watchdogStop != nil {
close(a.watchdogStop)
@@ -65,8 +100,9 @@ func (a *Application) startWatchdog() error {
// Set the watchdog on the model loader
a.modelLoader.SetWatchDog(wd)
// Sync pinned models from config to the watchdog
// Sync pinned models and concurrency groups from config to the watchdog
a.SyncPinnedModelsToWatchdog()
a.SyncModelGroupsToWatchdog()
// Start watchdog goroutine if any periodic checks are enabled
// LRU eviction doesn't need the Run() loop - it's triggered on model load
@@ -148,8 +184,9 @@ func (a *Application) RestartWatchdog() error {
newWD.RestoreState(oldState)
}
// Re-sync pinned models after restart
// Re-sync pinned models and concurrency groups after restart
a.SyncPinnedModelsToWatchdog()
a.SyncModelGroupsToWatchdog()
return nil
}

View File

@@ -0,0 +1,47 @@
package application
import (
"github.com/mudler/LocalAI/core/config"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("extractModelGroupsFromConfigs", func() {
It("returns an empty map when no config declares groups", func() {
out := extractModelGroupsFromConfigs([]config.ModelConfig{
{Name: "a"},
{Name: "b"},
})
Expect(out).To(BeEmpty())
})
It("returns each model's normalized groups", func() {
out := extractModelGroupsFromConfigs([]config.ModelConfig{
{Name: "a", ConcurrencyGroups: []string{" heavy ", "vision", "heavy"}},
{Name: "b", ConcurrencyGroups: []string{"heavy"}},
{Name: "c"}, // no groups → omitted
})
Expect(out).To(HaveLen(2))
Expect(out["a"]).To(Equal([]string{"heavy", "vision"}))
Expect(out["b"]).To(Equal([]string{"heavy"}))
Expect(out).ToNot(HaveKey("c"))
})
It("omits models whose groups normalize to empty", func() {
out := extractModelGroupsFromConfigs([]config.ModelConfig{
{Name: "blanks", ConcurrencyGroups: []string{"", " "}},
})
Expect(out).To(BeEmpty())
})
It("skips disabled models so they cannot block loading after re-enable", func() {
disabled := true
out := extractModelGroupsFromConfigs([]config.ModelConfig{
{Name: "a", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled},
{Name: "b", ConcurrencyGroups: []string{"heavy"}},
})
Expect(out).To(HaveLen(1))
Expect(out).To(HaveKey("b"))
Expect(out).ToNot(HaveKey("a"))
})
})

View File

@@ -87,6 +87,11 @@ type ModelConfig struct {
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`
// ConcurrencyGroups declares per-node mutual-exclusion groups: the model
// cannot be loaded alongside another model that shares any group name.
// See docs/content/advanced/vram-management.md for usage.
ConcurrencyGroups []string `yaml:"concurrency_groups,omitempty" json:"concurrency_groups,omitempty"`
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
@@ -587,6 +592,28 @@ func (c *ModelConfig) IsPinned() bool {
return c.Pinned != nil && *c.Pinned
}
// GetConcurrencyGroups returns the model's concurrency groups, normalized:
// trimmed of whitespace, empty entries dropped, deduped. Returns nil when no
// effective groups remain. The result is a fresh slice; the caller may
// mutate it without affecting the config.
func (c *ModelConfig) GetConcurrencyGroups() []string {
if len(c.ConcurrencyGroups) == 0 {
return nil
}
out := make([]string, 0, len(c.ConcurrencyGroups))
for _, g := range c.ConcurrencyGroups {
g = strings.TrimSpace(g)
if g == "" || slices.Contains(out, g) {
continue
}
out = append(out, g)
}
if len(out) == 0 {
return nil
}
return out
}
type ModelConfigUsecase int
const (

View File

@@ -249,6 +249,40 @@ func (bcl *ModelConfigLoader) RemoveModelConfig(m string) {
delete(bcl.configs, m)
}
// GetModelsConflictingWith returns the names of every other configured (and
// not-disabled) model that shares at least one concurrency group with the
// named model. Returns nil if the named model has no groups, is unknown, or
// has no peers in any of its groups. The result excludes the queried name.
func (bcl *ModelConfigLoader) GetModelsConflictingWith(name string) []string {
bcl.Lock()
defer bcl.Unlock()
target, ok := bcl.configs[name]
if !ok {
return nil
}
targetGroups := target.GetConcurrencyGroups()
if len(targetGroups) == 0 {
return nil
}
var conflicts []string
for n, cfg := range bcl.configs {
if n == name || cfg.IsDisabled() {
continue
}
other := cfg.GetConcurrencyGroups()
if len(other) == 0 {
continue
}
for _, g := range targetGroups {
if slices.Contains(other, g) {
conflicts = append(conflicts, n)
break
}
}
}
return conflicts
}
// UpdateModelConfig updates an existing model config in the loader.
// This is useful for updating runtime-detected properties like thinking support.
func (bcl *ModelConfigLoader) UpdateModelConfig(m string, updater func(*ModelConfig)) {

View File

@@ -0,0 +1,63 @@
package config
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ModelConfigLoader.GetModelsConflictingWith", func() {
var bcl *ModelConfigLoader
BeforeEach(func() {
bcl = NewModelConfigLoader("/tmp/conflict-test-models")
})
insert := func(cfg ModelConfig) {
bcl.Lock()
bcl.configs[cfg.Name] = cfg
bcl.Unlock()
}
It("returns nil when the named model has no groups", func() {
insert(ModelConfig{Name: "loner"})
Expect(bcl.GetModelsConflictingWith("loner")).To(BeNil())
})
It("returns nil when the named model is unknown", func() {
Expect(bcl.GetModelsConflictingWith("ghost")).To(BeNil())
})
It("returns nil when no other model shares a group", func() {
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"vision"}})
Expect(bcl.GetModelsConflictingWith("a")).To(BeNil())
})
It("returns models that share at least one group", func() {
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}})
insert(ModelConfig{Name: "c", ConcurrencyGroups: []string{"vision"}})
insert(ModelConfig{Name: "d", ConcurrencyGroups: []string{"heavy", "vision"}})
conflicts := bcl.GetModelsConflictingWith("a")
Expect(conflicts).To(ConsistOf("b", "d"))
})
It("never lists the queried model itself", func() {
insert(ModelConfig{Name: "self", ConcurrencyGroups: []string{"heavy"}})
Expect(bcl.GetModelsConflictingWith("self")).To(BeNil())
})
It("ignores disabled conflicting models", func() {
disabled := true
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled})
Expect(bcl.GetModelsConflictingWith("a")).To(BeNil())
})
It("normalizes groups so whitespace and duplicates do not break overlap", func() {
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{" heavy "}})
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy", "heavy"}})
Expect(bcl.GetModelsConflictingWith("a")).To(ConsistOf("b"))
})
})

View File

@@ -264,4 +264,53 @@ mcp:
Expect(err).To(BeNil())
Expect(valid).To(BeTrue())
})
Context("ConcurrencyGroups", func() {
It("returns nil when no groups are configured", func() {
cfg := &ModelConfig{Name: "no-groups"}
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
})
It("returns nil when all entries are blank", func() {
cfg := &ModelConfig{
Name: "blanks",
ConcurrencyGroups: []string{"", " ", "\t"},
}
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
})
It("trims whitespace, drops empty entries, and dedupes", func() {
cfg := &ModelConfig{
Name: "messy",
ConcurrencyGroups: []string{" vram-heavy ", "", "vram-heavy", "vision", " vision "},
}
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "vision"}))
})
It("returns a defensive copy", func() {
cfg := &ModelConfig{
Name: "copy",
ConcurrencyGroups: []string{"heavy"},
}
got := cfg.GetConcurrencyGroups()
got[0] = "tampered"
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"heavy"}))
})
It("parses concurrency_groups from YAML", func() {
tmp, err := os.CreateTemp("", "concgroups.yaml")
Expect(err).To(BeNil())
defer func() { _ = os.Remove(tmp.Name()) }()
_, err = tmp.WriteString(
`name: heavy-a
backend: llama-cpp
parameters:
model: heavy-a.gguf
concurrency_groups:
- vram-heavy
- "120b"
`)
Expect(err).ToNot(HaveOccurred())
configs, err := readModelConfigsFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(configs).To(HaveLen(1))
Expect(configs[0].ConcurrencyGroups).To(Equal([]string{"vram-heavy", "120b"}))
Expect(configs[0].GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "120b"}))
})
})
})

View File

@@ -35,6 +35,15 @@ type ModelRouter interface {
FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error)
}
// ConcurrencyConflictResolver returns the names of configured models that
// share at least one concurrency group with the given model. It is satisfied
// by *config.ModelConfigLoader and lets the SmartRouter make group-aware
// placement decisions without importing the config package's full surface.
type ConcurrencyConflictResolver interface {
GetModelsConflictingWith(modelName string) []string
}
// NodeHealthStore is used by HealthMonitor for node status management.

View File

@@ -115,6 +115,9 @@ func (f *fakeModelRouterForSmartRouter) FindLeastLoadedNodeFromSet(_ context.Con
func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) FindNodesWithModel(_ context.Context, _ string) ([]BackendNode, error) {
return nil, nil
}
// Compile-time check
var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil)

View File

@@ -37,18 +37,24 @@ type SmartRouterOptions struct {
AuthToken string
ClientFactory BackendClientFactory // optional; defaults to tokenClientFactory
DB *gorm.DB // for advisory locks during routing
// ConflictResolver, when set, lets the scheduler narrow placement
// candidates by per-model concurrency_groups (#9659). When nil, group
// anti-affinity is disabled at the scheduler layer; the per-node
// watchdog still enforces the rule on arrival.
ConflictResolver ConcurrencyConflictResolver
}
// SmartRouter routes inference requests to the best available backend node.
// It uses the ModelRouter interface (backed by NodeRegistry in production) for routing decisions.
type SmartRouter struct {
registry ModelRouter
unloader NodeCommandSender // optional, for NATS-driven load/unload
fileStager FileStager // optional, for distributed file transfer
galleriesJSON string // backend gallery config for dynamic installation
clientFactory BackendClientFactory // creates gRPC backend clients
db *gorm.DB // for advisory locks during routing
stagingTracker *StagingTracker // tracks file staging progress for UI visibility
registry ModelRouter
unloader NodeCommandSender // optional, for NATS-driven load/unload
fileStager FileStager // optional, for distributed file transfer
galleriesJSON string // backend gallery config for dynamic installation
clientFactory BackendClientFactory // creates gRPC backend clients
db *gorm.DB // for advisory locks during routing
stagingTracker *StagingTracker // tracks file staging progress for UI visibility
conflictResolver ConcurrencyConflictResolver
}
// NewSmartRouter creates a new SmartRouter backed by the given ModelRouter.
@@ -59,13 +65,14 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
factory = &tokenClientFactory{token: opts.AuthToken}
}
return &SmartRouter{
registry: registry,
unloader: opts.Unloader,
fileStager: opts.FileStager,
galleriesJSON: opts.GalleriesJSON,
clientFactory: factory,
db: opts.DB,
stagingTracker: NewStagingTracker(),
registry: registry,
unloader: opts.Unloader,
fileStager: opts.FileStager,
galleriesJSON: opts.GalleriesJSON,
clientFactory: factory,
db: opts.DB,
stagingTracker: NewStagingTracker(),
conflictResolver: opts.ConflictResolver,
}
}
@@ -382,6 +389,60 @@ func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID str
return extractNodeIDs(candidates), nil
}
// narrowByGroupAntiAffinity removes candidate nodes that already host a model
// declared as concurrent-conflicting with modelID via concurrency_groups
// (#9659). This is a soft filter: when *every* candidate would be excluded,
// the original set is returned and the per-node watchdog evicts on arrival.
//
// candidates may be nil ("any healthy node" — registry helpers treat nil as
// no filter). nil is returned unchanged: hard-narrowing the implicit "all
// nodes" set would silently exclude every node we know nothing about.
func (r *SmartRouter) narrowByGroupAntiAffinity(ctx context.Context, modelID string, candidates []string) ([]string, error) {
if r.conflictResolver == nil || candidates == nil {
return candidates, nil
}
conflicts := r.conflictResolver.GetModelsConflictingWith(modelID)
if len(conflicts) == 0 {
return candidates, nil
}
excluded := make(map[string]struct{})
for _, name := range conflicts {
nodes, err := r.registry.FindNodesWithModel(ctx, name)
if err != nil {
// Best-effort: a single lookup failure shouldn't fail placement.
// Log and move on — the watchdog still enforces the rule on arrival.
xlog.Warn("Group anti-affinity: lookup failed, skipping", "model", name, "error", err)
continue
}
for _, n := range nodes {
excluded[n.ID] = struct{}{}
}
}
if len(excluded) == 0 {
return candidates, nil
}
narrowed := candidates[:0:0]
for _, id := range candidates {
if _, bad := excluded[id]; bad {
continue
}
narrowed = append(narrowed, id)
}
if len(narrowed) == 0 {
// Soft fallback: every candidate has a conflict. Return the original
// set and let the per-node watchdog evict on arrival rather than
// failing the request.
xlog.Debug("Group anti-affinity: all candidates conflict, falling back to original set",
"model", modelID, "conflicts", conflicts)
return candidates, nil
}
xlog.Debug("Group anti-affinity narrowed candidates",
"model", modelID, "before", len(candidates), "after", len(narrowed))
return narrowed, nil
}
// nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model.
// Returns true if no constraints exist or the node matches all selector labels.
func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool {
@@ -438,6 +499,15 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
return nil, "", 0, err
}
// Apply concurrency-group anti-affinity (#9659): prefer nodes that don't
// already host a model declared exclusive with this one. Soft filter — if
// every candidate has a conflict, the original set is returned and the
// per-node watchdog evicts on arrival.
candidateNodeIDs, err = r.narrowByGroupAntiAffinity(ctx, modelID, candidateNodeIDs)
if err != nil {
return nil, "", 0, err
}
// Narrow candidates to nodes that still have a free replica slot for this
// model. Without this filter, the scheduler would happily pick a node
// already at capacity for this model (e.g. when MinReplicas > free

View File

@@ -106,6 +106,10 @@ type fakeModelRouter struct {
getNodeLabels []NodeLabel
getNodeLabelsErr error
// FindNodesWithModel returns (keyed by model name)
findNodesWithModelByName map[string][]BackendNode
findNodesWithModelErr error
// Track calls for assertions
decrementCalls []string // "nodeID:modelName"
incrementCalls []string
@@ -228,6 +232,25 @@ func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabe
return f.getNodeLabels, f.getNodeLabelsErr
}
func (f *fakeModelRouter) FindNodesWithModel(_ context.Context, modelName string) ([]BackendNode, error) {
if f.findNodesWithModelErr != nil {
return nil, f.findNodesWithModelErr
}
return f.findNodesWithModelByName[modelName], nil
}
// fakeConflictResolver implements ConcurrencyConflictResolver from a static map.
type fakeConflictResolver struct {
conflicts map[string][]string
}
func (f *fakeConflictResolver) GetModelsConflictingWith(name string) []string {
if f == nil {
return nil
}
return f.conflicts[name]
}
// ---------------------------------------------------------------------------
// Fake BackendClientFactory + Backend
// ---------------------------------------------------------------------------
@@ -847,4 +870,84 @@ var _ = Describe("SmartRouter", func() {
Expect(original.MMProj).To(Equal(origMMProj))
})
})
// -----------------------------------------------------------------------
// narrowByGroupAntiAffinity
// -----------------------------------------------------------------------
Describe("narrowByGroupAntiAffinity", func() {
var (
reg *fakeModelRouter
resolver *fakeConflictResolver
router *SmartRouter
ctx context.Context
)
BeforeEach(func() {
reg = &fakeModelRouter{}
resolver = &fakeConflictResolver{conflicts: map[string][]string{}}
router = NewSmartRouter(reg, SmartRouterOptions{
ConflictResolver: resolver,
})
ctx = context.Background()
})
It("returns the input set unchanged when the model has no conflicts", func() {
candidates := []string{"n1", "n2", "n3"}
out, err := router.narrowByGroupAntiAffinity(ctx, "lonely", candidates)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal(candidates))
})
It("removes nodes that already host a conflicting model", func() {
resolver.conflicts["b"] = []string{"a"}
reg.findNodesWithModelByName = map[string][]BackendNode{
"a": {{ID: "n1"}},
}
out, err := router.narrowByGroupAntiAffinity(ctx, "b", []string{"n1", "n2"})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(ConsistOf("n2"))
})
It("returns the original set unchanged when every candidate has a conflict (soft fallback)", func() {
resolver.conflicts["b"] = []string{"a"}
reg.findNodesWithModelByName = map[string][]BackendNode{
"a": {{ID: "n1"}, {ID: "n2"}},
}
candidates := []string{"n1", "n2"}
out, err := router.narrowByGroupAntiAffinity(ctx, "b", candidates)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal(candidates))
})
It("removes nodes hosting any of multiple conflicting models", func() {
resolver.conflicts["c"] = []string{"a", "b"}
reg.findNodesWithModelByName = map[string][]BackendNode{
"a": {{ID: "n1"}},
"b": {{ID: "n2"}},
}
out, err := router.narrowByGroupAntiAffinity(ctx, "c", []string{"n1", "n2", "n3"})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(ConsistOf("n3"))
})
It("treats a nil candidate set (\"any healthy node\") by returning nil unchanged when narrowing yields nothing", func() {
resolver.conflicts["b"] = []string{"a"}
reg.findNodesWithModelByName = map[string][]BackendNode{
"a": {{ID: "n1"}, {ID: "n2"}},
}
out, err := router.narrowByGroupAntiAffinity(ctx, "b", nil)
Expect(err).ToNot(HaveOccurred())
// nil in → nil out: caller's "any healthy node" semantics preserved.
// Hard-narrowing nil would silently exclude every other node.
Expect(out).To(BeNil())
})
It("is a no-op when no resolver is configured", func() {
plain := NewSmartRouter(reg, SmartRouterOptions{})
candidates := []string{"n1", "n2"}
out, err := plain.narrowByGroupAntiAffinity(ctx, "b", candidates)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal(candidates))
})
})
})