diff --git a/core/services/nodes/registry.go b/core/services/nodes/registry.go index 3d4111217..6f73c84b3 100644 --- a/core/services/nodes/registry.go +++ b/core/services/nodes/registry.go @@ -663,8 +663,16 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s var node BackendNode err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - // Order by in_flight ASC (least busy replica), then by available_vram DESC - // (prefer nodes with more free VRAM to spread load across the cluster). + // Order by in_flight ASC (least busy replica), then by last_used ASC + // (round-robin between equally-loaded replicas — oldest used wins, and + // every successful pick refreshes last_used below, so the "oldest" naturally + // rotates through the candidate set). available_vram DESC is the final + // tiebreaker for cold starts where last_used is identical. + // + // Without the last_used tier, a tie on in_flight (the common case at low + // to moderate concurrency where requests don't overlap) collapses to + // "biggest GPU wins every time" and one node ends up taking nearly all + // the load while replicas on other nodes sit idle. q := tx.Clauses(clause.Locking{Strength: "UPDATE"}). Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id"). Where("node_models.model_name = ? AND node_models.state = ?", modelName, "loaded") @@ -672,7 +680,7 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s q = q.Where("node_models.node_id IN ?", candidateNodeIDs) } if err := q. - Order("node_models.in_flight ASC, backend_nodes.available_vram DESC"). + Order("node_models.in_flight ASC, node_models.last_used ASC, backend_nodes.available_vram DESC"). First(&nm).Error; err != nil { return err } diff --git a/core/services/nodes/registry_test.go b/core/services/nodes/registry_test.go index dce310ff0..7f45362b5 100644 --- a/core/services/nodes/registry_test.go +++ b/core/services/nodes/registry_test.go @@ -304,6 +304,44 @@ var _ = Describe("NodeRegistry", func() { Expect(foundNM.NodeID).To(Equal(included.ID)) }) + It("round-robins between replicas when in_flight ties (last_used tiebreaker)", func() { + // Three replicas of the same model on three nodes, all with in_flight=0. + // Without the last_used tiebreaker, the node with the largest available_vram + // would win every pick and one node would take ~all the load. With it, + // each successful pick refreshes last_used so the next pick rotates to + // the oldest-used replica. + fat := makeNode("rr-fat", "10.0.0.50:50051", 24_000_000_000) + mid := makeNode("rr-mid", "10.0.0.51:50051", 16_000_000_000) + small := makeNode("rr-small", "10.0.0.52:50051", 8_000_000_000) + Expect(registry.Register(context.Background(), fat, true)).To(Succeed()) + Expect(registry.Register(context.Background(), mid, true)).To(Succeed()) + Expect(registry.Register(context.Background(), small, true)).To(Succeed()) + + Expect(registry.SetNodeModel(context.Background(), fat.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), mid.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), small.ID, "rr-model", 0, "loaded", "", 0)).To(Succeed()) + + // Decrement back to 0 after each pick so the next call sees a tie. + // (FindAndLockNodeWithModel atomically increments to lock the row.) + picks := make([]string, 0, 9) + for i := 0; i < 9; i++ { + n, nm, err := registry.FindAndLockNodeWithModel(context.Background(), "rr-model", nil) + Expect(err).ToNot(HaveOccurred()) + picks = append(picks, n.Name) + Expect(registry.DecrementInFlight(context.Background(), n.ID, "rr-model", nm.ReplicaIndex)).To(Succeed()) + } + + // Each replica should have been picked at least twice across 9 ties — + // proves we're rotating, not pinning to the largest-VRAM node. + counts := map[string]int{} + for _, p := range picks { + counts[p]++ + } + Expect(counts["rr-fat"]).To(BeNumerically(">=", 2), "fat node was picked %d times across 9 ties: %v", counts["rr-fat"], picks) + Expect(counts["rr-mid"]).To(BeNumerically(">=", 2), "mid node was picked %d times across 9 ties: %v", counts["rr-mid"], picks) + Expect(counts["rr-small"]).To(BeNumerically(">=", 2), "small node was picked %d times across 9 ties: %v", counts["rr-small"], picks) + }) + It("returns not-found when the model is loaded only on excluded nodes", func() { loadedExcluded := makeNode("excl-only-node", "10.0.0.45:50051", 8_000_000_000) emptyIncluded := makeNode("empty-included-node", "10.0.0.46:50051", 8_000_000_000)