Files
LocalAI/core/services/nodes/registry_test.go
LocalAI [bot] 7637f8cf1b feat(distributed): declarative per-model scheduling via env/args (#10308)
* feat(distributed): add SpreadAll column and authoritative scheduling seeding

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(distributed): parse declarative model scheduling config (env/file)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(distributed): reconcile spread_all to one replica per matching node

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(distributed): wire LOCALAI_MODEL_SCHEDULING env/args and startup seeding

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(distributed): expose spread_all on the scheduling API endpoint

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(distributed): add spread-to-all-nodes mode to the scheduling UI

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* docs(distributed): document LOCALAI_MODEL_SCHEDULING env/args

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* docs(distributed): clarify replica modes and all-nodes spread in scheduling config

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-13 18:31:06 +02:00

1548 lines
70 KiB
Go

package nodes
import (
"context"
"runtime"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/services/testutil"
"gorm.io/gorm"
)
var _ = Describe("NodeRegistry", func() {
var (
db *gorm.DB
registry *NodeRegistry
)
BeforeEach(func() {
if runtime.GOOS == "darwin" {
Skip("testcontainers requires Docker, not available on macOS CI")
}
db = testutil.SetupTestDB()
var err error
registry, err = NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
// Helper to build a minimal BackendNode.
makeNode := func(name, address string, vram uint64) *BackendNode {
return &BackendNode{
Name: name,
NodeType: NodeTypeBackend,
Address: address,
TotalVRAM: vram,
AvailableVRAM: vram,
}
}
Describe("Register", func() {
It("sets StatusPending when autoApprove is false", func() {
node := makeNode("worker-1", "10.0.0.1:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, false)).To(Succeed())
Expect(node.Status).To(Equal(StatusPending))
fetched, err := registry.GetByName(context.Background(), "worker-1")
Expect(err).ToNot(HaveOccurred())
Expect(fetched.Status).To(Equal(StatusPending))
})
It("sets StatusHealthy when autoApprove is true", func() {
node := makeNode("worker-2", "10.0.0.2:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(node.Status).To(Equal(StatusHealthy))
})
})
Describe("Re-registration", func() {
It("keeps a pending node pending on re-register with autoApprove=false", func() {
node := makeNode("re-pending", "10.0.0.3:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, false)).To(Succeed())
Expect(node.Status).To(Equal(StatusPending))
// Re-register same name, still no auto-approve
node2 := makeNode("re-pending", "10.0.0.3:50052", 4_000_000_000)
Expect(registry.Register(context.Background(), node2, false)).To(Succeed())
Expect(node2.Status).To(Equal(StatusPending))
// ID is preserved from original registration
Expect(node2.ID).To(Equal(node.ID))
})
It("restores a previously approved node to healthy on re-register with autoApprove=false", func() {
node := makeNode("re-approved", "10.0.0.4:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(node.Status).To(Equal(StatusHealthy))
// Simulate the node becoming unhealthy
Expect(registry.MarkUnhealthy(context.Background(), node.ID)).To(Succeed())
fetched, err := registry.GetByName(context.Background(), "re-approved")
Expect(err).ToNot(HaveOccurred())
Expect(fetched.Status).To(Equal(StatusUnhealthy))
// Re-register with autoApprove=false — should restore to healthy
// because the node was previously approved (status != pending)
node2 := makeNode("re-approved", "10.0.0.4:50052", 8_000_000_000)
Expect(registry.Register(context.Background(), node2, false)).To(Succeed())
Expect(node2.Status).To(Equal(StatusHealthy))
})
})
Describe("ApproveNode", func() {
It("transitions a pending node to healthy", func() {
node := makeNode("approve-me", "10.0.0.5:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, false)).To(Succeed())
Expect(node.Status).To(Equal(StatusPending))
Expect(registry.ApproveNode(context.Background(), node.ID)).To(Succeed())
fetched, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(fetched.Status).To(Equal(StatusHealthy))
})
It("returns error for non-existent node ID", func() {
err := registry.ApproveNode(context.Background(), "non-existent-id")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("not found or not in pending status"))
})
It("returns error for an already-healthy node", func() {
node := makeNode("already-healthy", "10.0.0.6:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
err := registry.ApproveNode(context.Background(), node.ID)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("not found or not in pending status"))
})
})
Describe("MarkOffline", func() {
It("sets status to offline and clears model records", func() {
node := makeNode("offline-test", "10.0.0.7:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Load a model on the node
Expect(registry.SetNodeModel(context.Background(), node.ID, "llama-7b", 0, "loaded", "10.0.0.7:50052", 0)).To(Succeed())
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(1))
// Mark offline
Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed())
fetched, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(fetched.Status).To(Equal(StatusOffline))
// Model records should be cleared
models, err = registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(BeEmpty())
})
It("returns error for non-existent node", func() {
err := registry.MarkOffline(context.Background(), "does-not-exist")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("not found"))
})
})
Describe("SetNodeModel ID stability", func() {
It("preserves the ID when called twice for the same node+model", func() {
node := makeNode("stable-id-node", "10.0.0.99:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.99:50052", 0)).To(Succeed())
nm1, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0)
Expect(err).ToNot(HaveOccurred())
// Call again with different state/address
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.99:50053", 0)).To(Succeed())
nm2, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm2.ID).To(Equal(nm1.ID), "ID should remain stable across SetNodeModel calls")
Expect(nm2.Address).To(Equal("10.0.0.99:50053"), "Address should be updated")
})
})
Describe("FindNodeWithVRAM", func() {
It("selects the node with sufficient VRAM", func() {
small := makeNode("small-gpu", "10.0.0.10:50051", 4_000_000_000)
big := makeNode("big-gpu", "10.0.0.11:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), small, true)).To(Succeed())
Expect(registry.Register(context.Background(), big, true)).To(Succeed())
// Request 8 GB — only big-gpu qualifies
found, err := registry.FindNodeWithVRAM(context.Background(), 8_000_000_000)
Expect(err).ToNot(HaveOccurred())
Expect(found.Name).To(Equal("big-gpu"))
})
It("returns error when no node has enough VRAM", func() {
small := makeNode("tiny-gpu", "10.0.0.12:50051", 2_000_000_000)
Expect(registry.Register(context.Background(), small, true)).To(Succeed())
_, err := registry.FindNodeWithVRAM(context.Background(), 32_000_000_000)
Expect(err).To(HaveOccurred())
})
})
Describe("FindIdleNode", func() {
It("returns the node with no loaded models", func() {
busy := makeNode("busy-node", "10.0.0.20:50051", 8_000_000_000)
idle := makeNode("idle-node", "10.0.0.21:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), busy, true)).To(Succeed())
Expect(registry.Register(context.Background(), idle, true)).To(Succeed())
// Load a model on the busy node
Expect(registry.SetNodeModel(context.Background(), busy.ID, "model-a", 0, "loaded", "", 0)).To(Succeed())
found, err := registry.FindIdleNode(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(found.Name).To(Equal("idle-node"))
})
It("returns error when all nodes have models loaded", func() {
n := makeNode("all-busy", "10.0.0.22:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-x", 0, "loaded", "", 0)).To(Succeed())
_, err := registry.FindIdleNode(context.Background())
Expect(err).To(HaveOccurred())
})
})
Describe("FindLeastLoadedNode", func() {
It("returns the node with fewer in-flight requests", func() {
heavy := makeNode("heavy-node", "10.0.0.30:50051", 8_000_000_000)
light := makeNode("light-node", "10.0.0.31:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), heavy, true)).To(Succeed())
Expect(registry.Register(context.Background(), light, true)).To(Succeed())
// Set up models with different in-flight counts
Expect(registry.SetNodeModel(context.Background(), heavy.ID, "model-a", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), light.ID, "model-b", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), light.ID, "model-b", 0)).To(Succeed())
found, err := registry.FindLeastLoadedNode(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(found.Name).To(Equal("light-node"))
})
})
Describe("FindAndLockNodeWithModel", func() {
It("returns the correct node and increments in-flight", func() {
node := makeNode("lock-node", "10.0.0.40:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.40:50052", 0)).To(Succeed())
foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(foundNode.ID).To(Equal(node.ID))
Expect(foundNM.ModelName).To(Equal("my-model"))
// Verify in-flight was incremented
nm, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm.InFlight).To(Equal(1))
})
It("returns error when model is not loaded anywhere", func() {
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "nonexistent-model", nil, nil)
Expect(err).To(HaveOccurred())
})
It("selects the node with fewer in-flight when multiple exist", func() {
n1 := makeNode("lock-heavy", "10.0.0.41:50051", 8_000_000_000)
n2 := makeNode("lock-light", "10.0.0.42:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "shared-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "shared-model", 0, "loaded", "", 0)).To(Succeed())
// Add in-flight to n1
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed())
foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(foundNode.Name).To(Equal("lock-light"))
})
It("filters by candidateNodeIDs even when an excluded node has lower in_flight", func() {
// Reproduces the selector-mismatch loop: a model loaded on a node
// the selector now excludes (excluded) and on a node it includes
// (included). Without the filter the excluded node wins on
// in_flight ASC; with the filter the included node is returned
// directly so Route() can serve from its existing replica.
excluded := makeNode("excluded-node", "10.0.0.43:50051", 8_000_000_000)
included := makeNode("included-node", "10.0.0.44:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), excluded, true)).To(Succeed())
Expect(registry.Register(context.Background(), included, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), excluded.ID, "filtered-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), included.ID, "filtered-model", 0, "loaded", "", 0)).To(Succeed())
// Make `included` strictly busier than `excluded` so the unfiltered
// query would prefer the excluded one — proving the filter is
// what's steering the result, not the in_flight ordering.
Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed())
foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "filtered-model", []string{included.ID}, nil)
Expect(err).ToNot(HaveOccurred())
Expect(foundNode.ID).To(Equal(included.ID))
Expect(foundNM.NodeID).To(Equal(included.ID))
})
It("round-robins between replicas when in_flight ties (last_used tiebreaker)", func() {
// Three replicas of the same model on three nodes, all with in_flight=0.
// Without the last_used tiebreaker, the node with the largest available_vram
// would win every pick and one node would take ~all the load. With it,
// each successful pick refreshes last_used so the next pick rotates to
// the oldest-used replica.
fat := makeNode("rr-fat", "10.0.0.50:50051", 24_000_000_000)
mid := makeNode("rr-mid", "10.0.0.51:50051", 16_000_000_000)
small := makeNode("rr-small", "10.0.0.52:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), fat, true)).To(Succeed())
Expect(registry.Register(context.Background(), mid, true)).To(Succeed())
Expect(registry.Register(context.Background(), small, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), fat.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), mid.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), small.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed())
// Decrement back to 0 after each pick so the next call sees a tie.
// (FindAndLockNodeWithModel atomically increments to lock the row.)
picks := make([]string, 0, 9)
for i := 0; i < 9; i++ {
n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
picks = append(picks, n.Name)
Expect(registry.DecrementInFlight(context.Background(), n.ID, "rr-model", nm.ReplicaIndex)).To(Succeed())
}
// Each replica should have been picked at least twice across 9 ties —
// proves we're rotating, not pinning to the largest-VRAM node.
counts := map[string]int{}
for _, p := range picks {
counts[p]++
}
Expect(counts["rr-fat"]).To(BeNumerically(">=", 2), "fat node was picked %d times across 9 ties: %v", counts["rr-fat"], picks)
Expect(counts["rr-mid"]).To(BeNumerically(">=", 2), "mid node was picked %d times across 9 ties: %v", counts["rr-mid"], picks)
Expect(counts["rr-small"]).To(BeNumerically(">=", 2), "small node was picked %d times across 9 ties: %v", counts["rr-small"], picks)
})
It("returns not-found when the model is loaded only on excluded nodes", func() {
loadedExcluded := makeNode("excl-only-node", "10.0.0.45:50051", 8_000_000_000)
emptyIncluded := makeNode("empty-included-node", "10.0.0.46:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), loadedExcluded, true)).To(Succeed())
Expect(registry.Register(context.Background(), emptyIncluded, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), loadedExcluded.ID, "no-match-model", 0, "loaded", "", 0)).To(Succeed())
// Filter restricts to a node that does not have the model — the
// query must return an error so Route() falls through to schedule
// a fresh load on a matching node instead of reusing the excluded
// replica.
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID}, nil)
Expect(err).To(HaveOccurred())
})
It("agrees with PickBestReplica on a seeded dataset (policy mirror)", func() {
// Guard against drift between the SQL ORDER BY in
// FindAndLockNodeWithModel and the canonical Go implementation in
// PickBestReplica. The two layers will eventually diverge in
// caller (DB-backed atomic pick vs in-memory snapshot pick for the
// per-frontend rotating cache), but the policy itself must stay
// the single source of truth. If this test fails, update *both*
// sides — never just one.
//
// Scenario exercises all three tiers:
// - "loser-busy" has the most VRAM but in_flight=2 — loses tier 1.
// - "loser-recent" ties at in_flight=0 but its last_used is the
// newest of the in_flight=0 group — loses tier 2.
// - "winner-mid" and "winner-fat" both tie at in_flight=0 and
// share the oldest last_used — tier 3 decides: fattest wins.
loserBusy := makeNode("mirror-loser-busy", "10.0.0.70:50051", 32_000_000_000)
loserRecent := makeNode("mirror-loser-recent", "10.0.0.71:50051", 8_000_000_000)
winnerMid := makeNode("mirror-winner-mid", "10.0.0.72:50051", 16_000_000_000)
winnerFat := makeNode("mirror-winner-fat", "10.0.0.73:50051", 24_000_000_000)
for _, n := range []*BackendNode{loserBusy, loserRecent, winnerMid, winnerFat} {
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n.ID, "mirror-model", 0, "loaded", "", 0)).To(Succeed())
}
// Force in_flight=2 on the "busy" node so tier 1 disqualifies it.
Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed())
// Slam last_used to known values so the test is deterministic
// regardless of clock resolution between the helpers above.
base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
set := func(id string, t time.Time) {
Expect(db.Model(&NodeModel{}).
Where("node_id = ? AND model_name = ?", id, "mirror-model").
Update("last_used", t).Error).To(Succeed())
}
set(loserBusy.ID, base) // newest doesn't matter — already disqualified by tier 1
set(loserRecent.ID, base.Add(time.Hour))
set(winnerMid.ID, base)
set(winnerFat.ID, base)
// Pull the same dataset both pickers will operate on. The Go
// picker is a faithful representation of the policy; the SQL is
// the production path.
var rows []NodeModel
Expect(db.Where("model_name = ? AND state = ?", "mirror-model", "loaded").
Find(&rows).Error).To(Succeed())
candidates := make([]ReplicaCandidate, 0, len(rows))
for _, nm := range rows {
var bn BackendNode
Expect(db.First(&bn, "id = ? AND status = ?", nm.NodeID, StatusHealthy).Error).To(Succeed())
candidates = append(candidates, ReplicaCandidate{
NodeID: nm.NodeID,
Address: bn.Address,
ReplicaIndex: nm.ReplicaIndex,
InFlight: nm.InFlight,
LastUsed: nm.LastUsed,
AvailableVRAM: bn.AvailableVRAM,
})
}
goPick := PickBestReplica(candidates)
Expect(goPick).ToNot(BeNil())
sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(sqlNode.ID).To(Equal(goPick.NodeID),
"SQL ORDER BY picked %s; PickBestReplica picked %s — policy has drifted",
sqlNode.ID, goPick.NodeID)
// Sanity check: the policy says winner-fat wins on tier 3.
Expect(goPick.NodeID).To(Equal(winnerFat.ID))
})
})
Describe("FindAndLockNodeWithModel preference", func() {
var nodeA, nodeB *BackendNode
BeforeEach(func() {
nodeA = makeNode("pref-a", "10.0.0.70:50051", 8_000_000_000)
nodeB = makeNode("pref-b", "10.0.0.71:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), nodeA, true)).To(Succeed())
Expect(registry.Register(context.Background(), nodeB, true)).To(Succeed())
// Both loaded+healthy for model "pref-model", in_flight 0.
Expect(registry.SetNodeModel(context.Background(), nodeA.ID, "pref-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), nodeB.ID, "pref-model", 0, "loaded", "", 0)).To(Succeed())
})
It("locks the preferred node when eligible", func() {
node, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, &RoutePreference{PreferredNodeID: nodeB.ID})
Expect(err).ToNot(HaveOccurred())
Expect(node.ID).To(Equal(nodeB.ID))
Expect(nm.NodeID).To(Equal(nodeB.ID))
// in_flight is incremented atomically via gorm.Expr, so verify the
// persisted value through a re-fetch (the returned struct mirrors
// the pre-increment read, like the default-pick path).
persisted, err := registry.GetNodeModel(context.Background(), nodeB.ID, "pref-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(persisted.InFlight).To(Equal(1))
})
It("falls back to default order when preferred not loaded", func() {
node, _, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, &RoutePreference{PreferredNodeID: "ZZZ"})
Expect(err).ToNot(HaveOccurred())
Expect(node.ID).To(BeElementOf(nodeA.ID, nodeB.ID))
})
It("nil preference behaves like before", func() {
node, _, err := registry.FindAndLockNodeWithModel(context.Background(), "pref-model", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(node).ToNot(BeNil())
})
It("locks the EXACT preferred replica when the node hosts two replicas", func() {
// A single node hosts replica 0 and replica 1 of a model, both
// loaded+healthy. The preference must lock the SPECIFIC replica
// requested, not the least-loaded replica on the node.
node := makeNode("pref-multi", "10.0.0.72:50051", 16_000_000_000)
node.MaxReplicasPerModel = 2
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 0, "loaded", "addr0", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 1, "loaded", "addr1", 0)).To(Succeed())
// pref={node, 1} must lock replica 1 specifically.
gotNode, nm1, err := registry.FindAndLockNodeWithModel(context.Background(), "multi-model", nil,
&RoutePreference{PreferredNodeID: node.ID, PreferredReplica: 1})
Expect(err).ToNot(HaveOccurred())
Expect(gotNode.ID).To(Equal(node.ID))
Expect(nm1.ReplicaIndex).To(Equal(1))
// pref={node, 0} must lock replica 0 specifically.
_, nm0, err := registry.FindAndLockNodeWithModel(context.Background(), "multi-model", nil,
&RoutePreference{PreferredNodeID: node.ID, PreferredReplica: 0})
Expect(err).ToNot(HaveOccurred())
Expect(nm0.ReplicaIndex).To(Equal(0))
})
})
Describe("LoadedReplicaStats", func() {
var n1, n2, n3 *BackendNode
BeforeEach(func() {
n1 = makeNode("stats-1", "10.0.0.80:50051", 8_000_000_000)
n2 = makeNode("stats-2", "10.0.0.81:50051", 8_000_000_000)
n3 = makeNode("stats-3", "10.0.0.82:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.Register(context.Background(), n3, true)).To(Succeed())
// n1 loaded+busy, n2 loaded+idle, n3 has a different model only.
Expect(registry.SetNodeModel(context.Background(), n1.ID, "stats-model", 0, "loaded", "10.0.0.80:6000", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "stats-model", 0, "loaded", "10.0.0.81:6000", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n3.ID, "other-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "stats-model", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "stats-model", 0)).To(Succeed())
})
It("returns loaded healthy replicas with in-flight counts", func() {
stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", nil)
Expect(err).ToNot(HaveOccurred())
Expect(stats).To(HaveLen(2))
byNode := map[string]ReplicaCandidate{}
for _, s := range stats {
byNode[s.NodeID] = s
}
Expect(byNode).To(HaveKey(n1.ID))
Expect(byNode).To(HaveKey(n2.ID))
Expect(byNode[n1.ID].InFlight).To(Equal(2))
Expect(byNode[n2.ID].InFlight).To(Equal(0))
})
It("filters to the candidate node set when provided", func() {
stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", []string{n2.ID})
Expect(err).ToNot(HaveOccurred())
Expect(stats).To(HaveLen(1))
Expect(stats[0].NodeID).To(Equal(n2.ID))
})
It("excludes unhealthy nodes", func() {
Expect(registry.MarkUnhealthy(context.Background(), n1.ID)).To(Succeed())
stats, err := registry.LoadedReplicaStats(context.Background(), "stats-model", nil)
Expect(err).ToNot(HaveOccurred())
Expect(stats).To(HaveLen(1))
Expect(stats[0].NodeID).To(Equal(n2.ID))
})
It("returns empty for a model with no loaded replicas", func() {
stats, err := registry.LoadedReplicaStats(context.Background(), "no-such-model", nil)
Expect(err).ToNot(HaveOccurred())
Expect(stats).To(BeEmpty())
})
})
Describe("MarkHealthy and MarkUnhealthy round-trip", func() {
It("transitions healthy -> unhealthy -> healthy", func() {
node := makeNode("roundtrip-node", "10.0.0.60:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(node.Status).To(Equal(StatusHealthy))
// Mark unhealthy
Expect(registry.MarkUnhealthy(context.Background(), node.ID)).To(Succeed())
fetched, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(fetched.Status).To(Equal(StatusUnhealthy))
// Mark healthy again
Expect(registry.MarkHealthy(context.Background(), node.ID)).To(Succeed())
fetched, err = registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(fetched.Status).To(Equal(StatusHealthy))
})
It("returns error for non-existent node", func() {
err := registry.MarkHealthy(context.Background(), "does-not-exist")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("not found"))
})
})
Describe("NodeLabel CRUD", func() {
It("sets and retrieves labels for a node", func() {
node := makeNode("label-node", "10.0.0.70:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed())
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(labels).To(HaveLen(2))
labelMap := make(map[string]string)
for _, l := range labels {
labelMap[l.Key] = l.Value
}
Expect(labelMap["env"]).To(Equal("prod"))
Expect(labelMap["region"]).To(Equal("us-east"))
})
It("overwrites existing label with same key", func() {
node := makeNode("label-overwrite", "10.0.0.71:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "dev")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(labels).To(HaveLen(1))
Expect(labels[0].Value).To(Equal("prod"))
})
It("removes a single label by key", func() {
node := makeNode("label-remove", "10.0.0.72:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed())
Expect(registry.RemoveNodeLabel(context.Background(), node.ID, "env")).To(Succeed())
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(labels).To(HaveLen(1))
Expect(labels[0].Key).To(Equal("region"))
})
It("SetNodeLabels replaces all labels", func() {
node := makeNode("label-replace", "10.0.0.73:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), node.ID, "old-key", "old-val")).To(Succeed())
newLabels := map[string]string{"new-a": "val-a", "new-b": "val-b"}
Expect(registry.SetNodeLabels(context.Background(), node.ID, newLabels)).To(Succeed())
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(labels).To(HaveLen(2))
labelMap := make(map[string]string)
for _, l := range labels {
labelMap[l.Key] = l.Value
}
Expect(labelMap).To(Equal(newLabels))
})
})
Describe("FindNodesBySelector", func() {
It("returns nodes matching all labels in selector", func() {
n1 := makeNode("sel-match", "10.0.0.80:50051", 8_000_000_000)
n2 := makeNode("sel-nomatch", "10.0.0.81:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n1.ID, "env", "prod")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n1.ID, "region", "us-east")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n2.ID, "env", "dev")).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod", "region": "us-east"})
Expect(err).ToNot(HaveOccurred())
Expect(nodes).To(HaveLen(1))
Expect(nodes[0].Name).To(Equal("sel-match"))
})
It("returns empty when no nodes match", func() {
n := makeNode("sel-empty", "10.0.0.82:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "dev")).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
Expect(err).ToNot(HaveOccurred())
Expect(nodes).To(BeEmpty())
})
It("ignores unhealthy nodes", func() {
n := makeNode("sel-unhealthy", "10.0.0.83:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed())
Expect(registry.MarkUnhealthy(context.Background(), n.ID)).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
Expect(err).ToNot(HaveOccurred())
Expect(nodes).To(BeEmpty())
})
It("matches nodes with more labels than selector requires", func() {
n := makeNode("sel-superset", "10.0.0.84:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "region", "us-east")).To(Succeed())
Expect(registry.SetNodeLabel(context.Background(), n.ID, "tier", "gpu")).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"})
Expect(err).ToNot(HaveOccurred())
Expect(nodes).To(HaveLen(1))
Expect(nodes[0].Name).To(Equal("sel-superset"))
})
It("returns all healthy nodes for empty selector", func() {
n1 := makeNode("sel-all-1", "10.0.0.85:50051", 8_000_000_000)
n2 := makeNode("sel-all-2", "10.0.0.86:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{})
Expect(err).ToNot(HaveOccurred())
Expect(len(nodes)).To(BeNumerically(">=", 2))
})
})
Describe("ModelSchedulingConfig CRUD", func() {
It("creates and retrieves a scheduling config", func() {
config := &ModelSchedulingConfig{
ModelName: "llama-7b",
NodeSelector: `{"gpu.vendor":"nvidia"}`,
MinReplicas: 1,
MaxReplicas: 3,
}
Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed())
Expect(config.ID).ToNot(BeEmpty())
fetched, err := registry.GetModelScheduling(context.Background(), "llama-7b")
Expect(err).ToNot(HaveOccurred())
Expect(fetched).ToNot(BeNil())
Expect(fetched.ModelName).To(Equal("llama-7b"))
Expect(fetched.NodeSelector).To(Equal(`{"gpu.vendor":"nvidia"}`))
Expect(fetched.MinReplicas).To(Equal(1))
Expect(fetched.MaxReplicas).To(Equal(3))
})
It("updates existing config via SetModelScheduling", func() {
config := &ModelSchedulingConfig{
ModelName: "update-model",
MinReplicas: 1,
MaxReplicas: 2,
}
Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed())
config2 := &ModelSchedulingConfig{
ModelName: "update-model",
MinReplicas: 2,
MaxReplicas: 5,
}
Expect(registry.SetModelScheduling(context.Background(), config2)).To(Succeed())
fetched, err := registry.GetModelScheduling(context.Background(), "update-model")
Expect(err).ToNot(HaveOccurred())
Expect(fetched.MinReplicas).To(Equal(2))
Expect(fetched.MaxReplicas).To(Equal(5))
})
It("persists and updates route policy and thresholds", func() {
err := registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "prefix-cache-model", RoutePolicy: "prefix_cache",
BalanceAbsThreshold: 3, BalanceRelThreshold: 2.0, MinPrefixMatch: 0.4,
})
Expect(err).ToNot(HaveOccurred())
got, err := registry.GetModelScheduling(context.Background(), "prefix-cache-model")
Expect(err).ToNot(HaveOccurred())
Expect(got.RoutePolicy).To(Equal("prefix_cache"))
Expect(got.BalanceAbsThreshold).To(Equal(3))
Expect(got.BalanceRelThreshold).To(BeNumerically("==", 2.0))
Expect(got.MinPrefixMatch).To(BeNumerically("==", 0.4))
// Update must not be dropped on conflict.
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "prefix-cache-model", RoutePolicy: "round_robin",
})).ToNot(HaveOccurred())
got, err = registry.GetModelScheduling(context.Background(), "prefix-cache-model")
Expect(err).ToNot(HaveOccurred())
Expect(got.RoutePolicy).To(Equal("round_robin"))
})
It("lists all configs", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed())
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed())
configs, err := registry.ListModelSchedulings(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(len(configs)).To(BeNumerically(">=", 2))
})
It("lists only auto-scaling configs", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-a", MinReplicas: 2})).To(Succeed())
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-b", MaxReplicas: 3})).To(Succeed())
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "no-auto", NodeSelector: `{"env":"prod"}`})).To(Succeed())
configs, err := registry.ListAutoScalingConfigs(context.Background())
Expect(err).ToNot(HaveOccurred())
names := make([]string, len(configs))
for i, c := range configs {
names[i] = c.ModelName
}
Expect(names).To(ContainElement("auto-a"))
Expect(names).To(ContainElement("auto-b"))
Expect(names).ToNot(ContainElement("no-auto"))
})
It("deletes a config", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "delete-me", MinReplicas: 1})).To(Succeed())
Expect(registry.DeleteModelScheduling(context.Background(), "delete-me")).To(Succeed())
fetched, err := registry.GetModelScheduling(context.Background(), "delete-me")
Expect(err).ToNot(HaveOccurred())
Expect(fetched).To(BeNil())
})
It("returns nil for non-existent model", func() {
fetched, err := registry.GetModelScheduling(context.Background(), "does-not-exist")
Expect(err).ToNot(HaveOccurred())
Expect(fetched).To(BeNil())
})
})
Describe("CountLoadedReplicas", func() {
It("returns correct count of loaded replicas", func() {
n1 := makeNode("replica-node-1", "10.0.0.90:50051", 8_000_000_000)
n2 := makeNode("replica-node-2", "10.0.0.91:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "counted-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "counted-model", 0, "loaded", "", 0)).To(Succeed())
count, err := registry.CountLoadedReplicas(context.Background(), "counted-model")
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(2)))
})
It("excludes non-loaded states", func() {
n1 := makeNode("replica-loaded", "10.0.0.92:50051", 8_000_000_000)
n2 := makeNode("replica-loading", "10.0.0.93:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "state-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "state-model", 0, "loading", "", 0)).To(Succeed())
count, err := registry.CountLoadedReplicas(context.Background(), "state-model")
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(1)))
})
})
Describe("DecrementInFlight", func() {
It("does not go below zero", func() {
node := makeNode("dec-node", "10.0.0.50:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "dec-model", 0, "loaded", "", 0)).To(Succeed())
// in_flight starts at 0 — decrement should be a no-op
Expect(registry.DecrementInFlight(context.Background(), node.ID, "dec-model", 0)).To(Succeed())
nm, err := registry.GetNodeModel(context.Background(), node.ID, "dec-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm.InFlight).To(Equal(0))
})
It("decrements correctly from a positive value", func() {
node := makeNode("dec-node-2", "10.0.0.51:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "dec-model-2", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed())
nm, err := registry.GetNodeModel(context.Background(), node.ID, "dec-model-2", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm.InFlight).To(Equal(2))
Expect(registry.DecrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed())
nm, err = registry.GetNodeModel(context.Background(), node.ID, "dec-model-2", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm.InFlight).To(Equal(1))
})
})
Describe("Schema defaults", func() {
// These tests pin the GORM defaults that the multi-replica refactor
// relies on. If a future migration changes a default, the
// reconciler/router will silently misbehave (e.g. capacity 0 instead
// of 1) — these assertions catch that at the migration boundary.
It("BackendNode.MaxReplicasPerModel defaults to 1", func() {
node := makeNode("schema-default-mrpm", "10.0.0.200:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
fetched, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(fetched.MaxReplicasPerModel).To(Equal(1),
"old workers don't send the field; default must preserve single-replica behavior")
})
It("BackendNode.ReservedVRAM defaults to 0", func() {
node := makeNode("schema-default-reserved", "10.0.0.201:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
fetched, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(fetched.ReservedVRAM).To(Equal(uint64(0)))
})
It("NodeModel.ReplicaIndex defaults to 0", func() {
node := makeNode("schema-default-replica", "10.0.0.202:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "default-replica-model", 0, "loaded", "127.0.0.1:50100", 0)).To(Succeed())
nm, err := registry.GetNodeModel(context.Background(), node.ID, "default-replica-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm).ToNot(BeNil())
Expect(nm.ReplicaIndex).To(Equal(0))
})
It("ModelSchedulingConfig.UnsatisfiableUntil is nullable and defaults to nil", func() {
cfg := &ModelSchedulingConfig{
ModelName: "schema-default-unsat",
MinReplicas: 1,
}
Expect(registry.SetModelScheduling(context.Background(), cfg)).To(Succeed())
fetched, err := registry.GetModelScheduling(context.Background(), "schema-default-unsat")
Expect(err).ToNot(HaveOccurred())
Expect(fetched).ToNot(BeNil())
Expect(fetched.UnsatisfiableUntil).To(BeNil())
Expect(fetched.UnsatisfiableTicks).To(Equal(0))
})
})
Describe("Multi-replica registry", func() {
// PR2 tests: SetNodeModel with distinct replica indexes creates distinct
// rows; per-row mutations (Remove, Increment, Decrement, Touch) target
// only their indexed row so siblings are not orphaned.
It("SetNodeModel(replicaIndex=0) then SetNodeModel(replicaIndex=1) creates two distinct rows", func() {
node := makeNode("multi-1", "10.0.0.210:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 0, "loaded", "127.0.0.1:50100", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 1, "loaded", "127.0.0.1:50101", 0)).To(Succeed())
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(2))
byIdx := map[int]NodeModel{}
for _, m := range models {
byIdx[m.ReplicaIndex] = m
}
Expect(byIdx[0].Address).To(Equal("127.0.0.1:50100"))
Expect(byIdx[1].Address).To(Equal("127.0.0.1:50101"))
Expect(byIdx[0].ID).ToNot(Equal(byIdx[1].ID))
})
It("RemoveNodeModel(replicaIndex=0) leaves replica 1 intact", func() {
node := makeNode("multi-2", "10.0.0.211:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "kept-model", 0, "loaded", "127.0.0.1:50110", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "kept-model", 1, "loaded", "127.0.0.1:50111", 0)).To(Succeed())
Expect(registry.RemoveNodeModel(context.Background(), node.ID, "kept-model", 0)).To(Succeed())
// Sibling replica must still exist — this was the latent bug pre-PR2:
// the WHERE clause matched both rows and orphaned the healthy sibling.
survivor, err := registry.GetNodeModel(context.Background(), node.ID, "kept-model", 1)
Expect(err).ToNot(HaveOccurred())
Expect(survivor).ToNot(BeNil())
Expect(survivor.Address).To(Equal("127.0.0.1:50111"))
// Replica 0 is gone
_, err = registry.GetNodeModel(context.Background(), node.ID, "kept-model", 0)
Expect(err).To(HaveOccurred())
})
It("RemoveAllNodeModelReplicas deletes every replica of the model on the node", func() {
node := makeNode("multi-3", "10.0.0.212:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "purge-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "purge-model", 1, "loaded", "b", 0)).To(Succeed())
Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "purge-model")).To(Succeed())
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(BeEmpty())
})
It("IncrementInFlight only updates the targeted replica row", func() {
node := makeNode("multi-4", "10.0.0.213:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "infl-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "infl-model", 1, "loaded", "b", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "infl-model", 1)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "infl-model", 1)).To(Succeed())
r0, err := registry.GetNodeModel(context.Background(), node.ID, "infl-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(r0.InFlight).To(Equal(0), "replica 0 must not have been incremented")
r1, err := registry.GetNodeModel(context.Background(), node.ID, "infl-model", 1)
Expect(err).ToNot(HaveOccurred())
Expect(r1.InFlight).To(Equal(2))
})
It("CountReplicasOnNode returns the per-(node, model) row count", func() {
node := makeNode("multi-5", "10.0.0.214:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "count-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "count-model", 1, "loaded", "b", 0)).To(Succeed())
n, err := registry.CountReplicasOnNode(context.Background(), node.ID, "count-model")
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
})
It("NextFreeReplicaIndex returns the lowest unused index < maxSlots", func() {
node := makeNode("multi-6", "10.0.0.215:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Slot 0 free initially
idx, err := registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4)
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(Equal(0))
// Occupy 0 and 2 — next free is 1 (lowest gap)
Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 2, "loaded", "c", 0)).To(Succeed())
idx, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4)
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(Equal(1), "must allocate the lowest free index for compactness")
// Fill all 4 — must return ErrNoFreeSlot
Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 1, "loaded", "b", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 3, "loaded", "d", 0)).To(Succeed())
_, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4)
Expect(err).To(MatchError(ErrNoFreeSlot))
// maxSlots=0 always returns ErrNoFreeSlot
_, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "no-slots-model", 0)
Expect(err).To(MatchError(ErrNoFreeSlot))
})
})
Describe("SetReplicaRemovedHook", func() {
type removed struct {
model, node string
replica int
}
It("fires once with the specific replica after RemoveNodeModel", func() {
node := makeNode("hook-remove-one", "10.0.0.230:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-model", 1, "loaded", "a", 0)).To(Succeed())
var fired []removed
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
fired = append(fired, removed{model: modelName, node: nodeID, replica: replicaIndex})
})
// RemoveNodeModel(replica 1) must fire with the SPECIFIC replica index.
Expect(registry.RemoveNodeModel(context.Background(), node.ID, "hook-model", 1)).To(Succeed())
Expect(fired).To(HaveLen(1))
Expect(fired[0]).To(Equal(removed{model: "hook-model", node: node.ID, replica: 1}))
})
It("fires once with replica<0 after RemoveAllNodeModelReplicas", func() {
node := makeNode("hook-remove-all", "10.0.0.231:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-all-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "hook-all-model", 1, "loaded", "b", 0)).To(Succeed())
var fired []removed
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
fired = append(fired, removed{model: modelName, node: nodeID, replica: replicaIndex})
})
// One call covers all replicas of that model on the node: a negative
// replica index signals "all replicas", and the consumer's
// InvalidateNode drops every entry for the (model, node) pair.
Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "hook-all-model")).To(Succeed())
Expect(fired).To(HaveLen(1))
Expect(fired[0].model).To(Equal("hook-all-model"))
Expect(fired[0].node).To(Equal(node.ID))
Expect(fired[0].replica).To(BeNumerically("<", 0))
})
It("does not panic when no hook is set", func() {
node := makeNode("hook-unset", "10.0.0.232:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "no-hook-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(func() {
Expect(registry.RemoveNodeModel(context.Background(), node.ID, "no-hook-model", 0)).To(Succeed())
Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "no-hook-model")).To(Succeed())
}).ToNot(Panic())
})
// firedModelSet collects the distinct model names the hook saw for the
// given node. The bulk node-scoped deletes below remove every replica of
// every model on the node in one statement, so the chokepoint must fire
// the hook once per distinct model name (the consumer's Invalidate
// drops all entries for that (model, node) pair).
seedTwoModels := func(node *BackendNode) {
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 0, "loaded", "a0", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 1, "loaded", "a1", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", 0, "loaded", "b0", 0)).To(Succeed())
}
It("fires once per distinct model after MarkOffline", func() {
node := makeNode("hook-offline", "10.0.0.240:50051", 8_000_000_000)
seedTwoModels(node)
fired := map[removed]int{}
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
// Bulk node-scoped deletes signal "all replicas" with replica<0.
Expect(replicaIndex).To(BeNumerically("<", 0))
fired[removed{model: modelName, node: nodeID}]++
})
Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed())
Expect(fired).To(HaveLen(2))
Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
})
It("fires once per distinct model after MarkDraining", func() {
node := makeNode("hook-draining", "10.0.0.241:50051", 8_000_000_000)
seedTwoModels(node)
fired := map[removed]int{}
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
// Bulk node-scoped deletes signal "all replicas" with replica<0.
Expect(replicaIndex).To(BeNumerically("<", 0))
fired[removed{model: modelName, node: nodeID}]++
})
Expect(registry.MarkDraining(context.Background(), node.ID)).To(Succeed())
Expect(fired).To(HaveLen(2))
Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
})
It("fires once per distinct model after Deregister", func() {
node := makeNode("hook-deregister", "10.0.0.242:50051", 8_000_000_000)
seedTwoModels(node)
fired := map[removed]int{}
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
// Bulk node-scoped deletes signal "all replicas" with replica<0.
Expect(replicaIndex).To(BeNumerically("<", 0))
fired[removed{model: modelName, node: nodeID}]++
})
Expect(registry.Deregister(context.Background(), node.ID)).To(Succeed())
Expect(fired).To(HaveLen(2))
Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
})
It("fires once per distinct model when re-registration clears stale rows", func() {
node := makeNode("hook-reregister", "10.0.0.243:50051", 8_000_000_000)
seedTwoModels(node)
fired := map[removed]int{}
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
// Bulk node-scoped deletes signal "all replicas" with replica<0.
Expect(replicaIndex).To(BeNumerically("<", 0))
fired[removed{model: modelName, node: nodeID}]++
})
// Re-register the same node (same name): the re-register path
// clears the stale model rows, which must fire the hook.
again := makeNode("hook-reregister", "10.0.0.243:50052", 8_000_000_000)
Expect(registry.Register(context.Background(), again, true)).To(Succeed())
Expect(fired).To(HaveLen(2))
Expect(fired[removed{model: "model-a", node: node.ID}]).To(Equal(1))
Expect(fired[removed{model: "model-b", node: node.ID}]).To(Equal(1))
})
// Atomicity: the bulk node-scoped delete in MarkOffline/MarkDraining/
// re-register now captures the model names and deletes the rows inside a
// single transaction. A true SetNodeModel-between-capture-and-delete race
// can't be forced deterministically here, but we can assert the
// post-condition the transaction guarantees: the set of fired hooks
// equals exactly the set of node_models rows the operation removed, with
// nothing left behind. If the capture and delete ever saw inconsistent
// snapshots, either a surviving row (delete missed it) or a missing hook
// (capture missed it) would break one of these assertions.
It("MarkOffline fires hooks for exactly the rows it deletes (consistent snapshot)", func() {
node := makeNode("hook-atomic-offline", "10.0.0.244:50051", 8_000_000_000)
seedTwoModels(node)
// Capture what the transaction should remove, straight from the DB,
// before running the operation.
before, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
expectedModels := map[string]struct{}{}
for _, nm := range before {
expectedModels[nm.ModelName] = struct{}{}
}
Expect(expectedModels).To(HaveLen(2), "seed should create two distinct models")
fired := map[string]struct{}{}
registry.SetReplicaRemovedHook(func(modelName, nodeID string, replicaIndex int) {
Expect(nodeID).To(Equal(node.ID))
Expect(replicaIndex).To(BeNumerically("<", 0))
fired[modelName] = struct{}{}
})
Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed())
// Hooks fired for exactly the distinct models that existed.
Expect(fired).To(Equal(expectedModels),
"hooks must fire for exactly the set of models the transaction deleted")
// And the delete actually emptied the node_models rows for the node:
// no row survives that did not get a hook.
after, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(after).To(BeEmpty(), "no node_models row should survive the bulk delete")
})
})
Describe("ApplyAutoLabels", func() {
It("mirrors MaxReplicasPerModel as the node.replica-slots label", func() {
node := makeNode("auto-label-replicas", "10.0.0.220:50051", 16_000_000_000)
node.MaxReplicasPerModel = 4
node.GPUVendor = "nvidia"
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
registry.ApplyAutoLabels(context.Background(), node.ID, node)
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
byKey := map[string]string{}
for _, l := range labels {
byKey[l.Key] = l.Value
}
Expect(byKey).To(HaveKeyWithValue("node.replica-slots", "4"),
"selectors targeting fat nodes need this auto-label")
Expect(byKey).To(HaveKeyWithValue("gpu.vendor", "nvidia"))
})
It("defaults node.replica-slots to 1 when MaxReplicasPerModel is unset", func() {
node := makeNode("auto-label-default", "10.0.0.221:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Fetch back; default should be 1 (PR1 schema test)
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModel).To(Equal(1))
registry.ApplyAutoLabels(context.Background(), node.ID, fetched)
labels, _ := registry.GetNodeLabels(context.Background(), node.ID)
byKey := map[string]string{}
for _, l := range labels {
byKey[l.Key] = l.Value
}
Expect(byKey).To(HaveKeyWithValue("node.replica-slots", "1"))
})
})
Describe("VRAM soft-reservation (PR5)", func() {
// These tests pin the soft-reservation contract: ReserveVRAM is the
// admission gate that prevents two concurrent scheduling decisions
// from over-committing the same node within one heartbeat window.
It("ReserveVRAM atomically deducts from effectively-free VRAM", func() {
node := makeNode("reserve-1", "10.0.0.230:50051", 10_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.ReserveVRAM(context.Background(), node.ID, 3_000_000_000)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(3_000_000_000)))
})
It("ReserveVRAM rejects when effectively-free VRAM is insufficient", func() {
node := makeNode("reserve-2", "10.0.0.231:50051", 5_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// First reservation fits.
Expect(registry.ReserveVRAM(context.Background(), node.ID, 4_000_000_000)).To(Succeed())
// Second is too big — only 1 GB effectively free.
err := registry.ReserveVRAM(context.Background(), node.ID, 2_000_000_000)
Expect(err).To(MatchError(ErrInsufficientVRAM))
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(4_000_000_000)),
"failed reservation must not bump the column")
})
It("ReserveVRAM with bytes=0 is a no-op", func() {
node := makeNode("reserve-3", "10.0.0.232:50051", 1_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.ReserveVRAM(context.Background(), node.ID, 0)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(0)))
})
It("ReleaseVRAM returns reserved bytes to the pool", func() {
node := makeNode("release-1", "10.0.0.233:50051", 10_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.ReserveVRAM(context.Background(), node.ID, 4_000_000_000)).To(Succeed())
Expect(registry.ReleaseVRAM(context.Background(), node.ID, 1_000_000_000)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(3_000_000_000)))
})
It("ReleaseVRAM cannot underflow past zero", func() {
node := makeNode("release-underflow", "10.0.0.234:50051", 1_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// No reservation; release is a guarded no-op rather than wrapping
// uint64 to a huge positive number.
Expect(registry.ReleaseVRAM(context.Background(), node.ID, 5_000_000_000)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(0)))
})
It("Heartbeat with available_vram resets reserved_vram to 0", func() {
node := makeNode("heartbeat-reset", "10.0.0.235:50051", 10_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.ReserveVRAM(context.Background(), node.ID, 5_000_000_000)).To(Succeed())
fresh := uint64(8_000_000_000)
Expect(registry.Heartbeat(context.Background(), node.ID, &HeartbeatUpdate{AvailableVRAM: &fresh})).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.AvailableVRAM).To(Equal(fresh),
"heartbeat must overwrite available_vram with the worker's reading")
Expect(fetched.ReservedVRAM).To(Equal(uint64(0)),
"heartbeat must clear the soft reservation — worker is the source of truth")
})
It("UpdateMaxReplicasPerModel marks the value as a sticky override", func() {
// The original UX bug: workers default the flag to 1, so every
// re-registration silently reverted the admin's UI value. This
// test pins the fix.
node := &BackendNode{
Name: "override-survives",
NodeType: NodeTypeBackend,
Address: "10.0.0.240:50051",
MaxReplicasPerModel: 1,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Admin sets capacity to 4 via the UI.
Expect(registry.UpdateMaxReplicasPerModel(context.Background(), node.ID, 4)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModel).To(Equal(4))
Expect(fetched.MaxReplicasPerModelManuallySet).To(BeTrue())
// Worker re-registers with its default of 1 (operator never set the flag).
restart := &BackendNode{
Name: "override-survives",
NodeType: NodeTypeBackend,
Address: "10.0.0.240:50051",
MaxReplicasPerModel: 1,
}
Expect(registry.Register(context.Background(), restart, true)).To(Succeed())
// Override must have survived.
fetched, _ = registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModel).To(Equal(4),
"admin override must not be overwritten by worker re-registration")
Expect(fetched.MaxReplicasPerModelManuallySet).To(BeTrue())
})
It("ResetMaxReplicasPerModel hands control back to the worker", func() {
node := &BackendNode{
Name: "override-reset",
NodeType: NodeTypeBackend,
Address: "10.0.0.241:50051",
MaxReplicasPerModel: 1,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.UpdateMaxReplicasPerModel(context.Background(), node.ID, 4)).To(Succeed())
Expect(registry.ResetMaxReplicasPerModel(context.Background(), node.ID)).To(Succeed())
// Reset only flips the flag; the value stays until the worker
// re-registers (we don't presume to know what the worker wants).
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModelManuallySet).To(BeFalse())
// Now worker re-registers with 8.
restart := &BackendNode{
Name: "override-reset",
NodeType: NodeTypeBackend,
Address: "10.0.0.241:50051",
MaxReplicasPerModel: 8,
}
Expect(registry.Register(context.Background(), restart, true)).To(Succeed())
fetched, _ = registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModel).To(Equal(8),
"after reset, the worker's value should apply")
})
It("FindNodeWithVRAM honors the reservation", func() {
small := makeNode("find-vram-small", "10.0.0.236:50051", 5_000_000_000)
big := makeNode("find-vram-big", "10.0.0.237:50051", 20_000_000_000)
Expect(registry.Register(context.Background(), small, true)).To(Succeed())
Expect(registry.Register(context.Background(), big, true)).To(Succeed())
// Reserve almost all of the big node so its effective free
// drops below the request — small isn't big enough either —
// the call must return an error.
Expect(registry.ReserveVRAM(context.Background(), big.ID, 18_000_000_000)).To(Succeed())
_, err := registry.FindNodeWithVRAM(context.Background(), 8_000_000_000)
Expect(err).To(HaveOccurred(),
"reserved capacity must remove a node from VRAM-aware candidates")
})
})
Describe("ModelLoadInfo persistence (Bug-1)", func() {
It("survives every NodeModel row being removed", func() {
ctx := context.Background()
// One node with one loaded replica + per-replica blob (the legacy path).
node := makeNode("li-1", "10.0.1.1:50051", 8_000_000_000)
Expect(registry.Register(ctx, node, true)).To(Succeed())
Expect(registry.SetNodeModel(ctx, node.ID, "load-info-model", 0, "loaded", node.Address, 0)).To(Succeed())
Expect(registry.SetNodeModelLoadInfo(ctx, node.ID, "load-info-model", 0, "llama-cpp", []byte("opts-v1"))).To(Succeed())
// Persist per-model via the new path (the dispatch hook does this).
Expect(registry.UpsertModelLoadInfo(ctx, "load-info-model", "llama-cpp", []byte("opts-v1"))).To(Succeed())
// Simulate worker death + MarkOffline reaping: every NodeModel row gone.
Expect(registry.RemoveAllNodeModelReplicas(ctx, node.ID, "load-info-model")).To(Succeed())
bt, blob, err := registry.GetModelLoadInfo(ctx, "load-info-model")
Expect(err).ToNot(HaveOccurred(),
"per-model load info must survive every NodeModel row going away")
Expect(bt).To(Equal("llama-cpp"))
Expect(blob).To(Equal([]byte("opts-v1")))
})
It("ON CONFLICT updates backend type and opts (last-write-wins)", func() {
ctx := context.Background()
Expect(registry.UpsertModelLoadInfo(ctx, "lww", "llama-cpp", []byte("v1"))).To(Succeed())
Expect(registry.UpsertModelLoadInfo(ctx, "lww", "vllm", []byte("v2"))).To(Succeed())
bt, blob, err := registry.GetModelLoadInfo(ctx, "lww")
Expect(err).ToNot(HaveOccurred())
Expect(bt).To(Equal("vllm"))
Expect(blob).To(Equal([]byte("v2")))
})
It("falls back to legacy NodeModel blob when no per-model row exists", func() {
// Pre-fix rolling-upgrade path: a frontend that ran before the new
// table existed only wrote the per-replica blob. The new
// GetModelLoadInfo must still find it so an upgrade doesn't
// regress the reconciler for already-loaded models.
ctx := context.Background()
node := makeNode("li-legacy", "10.0.1.2:50051", 8_000_000_000)
Expect(registry.Register(ctx, node, true)).To(Succeed())
Expect(registry.SetNodeModel(ctx, node.ID, "legacy-model", 0, "loaded", node.Address, 0)).To(Succeed())
Expect(registry.SetNodeModelLoadInfo(ctx, node.ID, "legacy-model", 0, "llama-cpp", []byte("legacy-opts"))).To(Succeed())
bt, blob, err := registry.GetModelLoadInfo(ctx, "legacy-model")
Expect(err).ToNot(HaveOccurred())
Expect(bt).To(Equal("llama-cpp"))
Expect(blob).To(Equal([]byte("legacy-opts")))
})
It("returns ErrRecordNotFound when neither source has the model", func() {
ctx := context.Background()
_, _, err := registry.GetModelLoadInfo(ctx, "never-loaded")
Expect(err).To(MatchError(gorm.ErrRecordNotFound))
})
It("rejects empty model names", func() {
err := registry.UpsertModelLoadInfo(context.Background(), "", "llama-cpp", []byte("x"))
Expect(err).To(HaveOccurred())
})
})
})
var _ = Describe("ModelScheduling spread + seeding", func() {
var (
db *gorm.DB
registry *NodeRegistry
)
BeforeEach(func() {
if runtime.GOOS == "darwin" {
Skip("testcontainers requires Docker, not available on macOS CI")
}
db = testutil.SetupTestDB()
var err error
registry, err = NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
It("persists and round-trips SpreadAll", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "m", SpreadAll: true,
})).To(Succeed())
got, err := registry.GetModelScheduling(context.Background(), "m")
Expect(err).ToNot(HaveOccurred())
Expect(got.SpreadAll).To(BeTrue())
})
It("includes SpreadAll configs in ListAutoScalingConfigs", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "m", SpreadAll: true,
})).To(Succeed())
configs, err := registry.ListAutoScalingConfigs(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(configs).To(HaveLen(1))
Expect(configs[0].ModelName).To(Equal("m"))
})
It("seeds configs with authoritative upsert", func() {
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "m", MinReplicas: 9,
})).To(Succeed())
err := registry.SeedModelScheduling(context.Background(), []ModelSchedulingConfig{
{ModelName: "m", MinReplicas: 1, MaxReplicas: 2},
{ModelName: "n", SpreadAll: true},
})
Expect(err).ToNot(HaveOccurred())
m, _ := registry.GetModelScheduling(context.Background(), "m")
Expect(m.MinReplicas).To(Equal(1))
Expect(m.MaxReplicas).To(Equal(2))
Expect(m.SpreadAll).To(BeFalse())
n, _ := registry.GetModelScheduling(context.Background(), "n")
Expect(n.SpreadAll).To(BeTrue())
})
})