From 8bbe89a537e47e78b5f06abe7f64fac9554d2aaf Mon Sep 17 00:00:00 2001 From: "LocalAI [bot]" <139863280+localai-bot@users.noreply.github.com> Date: Sun, 24 May 2026 10:15:27 +0200 Subject: [PATCH] fix(distributed): route per request across loaded replicas + cache probeHealth (#9968) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(distributed): extract PickBestReplica from FindAndLockNodeWithModel Lifts the replica-selection policy (in_flight ASC, last_used ASC, available_vram DESC) out of the SQL ORDER BY into a pure Go function in the new replicapicker.go. The SQL clause keeps its FOR UPDATE atomicity and remains the production path used by SmartRouter; PickBestReplica is the canonical implementation that the future per-frontend rotating replica cache (TODO referenced from pkg/model) will call against an in-memory snapshot without paying a DB round-trip per inference. A new registry_test mirror spec seeds a multi-tier scenario and asserts both layers pick the same replica, so any future tweak to either side fails the test until the other side is updated. No behavior change. Signed-off-by: Ettore Di Giacinto Assisted-by: Claude:claude-opus-4-7 [Claude Code] * fix(distributed): route per inference request and cache probeHealth Two related fixes that together restore load balancing across loaded replicas of the same model. 1. ModelLoader.Load and LoadModel bypass the local *Model cache when modelRouter is set. The cached *Model wraps an InFlightTrackingClient bound to a single (nodeID, replicaIndex) — reusing it pinned every subsequent request to whichever node won the very first pick, so FindAndLockNodeWithModel's round-robin never got a chance to run even after the reconciler scaled the model out to a second node. In distributed mode SmartRouter.Route now runs per request, and PickBestReplica picks the least-loaded replica each time. SmartRouter has its own coalescing (advisory DB lock for first-time loads + singleflight on backend.install RPC) so concurrent first requests for a not-yet-loaded model still produce a single worker side install. 2. SmartRouter.probeHealth memoizes successful gRPC HealthCheck results in a new probeCache (probe_cache.go) with a 30s TTL. With per-request routing every inference call hits probeHealth, and llama.cpp-style backends serialize HealthCheck behind active Predict — so a burst of incoming requests stalled on the probe to a node already mid-stream, tripping the 2s timeout and falling through to the install path. singleflight collapses N concurrent first-time probes for the same (node, addr) into one round-trip, failed probes invalidate the entry so the staleness-recovery path still triggers, and the TTL matches pkg/model/model.go's healthCheckTTL so the single-process and distributed paths share a staleness budget. The background HealthMonitor still reaps actually-dead backends within ~45s. The bypass introduces one short FindAndLockNodeWithModel transaction per inference. A TODO in pkg/model/loader.go documents the future per modelID rotating-replica cache that would reuse PickBestReplica against an in-memory snapshot and skip the DB round-trip for hot paths. Signed-off-by: Ettore Di Giacinto Assisted-by: Claude:claude-opus-4-7 [Claude Code] --------- Signed-off-by: Ettore Di Giacinto Co-authored-by: Ettore Di Giacinto --- core/services/nodes/probe_cache.go | 94 ++++++++++++++ core/services/nodes/probe_cache_test.go | 145 ++++++++++++++++++++++ core/services/nodes/registry.go | 37 ++++-- core/services/nodes/registry_test.go | 74 +++++++++++ core/services/nodes/replicapicker.go | 69 ++++++++++ core/services/nodes/replicapicker_test.go | 81 ++++++++++++ core/services/nodes/router.go | 38 ++++-- pkg/model/initializers.go | 31 +++++ pkg/model/loader.go | 43 +++++++ 9 files changed, 592 insertions(+), 20 deletions(-) create mode 100644 core/services/nodes/probe_cache.go create mode 100644 core/services/nodes/probe_cache_test.go create mode 100644 core/services/nodes/replicapicker.go create mode 100644 core/services/nodes/replicapicker_test.go diff --git a/core/services/nodes/probe_cache.go b/core/services/nodes/probe_cache.go new file mode 100644 index 000000000..422e36ede --- /dev/null +++ b/core/services/nodes/probe_cache.go @@ -0,0 +1,94 @@ +package nodes + +import ( + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +// probeCache memoizes recent successful gRPC HealthCheck results for +// (nodeID, addr) tuples so SmartRouter.probeHealth doesn't pay a round-trip +// on every inference request. +// +// Why this exists: with per-request routing (see pkg/model/loader.go), every +// inference call goes through SmartRouter.Route, which probes the backend +// before returning a client. Many gRPC backends (notably llama.cpp's server) +// serialize HealthCheck against active Predict on a shared goroutine, so a +// burst of new requests can stall behind a single long-running stream — +// exactly the "queue stalling" symptom observed in distributed clusters. +// +// The background HealthMonitor (perModelHealthCheck) is still the cluster-wide +// source of truth that reaps actually-dead backends within ~45s; this cache +// only saves the per-request hot path from re-asking when nothing has changed. +// +// TTL matches healthCheckTTL in pkg/model/model.go so the single-process +// IsRecentlyHealthy path and this distributed-mode path share the same +// staleness budget. +type probeCache struct { + ttl time.Duration + mu sync.Mutex + seen map[string]time.Time // key → last successful probe + flight singleflight.Group // coalesces concurrent probes for the same key +} + +// newProbeCache returns a probeCache with the given TTL. Zero TTL disables +// caching: every call to DoOrCached invokes the probe. +func newProbeCache(ttl time.Duration) *probeCache { + return &probeCache{ + ttl: ttl, + seen: make(map[string]time.Time), + } +} + +// IsFresh reports whether key was successfully probed within TTL. +func (c *probeCache) IsFresh(key string) bool { + if c.ttl <= 0 { + return false + } + c.mu.Lock() + defer c.mu.Unlock() + last, ok := c.seen[key] + return ok && time.Since(last) < c.ttl +} + +// markFresh records key as successfully probed at the current time. +func (c *probeCache) markFresh(key string) { + c.mu.Lock() + defer c.mu.Unlock() + c.seen[key] = time.Now() +} + +// Invalidate drops any cached freshness for key. Used after a probe failure +// (or any other signal that the backend may not be alive) so the next call +// will re-probe instead of trusting stale state. +func (c *probeCache) Invalidate(key string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.seen, key) +} + +// DoOrCached returns true if key is fresh; otherwise it runs probe (coalescing +// concurrent callers via singleflight) and caches a successful result. Failed +// probes invalidate the cache, so a transient miss doesn't pin every +// subsequent request to a re-probe. +func (c *probeCache) DoOrCached(key string, probe func() bool) bool { + if c.IsFresh(key) { + return true + } + v, _, _ := c.flight.Do(key, func() (any, error) { + // Double-check after potentially waiting: another caller in this + // flight may have just populated the cache. + if c.IsFresh(key) { + return true, nil + } + ok := probe() + if ok { + c.markFresh(key) + } else { + c.Invalidate(key) + } + return ok, nil + }) + return v.(bool) +} diff --git a/core/services/nodes/probe_cache_test.go b/core/services/nodes/probe_cache_test.go new file mode 100644 index 000000000..58e6fa111 --- /dev/null +++ b/core/services/nodes/probe_cache_test.go @@ -0,0 +1,145 @@ +package nodes + +import ( + "sync" + "sync/atomic" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("probeCache", func() { + It("invokes the probe on a cold cache and caches success", func() { + c := newProbeCache(time.Minute) + var calls int32 + probe := func() bool { + atomic.AddInt32(&calls, 1) + return true + } + + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + + // Cached: probe ran once. + Expect(atomic.LoadInt32(&calls)).To(Equal(int32(1))) + }) + + It("re-probes after the TTL expires", func() { + // 1 ms TTL means the second call is virtually guaranteed to see an + // expired entry without flaking on scheduler jitter. + c := newProbeCache(time.Millisecond) + var calls int32 + probe := func() bool { + atomic.AddInt32(&calls, 1) + return true + } + + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + time.Sleep(5 * time.Millisecond) + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + + Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2))) + }) + + It("does not cache failed probes — next call re-probes", func() { + c := newProbeCache(time.Minute) + var calls int32 + var result atomic.Bool + probe := func() bool { + atomic.AddInt32(&calls, 1) + return result.Load() + } + + // First probe fails — must NOT be cached. + result.Store(false) + Expect(c.DoOrCached("k", probe)).To(BeFalse()) + Expect(c.IsFresh("k")).To(BeFalse()) + + // Recover: second probe succeeds and is cached. + result.Store(true) + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + Expect(c.IsFresh("k")).To(BeTrue()) + + // Third call short-circuits on the fresh entry. + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2))) + }) + + It("coalesces concurrent probes via singleflight", func() { + // Models the "6 chat completions arrive simultaneously for a + // not-yet-cached backend" scenario. Without singleflight every caller + // would dial the backend, defeating the purpose of the cache. + c := newProbeCache(time.Minute) + var calls int32 + start := make(chan struct{}) + probe := func() bool { + atomic.AddInt32(&calls, 1) + // Stall briefly so the test reliably has all goroutines parked + // inside flight.Do at the same time. + time.Sleep(50 * time.Millisecond) + return true + } + + const N = 8 + var wg sync.WaitGroup + results := make([]bool, N) + for i := 0; i < N; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + <-start + results[i] = c.DoOrCached("k", probe) + }(i) + } + + close(start) + wg.Wait() + + Expect(atomic.LoadInt32(&calls)).To(Equal(int32(1)), + "singleflight must collapse %d concurrent probes into one", N) + for i, got := range results { + Expect(got).To(BeTrue(), "goroutine %d saw a different result", i) + } + }) + + It("treats different keys independently", func() { + c := newProbeCache(time.Minute) + var aCalls, bCalls int32 + Expect(c.DoOrCached("a", func() bool { atomic.AddInt32(&aCalls, 1); return true })).To(BeTrue()) + Expect(c.DoOrCached("b", func() bool { atomic.AddInt32(&bCalls, 1); return true })).To(BeTrue()) + Expect(c.DoOrCached("a", func() bool { atomic.AddInt32(&aCalls, 1); return true })).To(BeTrue()) + + Expect(atomic.LoadInt32(&aCalls)).To(Equal(int32(1))) + Expect(atomic.LoadInt32(&bCalls)).To(Equal(int32(1))) + }) + + It("disables caching when TTL is zero", func() { + c := newProbeCache(0) + var calls int32 + probe := func() bool { + atomic.AddInt32(&calls, 1) + return true + } + + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + + Expect(atomic.LoadInt32(&calls)).To(Equal(int32(3))) + }) + + It("Invalidate forces the next call to re-probe", func() { + c := newProbeCache(time.Hour) + var calls int32 + probe := func() bool { + atomic.AddInt32(&calls, 1) + return true + } + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + c.Invalidate("k") + Expect(c.DoOrCached("k", probe)).To(BeTrue()) + Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2))) + }) +}) diff --git a/core/services/nodes/registry.go b/core/services/nodes/registry.go index ed742d599..5dcb48a5c 100644 --- a/core/services/nodes/registry.go +++ b/core/services/nodes/registry.go @@ -668,10 +668,21 @@ func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string) return nodes, nil } -// FindAndLockNodeWithModel atomically finds the least-loaded node with the given -// model loaded and increments its in-flight counter within a single transaction. -// The SELECT FOR UPDATE row lock prevents concurrent eviction from removing the -// NodeModel row between the find and increment operations. +// FindAndLockNodeWithModel atomically finds the best loaded replica of the +// given model and increments its in-flight counter within a single +// transaction. The SELECT FOR UPDATE row lock prevents concurrent eviction +// from removing the NodeModel row between the find and increment operations, +// and serializes contending routers so concurrent picks distribute across +// replicas instead of all landing on the same row. +// +// **Policy:** the SQL ORDER BY below MUST mirror PickBestReplica +// (replicapicker.go). PickBestReplica is the canonical Go implementation of +// the same rule — the per-frontend rotating-replica cache (TODO, see +// pkg/model/loader.go) will eventually use it against in-memory snapshots so +// hot inference requests don't pay this DB round-trip. If you change the +// ordering here, change both sides; the TestFindAndLockNodeWithModelMirror +// spec ("agrees with PickBestReplica on a seeded dataset") fails fast if they +// drift. // // When candidateNodeIDs is non-empty, only nodes in that set are considered. // Pass nil (or empty) to consider any node. This lets callers pre-filter by @@ -683,16 +694,16 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s var node BackendNode err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - // Order by in_flight ASC (least busy replica), then by last_used ASC - // (round-robin between equally-loaded replicas — oldest used wins, and - // every successful pick refreshes last_used below, so the "oldest" naturally - // rotates through the candidate set). available_vram DESC is the final - // tiebreaker for cold starts where last_used is identical. + // Mirror of PickBestReplica's policy (see replicapicker.go): + // 1. in_flight ASC — least busy replica. + // 2. last_used ASC — round-robin between equally-loaded replicas. + // Every successful pick refreshes last_used below, so the + // "oldest" tier naturally rotates through the candidate set. + // Without this tier, in_flight ties collapsed to "fattest GPU + // wins every time" and one node took nearly all the load. + // 3. available_vram DESC — final tiebreaker for cold starts where + // last_used is identical across replicas. // - // Without the last_used tier, a tie on in_flight (the common case at low - // to moderate concurrency where requests don't overlap) collapses to - // "biggest GPU wins every time" and one node ends up taking nearly all - // the load while replicas on other nodes sit idle. // Filter on backend_nodes.status = healthy in the inner JOIN itself, // not only in the later node-fetch step. The previous version picked // a (node_id, replica) pair purely on node_models state, then bailed diff --git a/core/services/nodes/registry_test.go b/core/services/nodes/registry_test.go index 7f45362b5..f57ca194e 100644 --- a/core/services/nodes/registry_test.go +++ b/core/services/nodes/registry_test.go @@ -3,6 +3,7 @@ package nodes import ( "context" "runtime" + "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -357,6 +358,79 @@ var _ = Describe("NodeRegistry", func() { _, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID}) Expect(err).To(HaveOccurred()) }) + + It("agrees with PickBestReplica on a seeded dataset (policy mirror)", func() { + // Guard against drift between the SQL ORDER BY in + // FindAndLockNodeWithModel and the canonical Go implementation in + // PickBestReplica. The two layers will eventually diverge in + // caller (DB-backed atomic pick vs in-memory snapshot pick for the + // per-frontend rotating cache), but the policy itself must stay + // the single source of truth. If this test fails, update *both* + // sides — never just one. + // + // Scenario exercises all three tiers: + // - "loser-busy" has the most VRAM but in_flight=2 — loses tier 1. + // - "loser-recent" ties at in_flight=0 but its last_used is the + // newest of the in_flight=0 group — loses tier 2. + // - "winner-mid" and "winner-fat" both tie at in_flight=0 and + // share the oldest last_used — tier 3 decides: fattest wins. + loserBusy := makeNode("mirror-loser-busy", "10.0.0.70:50051", 32_000_000_000) + loserRecent := makeNode("mirror-loser-recent", "10.0.0.71:50051", 8_000_000_000) + winnerMid := makeNode("mirror-winner-mid", "10.0.0.72:50051", 16_000_000_000) + winnerFat := makeNode("mirror-winner-fat", "10.0.0.73:50051", 24_000_000_000) + for _, n := range []*BackendNode{loserBusy, loserRecent, winnerMid, winnerFat} { + Expect(registry.Register(context.Background(), n, true)).To(Succeed()) + Expect(registry.SetNodeModel(context.Background(), n.ID, "mirror-model", 0, "loaded", "", 0)).To(Succeed()) + } + + // Force in_flight=2 on the "busy" node so tier 1 disqualifies it. + Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed()) + Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed()) + + // Slam last_used to known values so the test is deterministic + // regardless of clock resolution between the helpers above. + base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + set := func(id string, t time.Time) { + Expect(db.Model(&NodeModel{}). + Where("node_id = ? AND model_name = ?", id, "mirror-model"). + Update("last_used", t).Error).To(Succeed()) + } + set(loserBusy.ID, base) // newest doesn't matter — already disqualified by tier 1 + set(loserRecent.ID, base.Add(time.Hour)) + set(winnerMid.ID, base) + set(winnerFat.ID, base) + + // Pull the same dataset both pickers will operate on. The Go + // picker is a faithful representation of the policy; the SQL is + // the production path. + var rows []NodeModel + Expect(db.Where("model_name = ? AND state = ?", "mirror-model", "loaded"). + Find(&rows).Error).To(Succeed()) + candidates := make([]ReplicaCandidate, 0, len(rows)) + for _, nm := range rows { + var bn BackendNode + Expect(db.First(&bn, "id = ? AND status = ?", nm.NodeID, StatusHealthy).Error).To(Succeed()) + candidates = append(candidates, ReplicaCandidate{ + NodeID: nm.NodeID, + Address: bn.Address, + ReplicaIndex: nm.ReplicaIndex, + InFlight: nm.InFlight, + LastUsed: nm.LastUsed, + AvailableVRAM: bn.AvailableVRAM, + }) + } + goPick := PickBestReplica(candidates) + Expect(goPick).ToNot(BeNil()) + + sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil) + Expect(err).ToNot(HaveOccurred()) + + Expect(sqlNode.ID).To(Equal(goPick.NodeID), + "SQL ORDER BY picked %s; PickBestReplica picked %s — policy has drifted", + sqlNode.ID, goPick.NodeID) + // Sanity check: the policy says winner-fat wins on tier 3. + Expect(goPick.NodeID).To(Equal(winnerFat.ID)) + }) }) Describe("MarkHealthy and MarkUnhealthy round-trip", func() { diff --git a/core/services/nodes/replicapicker.go b/core/services/nodes/replicapicker.go new file mode 100644 index 000000000..56d383e61 --- /dev/null +++ b/core/services/nodes/replicapicker.go @@ -0,0 +1,69 @@ +package nodes + +import "time" + +// ReplicaCandidate is the minimum view of a loaded model replica needed to +// apply the routing policy. It is intentionally decoupled from the gorm models +// (BackendNode, NodeModel) so the same picker can run against fresh DB rows +// (SmartRouter.Route → FindAndLockNodeWithModel) and against an in-memory +// snapshot (the per-frontend rotating cache flagged in pkg/model — see TODO +// below). +type ReplicaCandidate struct { + NodeID string + Address string + ReplicaIndex int + InFlight int + LastUsed time.Time + AvailableVRAM uint64 +} + +// PickBestReplica is the single source of truth for which loaded replica of a +// model serves the next request. +// +// Policy (ordered tiers, first non-tie wins): +// 1. Least in-flight wins — primary load-balancing signal. +// 2. Oldest last_used wins — round-robin between equally-loaded replicas. +// Every successful pick refreshes last_used (in FindAndLockNodeWithModel's +// transaction and in TouchNodeModel on cache hits), so the "oldest" tier +// naturally rotates through the candidate set without a separate cursor. +// 3. Largest available_vram wins — cold-start tiebreaker for replicas that +// have never been picked (identical last_used). +// +// Two callers must agree on this policy: +// +// - SmartRouter.Route, via the SQL ORDER BY in FindAndLockNodeWithModel +// (registry.go). That query MUST mirror this function — TestPickerSQLMirror +// asserts both sides agree on a representative dataset. +// +// - The per-frontend rotating-replica cache (NOT YET IMPLEMENTED — see +// pkg/model/loader.go and pkg/model/initializers.go for the integration +// point). When that cache lands, it will call PickBestReplica against an +// in-memory snapshot using locally-tracked in-flight counters and skip the +// per-request DB round-trip. +// +// Returns nil when the candidate list is empty. Does not allocate. +func PickBestReplica(candidates []ReplicaCandidate) *ReplicaCandidate { + if len(candidates) == 0 { + return nil + } + best := &candidates[0] + for i := 1; i < len(candidates); i++ { + c := &candidates[i] + if betterReplica(c, best) { + best = c + } + } + return best +} + +// betterReplica reports whether candidate a is preferred over candidate b +// under the policy documented on PickBestReplica. +func betterReplica(a, b *ReplicaCandidate) bool { + if a.InFlight != b.InFlight { + return a.InFlight < b.InFlight + } + if !a.LastUsed.Equal(b.LastUsed) { + return a.LastUsed.Before(b.LastUsed) + } + return a.AvailableVRAM > b.AvailableVRAM +} diff --git a/core/services/nodes/replicapicker_test.go b/core/services/nodes/replicapicker_test.go new file mode 100644 index 000000000..d71b83808 --- /dev/null +++ b/core/services/nodes/replicapicker_test.go @@ -0,0 +1,81 @@ +package nodes + +import ( + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("PickBestReplica", func() { + // Use a single reference time so every test that wants identical + // last_used can share it without relying on time.Now() interleavings. + ref := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + + It("returns nil for an empty candidate list", func() { + Expect(PickBestReplica(nil)).To(BeNil()) + Expect(PickBestReplica([]ReplicaCandidate{})).To(BeNil()) + }) + + It("returns the only candidate when there is just one", func() { + only := ReplicaCandidate{NodeID: "only", InFlight: 99, LastUsed: ref, AvailableVRAM: 1} + pick := PickBestReplica([]ReplicaCandidate{only}) + Expect(pick).ToNot(BeNil()) + Expect(pick.NodeID).To(Equal("only")) + }) + + It("prefers the replica with the lowest in_flight", func() { + // Without the in-flight tier, the larger-VRAM node would win. + cs := []ReplicaCandidate{ + {NodeID: "busy-big", InFlight: 3, LastUsed: ref, AvailableVRAM: 24_000_000_000}, + {NodeID: "idle-small", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000}, + {NodeID: "mid", InFlight: 1, LastUsed: ref, AvailableVRAM: 16_000_000_000}, + } + Expect(PickBestReplica(cs).NodeID).To(Equal("idle-small")) + }) + + It("uses oldest last_used as the tiebreaker when in_flight ties", func() { + // All three tied on in_flight=0. Without last_used, available_vram + // would pin every pick to the fattest node — the exact bug + // fix(distributed): round-robin replicas of the same model addressed. + cs := []ReplicaCandidate{ + {NodeID: "fat-recent", InFlight: 0, LastUsed: ref.Add(2 * time.Second), AvailableVRAM: 24_000_000_000}, + {NodeID: "small-oldest", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000}, + {NodeID: "mid-middle", InFlight: 0, LastUsed: ref.Add(1 * time.Second), AvailableVRAM: 16_000_000_000}, + } + Expect(PickBestReplica(cs).NodeID).To(Equal("small-oldest")) + }) + + It("uses largest available_vram as the final tiebreaker", func() { + // in_flight tied AND last_used tied — pick the largest GPU. + cs := []ReplicaCandidate{ + {NodeID: "small", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000}, + {NodeID: "fat", InFlight: 0, LastUsed: ref, AvailableVRAM: 24_000_000_000}, + {NodeID: "mid", InFlight: 0, LastUsed: ref, AvailableVRAM: 16_000_000_000}, + } + Expect(PickBestReplica(cs).NodeID).To(Equal("fat")) + }) + + It("respects tier precedence: in_flight beats last_used beats available_vram", func() { + // "fat-busy-oldest" wins on neither of the first two tiers; the + // "small-idle-recent" replica is busy=0 and should beat it despite + // being newer and smaller. + cs := []ReplicaCandidate{ + {NodeID: "fat-busy-oldest", InFlight: 5, LastUsed: ref, AvailableVRAM: 80_000_000_000}, + {NodeID: "small-idle-recent", InFlight: 0, LastUsed: ref.Add(time.Hour), AvailableVRAM: 4_000_000_000}, + } + Expect(PickBestReplica(cs).NodeID).To(Equal("small-idle-recent")) + }) + + It("is stable: returns the first candidate when every field ties", func() { + // betterReplica returns false on a full tie, so the leading element + // remains best. Callers shouldn't depend on this for correctness, + // but pinning the behavior here catches accidental reorderings. + cs := []ReplicaCandidate{ + {NodeID: "first", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000}, + {NodeID: "second", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000}, + {NodeID: "third", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000}, + } + Expect(PickBestReplica(cs).NodeID).To(Equal("first")) + }) +}) diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go index ca95b1653..c29108846 100644 --- a/core/services/nodes/router.go +++ b/core/services/nodes/router.go @@ -61,8 +61,19 @@ type SmartRouter struct { // completions for one not-yet-loaded model produce ONE round-trip, not // six. Avoids amplifying head-of-line blocking on the worker side. installFlight singleflight.Group + // probeCache memoizes recent successful gRPC HealthCheck results so + // per-request routing doesn't stall behind a busy backend's serialized + // HealthCheck/Predict. See probe_cache.go for the rationale. + probeCache *probeCache } +// probeCacheTTL is how long a successful gRPC HealthCheck on a backend is +// trusted before the next request re-probes. Matches healthCheckTTL in +// pkg/model/model.go so the single-process and distributed paths share a +// staleness budget. The background HealthMonitor still reaps dead backends +// independently within ~45s (see perModelMissThreshold). +const probeCacheTTL = 30 * time.Second + // NewSmartRouter creates a new SmartRouter backed by the given ModelRouter. // All optional dependencies are passed via SmartRouterOptions to avoid post-creation races. func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter { @@ -79,6 +90,7 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter db: opts.DB, stagingTracker: NewStagingTracker(), conflictResolver: opts.ConflictResolver, + probeCache: newProbeCache(probeCacheTTL), } } @@ -961,14 +973,26 @@ func (r *SmartRouter) stageGenericOptions(ctx context.Context, node *BackendNode } // probeHealth checks whether a backend process on the given node/addr is alive -// via a gRPC health check with a 2-second timeout. The client is closed after the check. +// via a gRPC health check with a 2-second timeout. The client is closed after +// the check. +// +// The result is memoized in r.probeCache for probeCacheTTL. With per-request +// routing every inference call lands here, and unbounded re-probing can stall +// behind a busy backend that serializes HealthCheck against active Predict. +// Concurrent probes for the same (node, addr) coalesce via singleflight so a +// burst of N requests for a cold cache costs at most one round-trip, not N. +// Failed probes invalidate the cache so the staleness recovery path +// (DecrementInFlight + RemoveNodeModel) still triggers on the next request. func (r *SmartRouter) probeHealth(ctx context.Context, node *BackendNode, addr string) bool { - client := r.buildClientForAddr(node, addr, false) - defer closeClient(client) - checkCtx, cancel := context.WithTimeout(ctx, 2*time.Second) - defer cancel() - ok, _ := client.HealthCheck(checkCtx) - return ok + key := node.ID + "|" + addr + return r.probeCache.DoOrCached(key, func() bool { + client := r.buildClientForAddr(node, addr, false) + defer closeClient(client) + checkCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + ok, _ := client.HealthCheck(checkCtx) + return ok + }) } // closeClient closes a gRPC backend client if it implements io.Closer. diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 6d291961b..d7719ca13 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -276,6 +276,37 @@ func (ml *ModelLoader) updateModelLastUsed(m *Model) { func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { o := NewOptions(opts...) + ml.mu.Lock() + distributed := ml.modelRouter != nil + ml.mu.Unlock() + + // In distributed mode, SmartRouter must run per inference request so + // PickBestReplica (core/services/nodes/replicapicker.go) picks the + // least-loaded replica each time. Bypass the local cache and the local + // LRU / concurrency-group watchdog enforcement: both are scoped to the + // in-process Model store, which in distributed mode only holds stubs for + // remote replicas. SmartRouter handles cluster-wide eviction + // (evictLRUAndFreeNode) and concurrency-group anti-affinity + // (narrowByGroupAntiAffinity) at the scheduler layer. + // + // TODO(distributed-cache): see LoadModel for the rotating-replica-cache + // integration point that would let hot paths skip the per-request DB + // round-trip without giving up the shared PickBestReplica policy. + if distributed { + client, err := ml.backendLoader(opts...) + if err != nil { + return nil, err + } + if m := ml.CheckIsLoaded(o.modelID); m != nil && m.Process() == nil { + client = newConnectionEvictingClient(client, o.modelID, func() { + if err := ml.ShutdownModel(o.modelID); err != nil { + xlog.Warn("Failed to shut down remote model after connection error", "model", o.modelID, "error", err) + } + }) + } + return client, nil + } + // Return earlier if we have a model already loaded // (avoid looping through all the backends) if m := ml.CheckIsLoaded(o.modelID); m != nil { diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 75d477dc3..7947dfc06 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -250,6 +250,49 @@ func (ml *ModelLoader) ListLoadedModels() []*Model { } func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, string, string) (*Model, error)) (*Model, error) { + ml.mu.Lock() + distributed := ml.modelRouter != nil + ml.mu.Unlock() + + if distributed { + // Distributed mode: SmartRouter must run per inference request so + // PickBestReplica (core/services/nodes/replicapicker.go) picks the + // least-loaded replica each time. The cached *Model returned from a + // previous call holds a client wrapper bound to one (nodeID, + // replicaIndex), so reusing it pins every subsequent request to the + // node that won the very first pick — defeating per-replica load + // balancing. Bypass the cache and the loading-coalesce map; the + // router does its own coalescing for first-time loads (advisory DB + // lock + singleflight on backend.install RPC), so concurrent first + // requests still produce a single worker-side install. + // + // TODO(distributed-cache): if profiling shows the per-request + // FindAndLockNodeWithModel SELECT FOR UPDATE becomes a hot path + // under burst load, replace this branch with a per-modelID cache + // that holds a *list* of replicas (refreshed every ~5s in + // background) and picks per call via PickBestReplica against + // locally-tracked in-flight counters. Same policy, no DB round-trip + // per inference. Trade-off: cross-frontend in-flight visibility + // becomes eventually consistent, acceptable for 1-3 frontend + // deployments. + modelFile := filepath.Join(ml.ModelPath, modelName) + model, err := loader(modelID, modelName, modelFile) + if err != nil { + return nil, fmt.Errorf("failed to route model with internal loader: %s", err) + } + if model == nil { + return nil, fmt.Errorf("loader didn't return a model") + } + // Record the latest mapping so DistributedModelStore.Range, shutdown, + // and listing endpoints see a representative entry. The DB is the + // source of truth for cluster-wide state; the local store is just a + // stub for in-process callers. + ml.mu.Lock() + ml.store.Set(modelID, model) + ml.mu.Unlock() + return model, nil + } + ml.mu.Lock() // Check if we already have a loaded model