package nodes import ( "context" "runtime" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/mudler/LocalAI/core/services/testutil" "gorm.io/gorm" ) var _ = Describe("NodeRegistry", func() { var ( db *gorm.DB registry *NodeRegistry ) BeforeEach(func() { if runtime.GOOS == "darwin" { Skip("testcontainers requires Docker, not available on macOS CI") } db = testutil.SetupTestDB() var err error registry, err = NewNodeRegistry(db) Expect(err).ToNot(HaveOccurred()) }) // Helper to build a minimal BackendNode. makeNode := func(name, address string, vram uint64) *BackendNode { return &BackendNode{ Name: name, NodeType: NodeTypeBackend, Address: address, TotalVRAM: vram, AvailableVRAM: vram, } } Describe("Register", func() { It("sets StatusPending when autoApprove is false", func() { node := makeNode("worker-1", "10.0.0.1:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, false)).To(Succeed()) Expect(node.Status).To(Equal(StatusPending)) fetched, err := registry.GetByName(context.Background(), "worker-1") Expect(err).ToNot(HaveOccurred()) Expect(fetched.Status).To(Equal(StatusPending)) }) It("sets StatusHealthy when autoApprove is true", func() { node := makeNode("worker-2", "10.0.0.2:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(node.Status).To(Equal(StatusHealthy)) }) }) Describe("Re-registration", func() { It("keeps a pending node pending on re-register with autoApprove=false", func() { node := makeNode("re-pending", "10.0.0.3:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, false)).To(Succeed()) Expect(node.Status).To(Equal(StatusPending)) // Re-register same name, still no auto-approve node2 := makeNode("re-pending", "10.0.0.3:50052", 4_000_000_000) Expect(registry.Register(context.Background(), node2, false)).To(Succeed()) Expect(node2.Status).To(Equal(StatusPending)) // ID is preserved from original registration Expect(node2.ID).To(Equal(node.ID)) }) It("restores a previously approved node to healthy on re-register with autoApprove=false", func() { node := makeNode("re-approved", "10.0.0.4:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(node.Status).To(Equal(StatusHealthy)) // Simulate the node becoming unhealthy Expect(registry.MarkUnhealthy(context.Background(), node.ID)).To(Succeed()) fetched, err := registry.GetByName(context.Background(), "re-approved") Expect(err).ToNot(HaveOccurred()) Expect(fetched.Status).To(Equal(StatusUnhealthy)) // Re-register with autoApprove=false — should restore to healthy // because the node was previously approved (status != pending) node2 := makeNode("re-approved", "10.0.0.4:50052", 8_000_000_000) Expect(registry.Register(context.Background(), node2, false)).To(Succeed()) Expect(node2.Status).To(Equal(StatusHealthy)) }) }) Describe("ApproveNode", func() { It("transitions a pending node to healthy", func() { node := makeNode("approve-me", "10.0.0.5:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, false)).To(Succeed()) Expect(node.Status).To(Equal(StatusPending)) Expect(registry.ApproveNode(context.Background(), node.ID)).To(Succeed()) fetched, err := registry.Get(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(fetched.Status).To(Equal(StatusHealthy)) }) It("returns error for non-existent node ID", func() { err := registry.ApproveNode(context.Background(), "non-existent-id") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("not found or not in pending status")) }) It("returns error for an already-healthy node", func() { node := makeNode("already-healthy", "10.0.0.6:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) err := registry.ApproveNode(context.Background(), node.ID) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("not found or not in pending status")) }) }) Describe("MarkOffline", func() { It("sets status to offline and clears model records", func() { node := makeNode("offline-test", "10.0.0.7:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) // Load a model on the node Expect(registry.SetNodeModel(context.Background(), node.ID, "llama-7b", 0, "loaded", "10.0.0.7:50052", 0)).To(Succeed()) models, err := registry.GetNodeModels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(models).To(HaveLen(1)) // Mark offline Expect(registry.MarkOffline(context.Background(), node.ID)).To(Succeed()) fetched, err := registry.Get(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(fetched.Status).To(Equal(StatusOffline)) // Model records should be cleared models, err = registry.GetNodeModels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(models).To(BeEmpty()) }) It("returns error for non-existent node", func() { err := registry.MarkOffline(context.Background(), "does-not-exist") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("not found")) }) }) Describe("SetNodeModel ID stability", func() { It("preserves the ID when called twice for the same node+model", func() { node := makeNode("stable-id-node", "10.0.0.99:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.99:50052", 0)).To(Succeed()) nm1, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0) Expect(err).ToNot(HaveOccurred()) // Call again with different state/address Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.99:50053", 0)).To(Succeed()) nm2, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0) Expect(err).ToNot(HaveOccurred()) Expect(nm2.ID).To(Equal(nm1.ID), "ID should remain stable across SetNodeModel calls") Expect(nm2.Address).To(Equal("10.0.0.99:50053"), "Address should be updated") }) }) Describe("FindNodeWithVRAM", func() { It("selects the node with sufficient VRAM", func() { small := makeNode("small-gpu", "10.0.0.10:50051", 4_000_000_000) big := makeNode("big-gpu", "10.0.0.11:50051", 16_000_000_000) Expect(registry.Register(context.Background(), small, true)).To(Succeed()) Expect(registry.Register(context.Background(), big, true)).To(Succeed()) // Request 8 GB — only big-gpu qualifies found, err := registry.FindNodeWithVRAM(context.Background(), 8_000_000_000) Expect(err).ToNot(HaveOccurred()) Expect(found.Name).To(Equal("big-gpu")) }) It("returns error when no node has enough VRAM", func() { small := makeNode("tiny-gpu", "10.0.0.12:50051", 2_000_000_000) Expect(registry.Register(context.Background(), small, true)).To(Succeed()) _, err := registry.FindNodeWithVRAM(context.Background(), 32_000_000_000) Expect(err).To(HaveOccurred()) }) }) Describe("FindIdleNode", func() { It("returns the node with no loaded models", func() { busy := makeNode("busy-node", "10.0.0.20:50051", 8_000_000_000) idle := makeNode("idle-node", "10.0.0.21:50051", 8_000_000_000) Expect(registry.Register(context.Background(), busy, true)).To(Succeed()) Expect(registry.Register(context.Background(), idle, true)).To(Succeed()) // Load a model on the busy node Expect(registry.SetNodeModel(context.Background(), busy.ID, "model-a", 0, "loaded", "", 0)).To(Succeed()) found, err := registry.FindIdleNode(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(found.Name).To(Equal("idle-node")) }) It("returns error when all nodes have models loaded", func() { n := makeNode("all-busy", "10.0.0.22:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), n.ID, "model-x", 0, "loaded", "", 0)).To(Succeed()) _, err := registry.FindIdleNode(context.Background()) Expect(err).To(HaveOccurred()) }) }) Describe("FindLeastLoadedNode", func() { It("returns the node with fewer in-flight requests", func() { heavy := makeNode("heavy-node", "10.0.0.30:50051", 8_000_000_000) light := makeNode("light-node", "10.0.0.31:50051", 8_000_000_000) Expect(registry.Register(context.Background(), heavy, true)).To(Succeed()) Expect(registry.Register(context.Background(), light, true)).To(Succeed()) // Set up models with different in-flight counts Expect(registry.SetNodeModel(context.Background(), heavy.ID, "model-a", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), light.ID, "model-b", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), light.ID, "model-b", 0)).To(Succeed()) found, err := registry.FindLeastLoadedNode(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(found.Name).To(Equal("light-node")) }) }) Describe("FindAndLockNodeWithModel", func() { It("returns the correct node and increments in-flight", func() { node := makeNode("lock-node", "10.0.0.40:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.40:50052", 0)).To(Succeed()) foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model", nil) Expect(err).ToNot(HaveOccurred()) Expect(foundNode.ID).To(Equal(node.ID)) Expect(foundNM.ModelName).To(Equal("my-model")) // Verify in-flight was incremented nm, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0) Expect(err).ToNot(HaveOccurred()) Expect(nm.InFlight).To(Equal(1)) }) It("returns error when model is not loaded anywhere", func() { _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "nonexistent-model", nil) Expect(err).To(HaveOccurred()) }) It("selects the node with fewer in-flight when multiple exist", func() { n1 := makeNode("lock-heavy", "10.0.0.41:50051", 8_000_000_000) n2 := makeNode("lock-light", "10.0.0.42:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), n1.ID, "shared-model", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), n2.ID, "shared-model", 0, "loaded", "", 0)).To(Succeed()) // Add in-flight to n1 Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed()) foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model", nil) Expect(err).ToNot(HaveOccurred()) Expect(foundNode.Name).To(Equal("lock-light")) }) It("filters by candidateNodeIDs even when an excluded node has lower in_flight", func() { // Reproduces the selector-mismatch loop: a model loaded on a node // the selector now excludes (excluded) and on a node it includes // (included). Without the filter the excluded node wins on // in_flight ASC; with the filter the included node is returned // directly so Route() can serve from its existing replica. excluded := makeNode("excluded-node", "10.0.0.43:50051", 8_000_000_000) included := makeNode("included-node", "10.0.0.44:50051", 8_000_000_000) Expect(registry.Register(context.Background(), excluded, true)).To(Succeed()) Expect(registry.Register(context.Background(), included, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), excluded.ID, "filtered-model", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), included.ID, "filtered-model", 0, "loaded", "", 0)).To(Succeed()) // Make `included` strictly busier than `excluded` so the unfiltered // query would prefer the excluded one — proving the filter is // what's steering the result, not the in_flight ordering. Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), included.ID, "filtered-model", 0)).To(Succeed()) foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "filtered-model", []string{included.ID}) Expect(err).ToNot(HaveOccurred()) Expect(foundNode.ID).To(Equal(included.ID)) Expect(foundNM.NodeID).To(Equal(included.ID)) }) It("round-robins between replicas when in_flight ties (last_used tiebreaker)", func() { // Three replicas of the same model on three nodes, all with in_flight=0. // Without the last_used tiebreaker, the node with the largest available_vram // would win every pick and one node would take ~all the load. With it, // each successful pick refreshes last_used so the next pick rotates to // the oldest-used replica. fat := makeNode("rr-fat", "10.0.0.50:50051", 24_000_000_000) mid := makeNode("rr-mid", "10.0.0.51:50051", 16_000_000_000) small := makeNode("rr-small", "10.0.0.52:50051", 8_000_000_000) Expect(registry.Register(context.Background(), fat, true)).To(Succeed()) Expect(registry.Register(context.Background(), mid, true)).To(Succeed()) Expect(registry.Register(context.Background(), small, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), fat.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), mid.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), small.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed()) // Decrement back to 0 after each pick so the next call sees a tie. // (FindAndLockNodeWithModel atomically increments to lock the row.) picks := make([]string, 0, 9) for i := 0; i < 9; i++ { n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil) Expect(err).ToNot(HaveOccurred()) picks = append(picks, n.Name) Expect(registry.DecrementInFlight(context.Background(), n.ID, "rr-model", nm.ReplicaIndex)).To(Succeed()) } // Each replica should have been picked at least twice across 9 ties — // proves we're rotating, not pinning to the largest-VRAM node. counts := map[string]int{} for _, p := range picks { counts[p]++ } Expect(counts["rr-fat"]).To(BeNumerically(">=", 2), "fat node was picked %d times across 9 ties: %v", counts["rr-fat"], picks) Expect(counts["rr-mid"]).To(BeNumerically(">=", 2), "mid node was picked %d times across 9 ties: %v", counts["rr-mid"], picks) Expect(counts["rr-small"]).To(BeNumerically(">=", 2), "small node was picked %d times across 9 ties: %v", counts["rr-small"], picks) }) It("returns not-found when the model is loaded only on excluded nodes", func() { loadedExcluded := makeNode("excl-only-node", "10.0.0.45:50051", 8_000_000_000) emptyIncluded := makeNode("empty-included-node", "10.0.0.46:50051", 8_000_000_000) Expect(registry.Register(context.Background(), loadedExcluded, true)).To(Succeed()) Expect(registry.Register(context.Background(), emptyIncluded, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), loadedExcluded.ID, "no-match-model", 0, "loaded", "", 0)).To(Succeed()) // Filter restricts to a node that does not have the model — the // query must return an error so Route() falls through to schedule // a fresh load on a matching node instead of reusing the excluded // replica. _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID}) Expect(err).To(HaveOccurred()) }) }) Describe("MarkHealthy and MarkUnhealthy round-trip", func() { It("transitions healthy -> unhealthy -> healthy", func() { node := makeNode("roundtrip-node", "10.0.0.60:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(node.Status).To(Equal(StatusHealthy)) // Mark unhealthy Expect(registry.MarkUnhealthy(context.Background(), node.ID)).To(Succeed()) fetched, err := registry.Get(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(fetched.Status).To(Equal(StatusUnhealthy)) // Mark healthy again Expect(registry.MarkHealthy(context.Background(), node.ID)).To(Succeed()) fetched, err = registry.Get(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(fetched.Status).To(Equal(StatusHealthy)) }) It("returns error for non-existent node", func() { err := registry.MarkHealthy(context.Background(), "does-not-exist") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("not found")) }) }) Describe("NodeLabel CRUD", func() { It("sets and retrieves labels for a node", func() { node := makeNode("label-node", "10.0.0.70:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed()) labels, err := registry.GetNodeLabels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(labels).To(HaveLen(2)) labelMap := make(map[string]string) for _, l := range labels { labelMap[l.Key] = l.Value } Expect(labelMap["env"]).To(Equal("prod")) Expect(labelMap["region"]).To(Equal("us-east")) }) It("overwrites existing label with same key", func() { node := makeNode("label-overwrite", "10.0.0.71:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "dev")).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed()) labels, err := registry.GetNodeLabels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(labels).To(HaveLen(1)) Expect(labels[0].Value).To(Equal("prod")) }) It("removes a single label by key", func() { node := makeNode("label-remove", "10.0.0.72:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), node.ID, "env", "prod")).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), node.ID, "region", "us-east")).To(Succeed()) Expect(registry.RemoveNodeLabel(context.Background(), node.ID, "env")).To(Succeed()) labels, err := registry.GetNodeLabels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(labels).To(HaveLen(1)) Expect(labels[0].Key).To(Equal("region")) }) It("SetNodeLabels replaces all labels", func() { node := makeNode("label-replace", "10.0.0.73:50051", 8_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), node.ID, "old-key", "old-val")).To(Succeed()) newLabels := map[string]string{"new-a": "val-a", "new-b": "val-b"} Expect(registry.SetNodeLabels(context.Background(), node.ID, newLabels)).To(Succeed()) labels, err := registry.GetNodeLabels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(labels).To(HaveLen(2)) labelMap := make(map[string]string) for _, l := range labels { labelMap[l.Key] = l.Value } Expect(labelMap).To(Equal(newLabels)) }) }) Describe("FindNodesBySelector", func() { It("returns nodes matching all labels in selector", func() { n1 := makeNode("sel-match", "10.0.0.80:50051", 8_000_000_000) n2 := makeNode("sel-nomatch", "10.0.0.81:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), n1.ID, "env", "prod")).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), n1.ID, "region", "us-east")).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), n2.ID, "env", "dev")).To(Succeed()) nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod", "region": "us-east"}) Expect(err).ToNot(HaveOccurred()) Expect(nodes).To(HaveLen(1)) Expect(nodes[0].Name).To(Equal("sel-match")) }) It("returns empty when no nodes match", func() { n := makeNode("sel-empty", "10.0.0.82:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n, true)).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "dev")).To(Succeed()) nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"}) Expect(err).ToNot(HaveOccurred()) Expect(nodes).To(BeEmpty()) }) It("ignores unhealthy nodes", func() { n := makeNode("sel-unhealthy", "10.0.0.83:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n, true)).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed()) Expect(registry.MarkUnhealthy(context.Background(), n.ID)).To(Succeed()) nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"}) Expect(err).ToNot(HaveOccurred()) Expect(nodes).To(BeEmpty()) }) It("matches nodes with more labels than selector requires", func() { n := makeNode("sel-superset", "10.0.0.84:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n, true)).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), n.ID, "env", "prod")).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), n.ID, "region", "us-east")).To(Succeed()) Expect(registry.SetNodeLabel(context.Background(), n.ID, "tier", "gpu")).To(Succeed()) nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{"env": "prod"}) Expect(err).ToNot(HaveOccurred()) Expect(nodes).To(HaveLen(1)) Expect(nodes[0].Name).To(Equal("sel-superset")) }) It("returns all healthy nodes for empty selector", func() { n1 := makeNode("sel-all-1", "10.0.0.85:50051", 8_000_000_000) n2 := makeNode("sel-all-2", "10.0.0.86:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) nodes, err := registry.FindNodesBySelector(context.Background(), map[string]string{}) Expect(err).ToNot(HaveOccurred()) Expect(len(nodes)).To(BeNumerically(">=", 2)) }) }) Describe("ModelSchedulingConfig CRUD", func() { It("creates and retrieves a scheduling config", func() { config := &ModelSchedulingConfig{ ModelName: "llama-7b", NodeSelector: `{"gpu.vendor":"nvidia"}`, MinReplicas: 1, MaxReplicas: 3, } Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed()) Expect(config.ID).ToNot(BeEmpty()) fetched, err := registry.GetModelScheduling(context.Background(), "llama-7b") Expect(err).ToNot(HaveOccurred()) Expect(fetched).ToNot(BeNil()) Expect(fetched.ModelName).To(Equal("llama-7b")) Expect(fetched.NodeSelector).To(Equal(`{"gpu.vendor":"nvidia"}`)) Expect(fetched.MinReplicas).To(Equal(1)) Expect(fetched.MaxReplicas).To(Equal(3)) }) It("updates existing config via SetModelScheduling", func() { config := &ModelSchedulingConfig{ ModelName: "update-model", MinReplicas: 1, MaxReplicas: 2, } Expect(registry.SetModelScheduling(context.Background(), config)).To(Succeed()) config2 := &ModelSchedulingConfig{ ModelName: "update-model", MinReplicas: 2, MaxReplicas: 5, } Expect(registry.SetModelScheduling(context.Background(), config2)).To(Succeed()) fetched, err := registry.GetModelScheduling(context.Background(), "update-model") Expect(err).ToNot(HaveOccurred()) Expect(fetched.MinReplicas).To(Equal(2)) Expect(fetched.MaxReplicas).To(Equal(5)) }) It("lists all configs", func() { Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-a", MinReplicas: 1})).To(Succeed()) Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "list-b", MaxReplicas: 2})).To(Succeed()) configs, err := registry.ListModelSchedulings(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(len(configs)).To(BeNumerically(">=", 2)) }) It("lists only auto-scaling configs", func() { Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-a", MinReplicas: 2})).To(Succeed()) Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "auto-b", MaxReplicas: 3})).To(Succeed()) Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "no-auto", NodeSelector: `{"env":"prod"}`})).To(Succeed()) configs, err := registry.ListAutoScalingConfigs(context.Background()) Expect(err).ToNot(HaveOccurred()) names := make([]string, len(configs)) for i, c := range configs { names[i] = c.ModelName } Expect(names).To(ContainElement("auto-a")) Expect(names).To(ContainElement("auto-b")) Expect(names).ToNot(ContainElement("no-auto")) }) It("deletes a config", func() { Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{ModelName: "delete-me", MinReplicas: 1})).To(Succeed()) Expect(registry.DeleteModelScheduling(context.Background(), "delete-me")).To(Succeed()) fetched, err := registry.GetModelScheduling(context.Background(), "delete-me") Expect(err).ToNot(HaveOccurred()) Expect(fetched).To(BeNil()) }) It("returns nil for non-existent model", func() { fetched, err := registry.GetModelScheduling(context.Background(), "does-not-exist") Expect(err).ToNot(HaveOccurred()) Expect(fetched).To(BeNil()) }) }) Describe("CountLoadedReplicas", func() { It("returns correct count of loaded replicas", func() { n1 := makeNode("replica-node-1", "10.0.0.90:50051", 8_000_000_000) n2 := makeNode("replica-node-2", "10.0.0.91:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), n1.ID, "counted-model", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), n2.ID, "counted-model", 0, "loaded", "", 0)).To(Succeed()) count, err := registry.CountLoadedReplicas(context.Background(), "counted-model") Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(2))) }) It("excludes non-loaded states", func() { n1 := makeNode("replica-loaded", "10.0.0.92:50051", 8_000_000_000) n2 := makeNode("replica-loading", "10.0.0.93:50051", 8_000_000_000) Expect(registry.Register(context.Background(), n1, true)).To(Succeed()) Expect(registry.Register(context.Background(), n2, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), n1.ID, "state-model", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), n2.ID, "state-model", 0, "loading", "", 0)).To(Succeed()) count, err := registry.CountLoadedReplicas(context.Background(), "state-model") Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) }) }) Describe("DecrementInFlight", func() { It("does not go below zero", func() { node := makeNode("dec-node", "10.0.0.50:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "dec-model", 0, "loaded", "", 0)).To(Succeed()) // in_flight starts at 0 — decrement should be a no-op Expect(registry.DecrementInFlight(context.Background(), node.ID, "dec-model", 0)).To(Succeed()) nm, err := registry.GetNodeModel(context.Background(), node.ID, "dec-model", 0) Expect(err).ToNot(HaveOccurred()) Expect(nm.InFlight).To(Equal(0)) }) It("decrements correctly from a positive value", func() { node := makeNode("dec-node-2", "10.0.0.51:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "dec-model-2", 0, "loaded", "", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed()) nm, err := registry.GetNodeModel(context.Background(), node.ID, "dec-model-2", 0) Expect(err).ToNot(HaveOccurred()) Expect(nm.InFlight).To(Equal(2)) Expect(registry.DecrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed()) nm, err = registry.GetNodeModel(context.Background(), node.ID, "dec-model-2", 0) Expect(err).ToNot(HaveOccurred()) Expect(nm.InFlight).To(Equal(1)) }) }) Describe("Schema defaults", func() { // These tests pin the GORM defaults that the multi-replica refactor // relies on. If a future migration changes a default, the // reconciler/router will silently misbehave (e.g. capacity 0 instead // of 1) — these assertions catch that at the migration boundary. It("BackendNode.MaxReplicasPerModel defaults to 1", func() { node := makeNode("schema-default-mrpm", "10.0.0.200:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) fetched, err := registry.Get(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(fetched.MaxReplicasPerModel).To(Equal(1), "old workers don't send the field; default must preserve single-replica behavior") }) It("BackendNode.ReservedVRAM defaults to 0", func() { node := makeNode("schema-default-reserved", "10.0.0.201:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) fetched, err := registry.Get(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(fetched.ReservedVRAM).To(Equal(uint64(0))) }) It("NodeModel.ReplicaIndex defaults to 0", func() { node := makeNode("schema-default-replica", "10.0.0.202:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "default-replica-model", 0, "loaded", "127.0.0.1:50100", 0)).To(Succeed()) nm, err := registry.GetNodeModel(context.Background(), node.ID, "default-replica-model", 0) Expect(err).ToNot(HaveOccurred()) Expect(nm).ToNot(BeNil()) Expect(nm.ReplicaIndex).To(Equal(0)) }) It("ModelSchedulingConfig.UnsatisfiableUntil is nullable and defaults to nil", func() { cfg := &ModelSchedulingConfig{ ModelName: "schema-default-unsat", MinReplicas: 1, } Expect(registry.SetModelScheduling(context.Background(), cfg)).To(Succeed()) fetched, err := registry.GetModelScheduling(context.Background(), "schema-default-unsat") Expect(err).ToNot(HaveOccurred()) Expect(fetched).ToNot(BeNil()) Expect(fetched.UnsatisfiableUntil).To(BeNil()) Expect(fetched.UnsatisfiableTicks).To(Equal(0)) }) }) Describe("Multi-replica registry", func() { // PR2 tests: SetNodeModel with distinct replica indexes creates distinct // rows; per-row mutations (Remove, Increment, Decrement, Touch) target // only their indexed row so siblings are not orphaned. It("SetNodeModel(replicaIndex=0) then SetNodeModel(replicaIndex=1) creates two distinct rows", func() { node := makeNode("multi-1", "10.0.0.210:50051", 16_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 0, "loaded", "127.0.0.1:50100", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 1, "loaded", "127.0.0.1:50101", 0)).To(Succeed()) models, err := registry.GetNodeModels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(models).To(HaveLen(2)) byIdx := map[int]NodeModel{} for _, m := range models { byIdx[m.ReplicaIndex] = m } Expect(byIdx[0].Address).To(Equal("127.0.0.1:50100")) Expect(byIdx[1].Address).To(Equal("127.0.0.1:50101")) Expect(byIdx[0].ID).ToNot(Equal(byIdx[1].ID)) }) It("RemoveNodeModel(replicaIndex=0) leaves replica 1 intact", func() { node := makeNode("multi-2", "10.0.0.211:50051", 16_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "kept-model", 0, "loaded", "127.0.0.1:50110", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "kept-model", 1, "loaded", "127.0.0.1:50111", 0)).To(Succeed()) Expect(registry.RemoveNodeModel(context.Background(), node.ID, "kept-model", 0)).To(Succeed()) // Sibling replica must still exist — this was the latent bug pre-PR2: // the WHERE clause matched both rows and orphaned the healthy sibling. survivor, err := registry.GetNodeModel(context.Background(), node.ID, "kept-model", 1) Expect(err).ToNot(HaveOccurred()) Expect(survivor).ToNot(BeNil()) Expect(survivor.Address).To(Equal("127.0.0.1:50111")) // Replica 0 is gone _, err = registry.GetNodeModel(context.Background(), node.ID, "kept-model", 0) Expect(err).To(HaveOccurred()) }) It("RemoveAllNodeModelReplicas deletes every replica of the model on the node", func() { node := makeNode("multi-3", "10.0.0.212:50051", 16_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "purge-model", 0, "loaded", "a", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "purge-model", 1, "loaded", "b", 0)).To(Succeed()) Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "purge-model")).To(Succeed()) models, err := registry.GetNodeModels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) Expect(models).To(BeEmpty()) }) It("IncrementInFlight only updates the targeted replica row", func() { node := makeNode("multi-4", "10.0.0.213:50051", 16_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "infl-model", 0, "loaded", "a", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "infl-model", 1, "loaded", "b", 0)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), node.ID, "infl-model", 1)).To(Succeed()) Expect(registry.IncrementInFlight(context.Background(), node.ID, "infl-model", 1)).To(Succeed()) r0, err := registry.GetNodeModel(context.Background(), node.ID, "infl-model", 0) Expect(err).ToNot(HaveOccurred()) Expect(r0.InFlight).To(Equal(0), "replica 0 must not have been incremented") r1, err := registry.GetNodeModel(context.Background(), node.ID, "infl-model", 1) Expect(err).ToNot(HaveOccurred()) Expect(r1.InFlight).To(Equal(2)) }) It("CountReplicasOnNode returns the per-(node, model) row count", func() { node := makeNode("multi-5", "10.0.0.214:50051", 16_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "count-model", 0, "loaded", "a", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "count-model", 1, "loaded", "b", 0)).To(Succeed()) n, err := registry.CountReplicasOnNode(context.Background(), node.ID, "count-model") Expect(err).ToNot(HaveOccurred()) Expect(n).To(Equal(2)) }) It("NextFreeReplicaIndex returns the lowest unused index < maxSlots", func() { node := makeNode("multi-6", "10.0.0.215:50051", 16_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) // Slot 0 free initially idx, err := registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4) Expect(err).ToNot(HaveOccurred()) Expect(idx).To(Equal(0)) // Occupy 0 and 2 — next free is 1 (lowest gap) Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 0, "loaded", "a", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 2, "loaded", "c", 0)).To(Succeed()) idx, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4) Expect(err).ToNot(HaveOccurred()) Expect(idx).To(Equal(1), "must allocate the lowest free index for compactness") // Fill all 4 — must return ErrNoFreeSlot Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 1, "loaded", "b", 0)).To(Succeed()) Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 3, "loaded", "d", 0)).To(Succeed()) _, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4) Expect(err).To(MatchError(ErrNoFreeSlot)) // maxSlots=0 always returns ErrNoFreeSlot _, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "no-slots-model", 0) Expect(err).To(MatchError(ErrNoFreeSlot)) }) }) Describe("ApplyAutoLabels", func() { It("mirrors MaxReplicasPerModel as the node.replica-slots label", func() { node := makeNode("auto-label-replicas", "10.0.0.220:50051", 16_000_000_000) node.MaxReplicasPerModel = 4 node.GPUVendor = "nvidia" Expect(registry.Register(context.Background(), node, true)).To(Succeed()) registry.ApplyAutoLabels(context.Background(), node.ID, node) labels, err := registry.GetNodeLabels(context.Background(), node.ID) Expect(err).ToNot(HaveOccurred()) byKey := map[string]string{} for _, l := range labels { byKey[l.Key] = l.Value } Expect(byKey).To(HaveKeyWithValue("node.replica-slots", "4"), "selectors targeting fat nodes need this auto-label") Expect(byKey).To(HaveKeyWithValue("gpu.vendor", "nvidia")) }) It("defaults node.replica-slots to 1 when MaxReplicasPerModel is unset", func() { node := makeNode("auto-label-default", "10.0.0.221:50051", 4_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) // Fetch back; default should be 1 (PR1 schema test) fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.MaxReplicasPerModel).To(Equal(1)) registry.ApplyAutoLabels(context.Background(), node.ID, fetched) labels, _ := registry.GetNodeLabels(context.Background(), node.ID) byKey := map[string]string{} for _, l := range labels { byKey[l.Key] = l.Value } Expect(byKey).To(HaveKeyWithValue("node.replica-slots", "1")) }) }) Describe("VRAM soft-reservation (PR5)", func() { // These tests pin the soft-reservation contract: ReserveVRAM is the // admission gate that prevents two concurrent scheduling decisions // from over-committing the same node within one heartbeat window. It("ReserveVRAM atomically deducts from effectively-free VRAM", func() { node := makeNode("reserve-1", "10.0.0.230:50051", 10_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.ReserveVRAM(context.Background(), node.ID, 3_000_000_000)).To(Succeed()) fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.ReservedVRAM).To(Equal(uint64(3_000_000_000))) }) It("ReserveVRAM rejects when effectively-free VRAM is insufficient", func() { node := makeNode("reserve-2", "10.0.0.231:50051", 5_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) // First reservation fits. Expect(registry.ReserveVRAM(context.Background(), node.ID, 4_000_000_000)).To(Succeed()) // Second is too big — only 1 GB effectively free. err := registry.ReserveVRAM(context.Background(), node.ID, 2_000_000_000) Expect(err).To(MatchError(ErrInsufficientVRAM)) fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.ReservedVRAM).To(Equal(uint64(4_000_000_000)), "failed reservation must not bump the column") }) It("ReserveVRAM with bytes=0 is a no-op", func() { node := makeNode("reserve-3", "10.0.0.232:50051", 1_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.ReserveVRAM(context.Background(), node.ID, 0)).To(Succeed()) fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.ReservedVRAM).To(Equal(uint64(0))) }) It("ReleaseVRAM returns reserved bytes to the pool", func() { node := makeNode("release-1", "10.0.0.233:50051", 10_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.ReserveVRAM(context.Background(), node.ID, 4_000_000_000)).To(Succeed()) Expect(registry.ReleaseVRAM(context.Background(), node.ID, 1_000_000_000)).To(Succeed()) fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.ReservedVRAM).To(Equal(uint64(3_000_000_000))) }) It("ReleaseVRAM cannot underflow past zero", func() { node := makeNode("release-underflow", "10.0.0.234:50051", 1_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) // No reservation; release is a guarded no-op rather than wrapping // uint64 to a huge positive number. Expect(registry.ReleaseVRAM(context.Background(), node.ID, 5_000_000_000)).To(Succeed()) fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.ReservedVRAM).To(Equal(uint64(0))) }) It("Heartbeat with available_vram resets reserved_vram to 0", func() { node := makeNode("heartbeat-reset", "10.0.0.235:50051", 10_000_000_000) Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.ReserveVRAM(context.Background(), node.ID, 5_000_000_000)).To(Succeed()) fresh := uint64(8_000_000_000) Expect(registry.Heartbeat(context.Background(), node.ID, &HeartbeatUpdate{AvailableVRAM: &fresh})).To(Succeed()) fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.AvailableVRAM).To(Equal(fresh), "heartbeat must overwrite available_vram with the worker's reading") Expect(fetched.ReservedVRAM).To(Equal(uint64(0)), "heartbeat must clear the soft reservation — worker is the source of truth") }) It("UpdateMaxReplicasPerModel marks the value as a sticky override", func() { // The original UX bug: workers default the flag to 1, so every // re-registration silently reverted the admin's UI value. This // test pins the fix. node := &BackendNode{ Name: "override-survives", NodeType: NodeTypeBackend, Address: "10.0.0.240:50051", MaxReplicasPerModel: 1, } Expect(registry.Register(context.Background(), node, true)).To(Succeed()) // Admin sets capacity to 4 via the UI. Expect(registry.UpdateMaxReplicasPerModel(context.Background(), node.ID, 4)).To(Succeed()) fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.MaxReplicasPerModel).To(Equal(4)) Expect(fetched.MaxReplicasPerModelManuallySet).To(BeTrue()) // Worker re-registers with its default of 1 (operator never set the flag). restart := &BackendNode{ Name: "override-survives", NodeType: NodeTypeBackend, Address: "10.0.0.240:50051", MaxReplicasPerModel: 1, } Expect(registry.Register(context.Background(), restart, true)).To(Succeed()) // Override must have survived. fetched, _ = registry.Get(context.Background(), node.ID) Expect(fetched.MaxReplicasPerModel).To(Equal(4), "admin override must not be overwritten by worker re-registration") Expect(fetched.MaxReplicasPerModelManuallySet).To(BeTrue()) }) It("ResetMaxReplicasPerModel hands control back to the worker", func() { node := &BackendNode{ Name: "override-reset", NodeType: NodeTypeBackend, Address: "10.0.0.241:50051", MaxReplicasPerModel: 1, } Expect(registry.Register(context.Background(), node, true)).To(Succeed()) Expect(registry.UpdateMaxReplicasPerModel(context.Background(), node.ID, 4)).To(Succeed()) Expect(registry.ResetMaxReplicasPerModel(context.Background(), node.ID)).To(Succeed()) // Reset only flips the flag; the value stays until the worker // re-registers (we don't presume to know what the worker wants). fetched, _ := registry.Get(context.Background(), node.ID) Expect(fetched.MaxReplicasPerModelManuallySet).To(BeFalse()) // Now worker re-registers with 8. restart := &BackendNode{ Name: "override-reset", NodeType: NodeTypeBackend, Address: "10.0.0.241:50051", MaxReplicasPerModel: 8, } Expect(registry.Register(context.Background(), restart, true)).To(Succeed()) fetched, _ = registry.Get(context.Background(), node.ID) Expect(fetched.MaxReplicasPerModel).To(Equal(8), "after reset, the worker's value should apply") }) It("FindNodeWithVRAM honors the reservation", func() { small := makeNode("find-vram-small", "10.0.0.236:50051", 5_000_000_000) big := makeNode("find-vram-big", "10.0.0.237:50051", 20_000_000_000) Expect(registry.Register(context.Background(), small, true)).To(Succeed()) Expect(registry.Register(context.Background(), big, true)).To(Succeed()) // Reserve almost all of the big node so its effective free // drops below the request — small isn't big enough either — // the call must return an error. Expect(registry.ReserveVRAM(context.Background(), big.ID, 18_000_000_000)).To(Succeed()) _, err := registry.FindNodeWithVRAM(context.Background(), 8_000_000_000) Expect(err).To(HaveOccurred(), "reserved capacity must remove a node from VRAM-aware candidates") }) }) })