mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-24 08:38:02 -04:00
fix(distributed): route per request across loaded replicas + cache probeHealth (#9968)
* refactor(distributed): extract PickBestReplica from FindAndLockNodeWithModel Lifts the replica-selection policy (in_flight ASC, last_used ASC, available_vram DESC) out of the SQL ORDER BY into a pure Go function in the new replicapicker.go. The SQL clause keeps its FOR UPDATE atomicity and remains the production path used by SmartRouter; PickBestReplica is the canonical implementation that the future per-frontend rotating replica cache (TODO referenced from pkg/model) will call against an in-memory snapshot without paying a DB round-trip per inference. A new registry_test mirror spec seeds a multi-tier scenario and asserts both layers pick the same replica, so any future tweak to either side fails the test until the other side is updated. No behavior change. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-7 [Claude Code] * fix(distributed): route per inference request and cache probeHealth Two related fixes that together restore load balancing across loaded replicas of the same model. 1. ModelLoader.Load and LoadModel bypass the local *Model cache when modelRouter is set. The cached *Model wraps an InFlightTrackingClient bound to a single (nodeID, replicaIndex) — reusing it pinned every subsequent request to whichever node won the very first pick, so FindAndLockNodeWithModel's round-robin never got a chance to run even after the reconciler scaled the model out to a second node. In distributed mode SmartRouter.Route now runs per request, and PickBestReplica picks the least-loaded replica each time. SmartRouter has its own coalescing (advisory DB lock for first-time loads + singleflight on backend.install RPC) so concurrent first requests for a not-yet-loaded model still produce a single worker side install. 2. SmartRouter.probeHealth memoizes successful gRPC HealthCheck results in a new probeCache (probe_cache.go) with a 30s TTL. With per-request routing every inference call hits probeHealth, and llama.cpp-style backends serialize HealthCheck behind active Predict — so a burst of incoming requests stalled on the probe to a node already mid-stream, tripping the 2s timeout and falling through to the install path. singleflight collapses N concurrent first-time probes for the same (node, addr) into one round-trip, failed probes invalidate the entry so the staleness-recovery path still triggers, and the TTL matches pkg/model/model.go's healthCheckTTL so the single-process and distributed paths share a staleness budget. The background HealthMonitor still reaps actually-dead backends within ~45s. The bypass introduces one short FindAndLockNodeWithModel transaction per inference. A TODO in pkg/model/loader.go documents the future per modelID rotating-replica cache that would reuse PickBestReplica against an in-memory snapshot and skip the DB round-trip for hot paths. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-7 [Claude Code] --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
94
core/services/nodes/probe_cache.go
Normal file
94
core/services/nodes/probe_cache.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// probeCache memoizes recent successful gRPC HealthCheck results for
|
||||
// (nodeID, addr) tuples so SmartRouter.probeHealth doesn't pay a round-trip
|
||||
// on every inference request.
|
||||
//
|
||||
// Why this exists: with per-request routing (see pkg/model/loader.go), every
|
||||
// inference call goes through SmartRouter.Route, which probes the backend
|
||||
// before returning a client. Many gRPC backends (notably llama.cpp's server)
|
||||
// serialize HealthCheck against active Predict on a shared goroutine, so a
|
||||
// burst of new requests can stall behind a single long-running stream —
|
||||
// exactly the "queue stalling" symptom observed in distributed clusters.
|
||||
//
|
||||
// The background HealthMonitor (perModelHealthCheck) is still the cluster-wide
|
||||
// source of truth that reaps actually-dead backends within ~45s; this cache
|
||||
// only saves the per-request hot path from re-asking when nothing has changed.
|
||||
//
|
||||
// TTL matches healthCheckTTL in pkg/model/model.go so the single-process
|
||||
// IsRecentlyHealthy path and this distributed-mode path share the same
|
||||
// staleness budget.
|
||||
type probeCache struct {
|
||||
ttl time.Duration
|
||||
mu sync.Mutex
|
||||
seen map[string]time.Time // key → last successful probe
|
||||
flight singleflight.Group // coalesces concurrent probes for the same key
|
||||
}
|
||||
|
||||
// newProbeCache returns a probeCache with the given TTL. Zero TTL disables
|
||||
// caching: every call to DoOrCached invokes the probe.
|
||||
func newProbeCache(ttl time.Duration) *probeCache {
|
||||
return &probeCache{
|
||||
ttl: ttl,
|
||||
seen: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// IsFresh reports whether key was successfully probed within TTL.
|
||||
func (c *probeCache) IsFresh(key string) bool {
|
||||
if c.ttl <= 0 {
|
||||
return false
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
last, ok := c.seen[key]
|
||||
return ok && time.Since(last) < c.ttl
|
||||
}
|
||||
|
||||
// markFresh records key as successfully probed at the current time.
|
||||
func (c *probeCache) markFresh(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.seen[key] = time.Now()
|
||||
}
|
||||
|
||||
// Invalidate drops any cached freshness for key. Used after a probe failure
|
||||
// (or any other signal that the backend may not be alive) so the next call
|
||||
// will re-probe instead of trusting stale state.
|
||||
func (c *probeCache) Invalidate(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.seen, key)
|
||||
}
|
||||
|
||||
// DoOrCached returns true if key is fresh; otherwise it runs probe (coalescing
|
||||
// concurrent callers via singleflight) and caches a successful result. Failed
|
||||
// probes invalidate the cache, so a transient miss doesn't pin every
|
||||
// subsequent request to a re-probe.
|
||||
func (c *probeCache) DoOrCached(key string, probe func() bool) bool {
|
||||
if c.IsFresh(key) {
|
||||
return true
|
||||
}
|
||||
v, _, _ := c.flight.Do(key, func() (any, error) {
|
||||
// Double-check after potentially waiting: another caller in this
|
||||
// flight may have just populated the cache.
|
||||
if c.IsFresh(key) {
|
||||
return true, nil
|
||||
}
|
||||
ok := probe()
|
||||
if ok {
|
||||
c.markFresh(key)
|
||||
} else {
|
||||
c.Invalidate(key)
|
||||
}
|
||||
return ok, nil
|
||||
})
|
||||
return v.(bool)
|
||||
}
|
||||
145
core/services/nodes/probe_cache_test.go
Normal file
145
core/services/nodes/probe_cache_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("probeCache", func() {
|
||||
It("invokes the probe on a cold cache and caches success", func() {
|
||||
c := newProbeCache(time.Minute)
|
||||
var calls int32
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return true
|
||||
}
|
||||
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
|
||||
// Cached: probe ran once.
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("re-probes after the TTL expires", func() {
|
||||
// 1 ms TTL means the second call is virtually guaranteed to see an
|
||||
// expired entry without flaking on scheduler jitter.
|
||||
c := newProbeCache(time.Millisecond)
|
||||
var calls int32
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return true
|
||||
}
|
||||
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
|
||||
})
|
||||
|
||||
It("does not cache failed probes — next call re-probes", func() {
|
||||
c := newProbeCache(time.Minute)
|
||||
var calls int32
|
||||
var result atomic.Bool
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return result.Load()
|
||||
}
|
||||
|
||||
// First probe fails — must NOT be cached.
|
||||
result.Store(false)
|
||||
Expect(c.DoOrCached("k", probe)).To(BeFalse())
|
||||
Expect(c.IsFresh("k")).To(BeFalse())
|
||||
|
||||
// Recover: second probe succeeds and is cached.
|
||||
result.Store(true)
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.IsFresh("k")).To(BeTrue())
|
||||
|
||||
// Third call short-circuits on the fresh entry.
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
|
||||
})
|
||||
|
||||
It("coalesces concurrent probes via singleflight", func() {
|
||||
// Models the "6 chat completions arrive simultaneously for a
|
||||
// not-yet-cached backend" scenario. Without singleflight every caller
|
||||
// would dial the backend, defeating the purpose of the cache.
|
||||
c := newProbeCache(time.Minute)
|
||||
var calls int32
|
||||
start := make(chan struct{})
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
// Stall briefly so the test reliably has all goroutines parked
|
||||
// inside flight.Do at the same time.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return true
|
||||
}
|
||||
|
||||
const N = 8
|
||||
var wg sync.WaitGroup
|
||||
results := make([]bool, N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
results[i] = c.DoOrCached("k", probe)
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(1)),
|
||||
"singleflight must collapse %d concurrent probes into one", N)
|
||||
for i, got := range results {
|
||||
Expect(got).To(BeTrue(), "goroutine %d saw a different result", i)
|
||||
}
|
||||
})
|
||||
|
||||
It("treats different keys independently", func() {
|
||||
c := newProbeCache(time.Minute)
|
||||
var aCalls, bCalls int32
|
||||
Expect(c.DoOrCached("a", func() bool { atomic.AddInt32(&aCalls, 1); return true })).To(BeTrue())
|
||||
Expect(c.DoOrCached("b", func() bool { atomic.AddInt32(&bCalls, 1); return true })).To(BeTrue())
|
||||
Expect(c.DoOrCached("a", func() bool { atomic.AddInt32(&aCalls, 1); return true })).To(BeTrue())
|
||||
|
||||
Expect(atomic.LoadInt32(&aCalls)).To(Equal(int32(1)))
|
||||
Expect(atomic.LoadInt32(&bCalls)).To(Equal(int32(1)))
|
||||
})
|
||||
|
||||
It("disables caching when TTL is zero", func() {
|
||||
c := newProbeCache(0)
|
||||
var calls int32
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return true
|
||||
}
|
||||
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(3)))
|
||||
})
|
||||
|
||||
It("Invalidate forces the next call to re-probe", func() {
|
||||
c := newProbeCache(time.Hour)
|
||||
var calls int32
|
||||
probe := func() bool {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return true
|
||||
}
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
c.Invalidate("k")
|
||||
Expect(c.DoOrCached("k", probe)).To(BeTrue())
|
||||
Expect(atomic.LoadInt32(&calls)).To(Equal(int32(2)))
|
||||
})
|
||||
})
|
||||
@@ -668,10 +668,21 @@ func (r *NodeRegistry) FindNodesWithModel(ctx context.Context, modelName string)
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// FindAndLockNodeWithModel atomically finds the least-loaded node with the given
|
||||
// model loaded and increments its in-flight counter within a single transaction.
|
||||
// The SELECT FOR UPDATE row lock prevents concurrent eviction from removing the
|
||||
// NodeModel row between the find and increment operations.
|
||||
// FindAndLockNodeWithModel atomically finds the best loaded replica of the
|
||||
// given model and increments its in-flight counter within a single
|
||||
// transaction. The SELECT FOR UPDATE row lock prevents concurrent eviction
|
||||
// from removing the NodeModel row between the find and increment operations,
|
||||
// and serializes contending routers so concurrent picks distribute across
|
||||
// replicas instead of all landing on the same row.
|
||||
//
|
||||
// **Policy:** the SQL ORDER BY below MUST mirror PickBestReplica
|
||||
// (replicapicker.go). PickBestReplica is the canonical Go implementation of
|
||||
// the same rule — the per-frontend rotating-replica cache (TODO, see
|
||||
// pkg/model/loader.go) will eventually use it against in-memory snapshots so
|
||||
// hot inference requests don't pay this DB round-trip. If you change the
|
||||
// ordering here, change both sides; the TestFindAndLockNodeWithModelMirror
|
||||
// spec ("agrees with PickBestReplica on a seeded dataset") fails fast if they
|
||||
// drift.
|
||||
//
|
||||
// When candidateNodeIDs is non-empty, only nodes in that set are considered.
|
||||
// Pass nil (or empty) to consider any node. This lets callers pre-filter by
|
||||
@@ -683,16 +694,16 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
|
||||
var node BackendNode
|
||||
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// Order by in_flight ASC (least busy replica), then by last_used ASC
|
||||
// (round-robin between equally-loaded replicas — oldest used wins, and
|
||||
// every successful pick refreshes last_used below, so the "oldest" naturally
|
||||
// rotates through the candidate set). available_vram DESC is the final
|
||||
// tiebreaker for cold starts where last_used is identical.
|
||||
// Mirror of PickBestReplica's policy (see replicapicker.go):
|
||||
// 1. in_flight ASC — least busy replica.
|
||||
// 2. last_used ASC — round-robin between equally-loaded replicas.
|
||||
// Every successful pick refreshes last_used below, so the
|
||||
// "oldest" tier naturally rotates through the candidate set.
|
||||
// Without this tier, in_flight ties collapsed to "fattest GPU
|
||||
// wins every time" and one node took nearly all the load.
|
||||
// 3. available_vram DESC — final tiebreaker for cold starts where
|
||||
// last_used is identical across replicas.
|
||||
//
|
||||
// Without the last_used tier, a tie on in_flight (the common case at low
|
||||
// to moderate concurrency where requests don't overlap) collapses to
|
||||
// "biggest GPU wins every time" and one node ends up taking nearly all
|
||||
// the load while replicas on other nodes sit idle.
|
||||
// Filter on backend_nodes.status = healthy in the inner JOIN itself,
|
||||
// not only in the later node-fetch step. The previous version picked
|
||||
// a (node_id, replica) pair purely on node_models state, then bailed
|
||||
|
||||
@@ -3,6 +3,7 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -357,6 +358,79 @@ var _ = Describe("NodeRegistry", func() {
|
||||
_, _, err := registry.FindAndLockNodeWithModel(context.Background(), "no-match-model", []string{emptyIncluded.ID})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("agrees with PickBestReplica on a seeded dataset (policy mirror)", func() {
|
||||
// Guard against drift between the SQL ORDER BY in
|
||||
// FindAndLockNodeWithModel and the canonical Go implementation in
|
||||
// PickBestReplica. The two layers will eventually diverge in
|
||||
// caller (DB-backed atomic pick vs in-memory snapshot pick for the
|
||||
// per-frontend rotating cache), but the policy itself must stay
|
||||
// the single source of truth. If this test fails, update *both*
|
||||
// sides — never just one.
|
||||
//
|
||||
// Scenario exercises all three tiers:
|
||||
// - "loser-busy" has the most VRAM but in_flight=2 — loses tier 1.
|
||||
// - "loser-recent" ties at in_flight=0 but its last_used is the
|
||||
// newest of the in_flight=0 group — loses tier 2.
|
||||
// - "winner-mid" and "winner-fat" both tie at in_flight=0 and
|
||||
// share the oldest last_used — tier 3 decides: fattest wins.
|
||||
loserBusy := makeNode("mirror-loser-busy", "10.0.0.70:50051", 32_000_000_000)
|
||||
loserRecent := makeNode("mirror-loser-recent", "10.0.0.71:50051", 8_000_000_000)
|
||||
winnerMid := makeNode("mirror-winner-mid", "10.0.0.72:50051", 16_000_000_000)
|
||||
winnerFat := makeNode("mirror-winner-fat", "10.0.0.73:50051", 24_000_000_000)
|
||||
for _, n := range []*BackendNode{loserBusy, loserRecent, winnerMid, winnerFat} {
|
||||
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
|
||||
Expect(registry.SetNodeModel(context.Background(), n.ID, "mirror-model", 0, "loaded", "", 0)).To(Succeed())
|
||||
}
|
||||
|
||||
// Force in_flight=2 on the "busy" node so tier 1 disqualifies it.
|
||||
Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed())
|
||||
Expect(registry.IncrementInFlight(context.Background(), loserBusy.ID, "mirror-model", 0)).To(Succeed())
|
||||
|
||||
// Slam last_used to known values so the test is deterministic
|
||||
// regardless of clock resolution between the helpers above.
|
||||
base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
set := func(id string, t time.Time) {
|
||||
Expect(db.Model(&NodeModel{}).
|
||||
Where("node_id = ? AND model_name = ?", id, "mirror-model").
|
||||
Update("last_used", t).Error).To(Succeed())
|
||||
}
|
||||
set(loserBusy.ID, base) // newest doesn't matter — already disqualified by tier 1
|
||||
set(loserRecent.ID, base.Add(time.Hour))
|
||||
set(winnerMid.ID, base)
|
||||
set(winnerFat.ID, base)
|
||||
|
||||
// Pull the same dataset both pickers will operate on. The Go
|
||||
// picker is a faithful representation of the policy; the SQL is
|
||||
// the production path.
|
||||
var rows []NodeModel
|
||||
Expect(db.Where("model_name = ? AND state = ?", "mirror-model", "loaded").
|
||||
Find(&rows).Error).To(Succeed())
|
||||
candidates := make([]ReplicaCandidate, 0, len(rows))
|
||||
for _, nm := range rows {
|
||||
var bn BackendNode
|
||||
Expect(db.First(&bn, "id = ? AND status = ?", nm.NodeID, StatusHealthy).Error).To(Succeed())
|
||||
candidates = append(candidates, ReplicaCandidate{
|
||||
NodeID: nm.NodeID,
|
||||
Address: bn.Address,
|
||||
ReplicaIndex: nm.ReplicaIndex,
|
||||
InFlight: nm.InFlight,
|
||||
LastUsed: nm.LastUsed,
|
||||
AvailableVRAM: bn.AvailableVRAM,
|
||||
})
|
||||
}
|
||||
goPick := PickBestReplica(candidates)
|
||||
Expect(goPick).ToNot(BeNil())
|
||||
|
||||
sqlNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "mirror-model", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(sqlNode.ID).To(Equal(goPick.NodeID),
|
||||
"SQL ORDER BY picked %s; PickBestReplica picked %s — policy has drifted",
|
||||
sqlNode.ID, goPick.NodeID)
|
||||
// Sanity check: the policy says winner-fat wins on tier 3.
|
||||
Expect(goPick.NodeID).To(Equal(winnerFat.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("MarkHealthy and MarkUnhealthy round-trip", func() {
|
||||
|
||||
69
core/services/nodes/replicapicker.go
Normal file
69
core/services/nodes/replicapicker.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package nodes
|
||||
|
||||
import "time"
|
||||
|
||||
// ReplicaCandidate is the minimum view of a loaded model replica needed to
|
||||
// apply the routing policy. It is intentionally decoupled from the gorm models
|
||||
// (BackendNode, NodeModel) so the same picker can run against fresh DB rows
|
||||
// (SmartRouter.Route → FindAndLockNodeWithModel) and against an in-memory
|
||||
// snapshot (the per-frontend rotating cache flagged in pkg/model — see TODO
|
||||
// below).
|
||||
type ReplicaCandidate struct {
|
||||
NodeID string
|
||||
Address string
|
||||
ReplicaIndex int
|
||||
InFlight int
|
||||
LastUsed time.Time
|
||||
AvailableVRAM uint64
|
||||
}
|
||||
|
||||
// PickBestReplica is the single source of truth for which loaded replica of a
|
||||
// model serves the next request.
|
||||
//
|
||||
// Policy (ordered tiers, first non-tie wins):
|
||||
// 1. Least in-flight wins — primary load-balancing signal.
|
||||
// 2. Oldest last_used wins — round-robin between equally-loaded replicas.
|
||||
// Every successful pick refreshes last_used (in FindAndLockNodeWithModel's
|
||||
// transaction and in TouchNodeModel on cache hits), so the "oldest" tier
|
||||
// naturally rotates through the candidate set without a separate cursor.
|
||||
// 3. Largest available_vram wins — cold-start tiebreaker for replicas that
|
||||
// have never been picked (identical last_used).
|
||||
//
|
||||
// Two callers must agree on this policy:
|
||||
//
|
||||
// - SmartRouter.Route, via the SQL ORDER BY in FindAndLockNodeWithModel
|
||||
// (registry.go). That query MUST mirror this function — TestPickerSQLMirror
|
||||
// asserts both sides agree on a representative dataset.
|
||||
//
|
||||
// - The per-frontend rotating-replica cache (NOT YET IMPLEMENTED — see
|
||||
// pkg/model/loader.go and pkg/model/initializers.go for the integration
|
||||
// point). When that cache lands, it will call PickBestReplica against an
|
||||
// in-memory snapshot using locally-tracked in-flight counters and skip the
|
||||
// per-request DB round-trip.
|
||||
//
|
||||
// Returns nil when the candidate list is empty. Does not allocate.
|
||||
func PickBestReplica(candidates []ReplicaCandidate) *ReplicaCandidate {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
best := &candidates[0]
|
||||
for i := 1; i < len(candidates); i++ {
|
||||
c := &candidates[i]
|
||||
if betterReplica(c, best) {
|
||||
best = c
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
// betterReplica reports whether candidate a is preferred over candidate b
|
||||
// under the policy documented on PickBestReplica.
|
||||
func betterReplica(a, b *ReplicaCandidate) bool {
|
||||
if a.InFlight != b.InFlight {
|
||||
return a.InFlight < b.InFlight
|
||||
}
|
||||
if !a.LastUsed.Equal(b.LastUsed) {
|
||||
return a.LastUsed.Before(b.LastUsed)
|
||||
}
|
||||
return a.AvailableVRAM > b.AvailableVRAM
|
||||
}
|
||||
81
core/services/nodes/replicapicker_test.go
Normal file
81
core/services/nodes/replicapicker_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("PickBestReplica", func() {
|
||||
// Use a single reference time so every test that wants identical
|
||||
// last_used can share it without relying on time.Now() interleavings.
|
||||
ref := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
It("returns nil for an empty candidate list", func() {
|
||||
Expect(PickBestReplica(nil)).To(BeNil())
|
||||
Expect(PickBestReplica([]ReplicaCandidate{})).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns the only candidate when there is just one", func() {
|
||||
only := ReplicaCandidate{NodeID: "only", InFlight: 99, LastUsed: ref, AvailableVRAM: 1}
|
||||
pick := PickBestReplica([]ReplicaCandidate{only})
|
||||
Expect(pick).ToNot(BeNil())
|
||||
Expect(pick.NodeID).To(Equal("only"))
|
||||
})
|
||||
|
||||
It("prefers the replica with the lowest in_flight", func() {
|
||||
// Without the in-flight tier, the larger-VRAM node would win.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "busy-big", InFlight: 3, LastUsed: ref, AvailableVRAM: 24_000_000_000},
|
||||
{NodeID: "idle-small", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "mid", InFlight: 1, LastUsed: ref, AvailableVRAM: 16_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("idle-small"))
|
||||
})
|
||||
|
||||
It("uses oldest last_used as the tiebreaker when in_flight ties", func() {
|
||||
// All three tied on in_flight=0. Without last_used, available_vram
|
||||
// would pin every pick to the fattest node — the exact bug
|
||||
// fix(distributed): round-robin replicas of the same model addressed.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "fat-recent", InFlight: 0, LastUsed: ref.Add(2 * time.Second), AvailableVRAM: 24_000_000_000},
|
||||
{NodeID: "small-oldest", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "mid-middle", InFlight: 0, LastUsed: ref.Add(1 * time.Second), AvailableVRAM: 16_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("small-oldest"))
|
||||
})
|
||||
|
||||
It("uses largest available_vram as the final tiebreaker", func() {
|
||||
// in_flight tied AND last_used tied — pick the largest GPU.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "small", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "fat", InFlight: 0, LastUsed: ref, AvailableVRAM: 24_000_000_000},
|
||||
{NodeID: "mid", InFlight: 0, LastUsed: ref, AvailableVRAM: 16_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("fat"))
|
||||
})
|
||||
|
||||
It("respects tier precedence: in_flight beats last_used beats available_vram", func() {
|
||||
// "fat-busy-oldest" wins on neither of the first two tiers; the
|
||||
// "small-idle-recent" replica is busy=0 and should beat it despite
|
||||
// being newer and smaller.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "fat-busy-oldest", InFlight: 5, LastUsed: ref, AvailableVRAM: 80_000_000_000},
|
||||
{NodeID: "small-idle-recent", InFlight: 0, LastUsed: ref.Add(time.Hour), AvailableVRAM: 4_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("small-idle-recent"))
|
||||
})
|
||||
|
||||
It("is stable: returns the first candidate when every field ties", func() {
|
||||
// betterReplica returns false on a full tie, so the leading element
|
||||
// remains best. Callers shouldn't depend on this for correctness,
|
||||
// but pinning the behavior here catches accidental reorderings.
|
||||
cs := []ReplicaCandidate{
|
||||
{NodeID: "first", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "second", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
{NodeID: "third", InFlight: 0, LastUsed: ref, AvailableVRAM: 8_000_000_000},
|
||||
}
|
||||
Expect(PickBestReplica(cs).NodeID).To(Equal("first"))
|
||||
})
|
||||
})
|
||||
@@ -61,8 +61,19 @@ type SmartRouter struct {
|
||||
// completions for one not-yet-loaded model produce ONE round-trip, not
|
||||
// six. Avoids amplifying head-of-line blocking on the worker side.
|
||||
installFlight singleflight.Group
|
||||
// probeCache memoizes recent successful gRPC HealthCheck results so
|
||||
// per-request routing doesn't stall behind a busy backend's serialized
|
||||
// HealthCheck/Predict. See probe_cache.go for the rationale.
|
||||
probeCache *probeCache
|
||||
}
|
||||
|
||||
// probeCacheTTL is how long a successful gRPC HealthCheck on a backend is
|
||||
// trusted before the next request re-probes. Matches healthCheckTTL in
|
||||
// pkg/model/model.go so the single-process and distributed paths share a
|
||||
// staleness budget. The background HealthMonitor still reaps dead backends
|
||||
// independently within ~45s (see perModelMissThreshold).
|
||||
const probeCacheTTL = 30 * time.Second
|
||||
|
||||
// NewSmartRouter creates a new SmartRouter backed by the given ModelRouter.
|
||||
// All optional dependencies are passed via SmartRouterOptions to avoid post-creation races.
|
||||
func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter {
|
||||
@@ -79,6 +90,7 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter
|
||||
db: opts.DB,
|
||||
stagingTracker: NewStagingTracker(),
|
||||
conflictResolver: opts.ConflictResolver,
|
||||
probeCache: newProbeCache(probeCacheTTL),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,14 +973,26 @@ func (r *SmartRouter) stageGenericOptions(ctx context.Context, node *BackendNode
|
||||
}
|
||||
|
||||
// probeHealth checks whether a backend process on the given node/addr is alive
|
||||
// via a gRPC health check with a 2-second timeout. The client is closed after the check.
|
||||
// via a gRPC health check with a 2-second timeout. The client is closed after
|
||||
// the check.
|
||||
//
|
||||
// The result is memoized in r.probeCache for probeCacheTTL. With per-request
|
||||
// routing every inference call lands here, and unbounded re-probing can stall
|
||||
// behind a busy backend that serializes HealthCheck against active Predict.
|
||||
// Concurrent probes for the same (node, addr) coalesce via singleflight so a
|
||||
// burst of N requests for a cold cache costs at most one round-trip, not N.
|
||||
// Failed probes invalidate the cache so the staleness recovery path
|
||||
// (DecrementInFlight + RemoveNodeModel) still triggers on the next request.
|
||||
func (r *SmartRouter) probeHealth(ctx context.Context, node *BackendNode, addr string) bool {
|
||||
client := r.buildClientForAddr(node, addr, false)
|
||||
defer closeClient(client)
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
ok, _ := client.HealthCheck(checkCtx)
|
||||
return ok
|
||||
key := node.ID + "|" + addr
|
||||
return r.probeCache.DoOrCached(key, func() bool {
|
||||
client := r.buildClientForAddr(node, addr, false)
|
||||
defer closeClient(client)
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
ok, _ := client.HealthCheck(checkCtx)
|
||||
return ok
|
||||
})
|
||||
}
|
||||
|
||||
// closeClient closes a gRPC backend client if it implements io.Closer.
|
||||
|
||||
@@ -276,6 +276,37 @@ func (ml *ModelLoader) updateModelLastUsed(m *Model) {
|
||||
func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
|
||||
o := NewOptions(opts...)
|
||||
|
||||
ml.mu.Lock()
|
||||
distributed := ml.modelRouter != nil
|
||||
ml.mu.Unlock()
|
||||
|
||||
// In distributed mode, SmartRouter must run per inference request so
|
||||
// PickBestReplica (core/services/nodes/replicapicker.go) picks the
|
||||
// least-loaded replica each time. Bypass the local cache and the local
|
||||
// LRU / concurrency-group watchdog enforcement: both are scoped to the
|
||||
// in-process Model store, which in distributed mode only holds stubs for
|
||||
// remote replicas. SmartRouter handles cluster-wide eviction
|
||||
// (evictLRUAndFreeNode) and concurrency-group anti-affinity
|
||||
// (narrowByGroupAntiAffinity) at the scheduler layer.
|
||||
//
|
||||
// TODO(distributed-cache): see LoadModel for the rotating-replica-cache
|
||||
// integration point that would let hot paths skip the per-request DB
|
||||
// round-trip without giving up the shared PickBestReplica policy.
|
||||
if distributed {
|
||||
client, err := ml.backendLoader(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m := ml.CheckIsLoaded(o.modelID); m != nil && m.Process() == nil {
|
||||
client = newConnectionEvictingClient(client, o.modelID, func() {
|
||||
if err := ml.ShutdownModel(o.modelID); err != nil {
|
||||
xlog.Warn("Failed to shut down remote model after connection error", "model", o.modelID, "error", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Return earlier if we have a model already loaded
|
||||
// (avoid looping through all the backends)
|
||||
if m := ml.CheckIsLoaded(o.modelID); m != nil {
|
||||
|
||||
@@ -250,6 +250,49 @@ func (ml *ModelLoader) ListLoadedModels() []*Model {
|
||||
}
|
||||
|
||||
func (ml *ModelLoader) LoadModel(modelID, modelName string, loader func(string, string, string) (*Model, error)) (*Model, error) {
|
||||
ml.mu.Lock()
|
||||
distributed := ml.modelRouter != nil
|
||||
ml.mu.Unlock()
|
||||
|
||||
if distributed {
|
||||
// Distributed mode: SmartRouter must run per inference request so
|
||||
// PickBestReplica (core/services/nodes/replicapicker.go) picks the
|
||||
// least-loaded replica each time. The cached *Model returned from a
|
||||
// previous call holds a client wrapper bound to one (nodeID,
|
||||
// replicaIndex), so reusing it pins every subsequent request to the
|
||||
// node that won the very first pick — defeating per-replica load
|
||||
// balancing. Bypass the cache and the loading-coalesce map; the
|
||||
// router does its own coalescing for first-time loads (advisory DB
|
||||
// lock + singleflight on backend.install RPC), so concurrent first
|
||||
// requests still produce a single worker-side install.
|
||||
//
|
||||
// TODO(distributed-cache): if profiling shows the per-request
|
||||
// FindAndLockNodeWithModel SELECT FOR UPDATE becomes a hot path
|
||||
// under burst load, replace this branch with a per-modelID cache
|
||||
// that holds a *list* of replicas (refreshed every ~5s in
|
||||
// background) and picks per call via PickBestReplica against
|
||||
// locally-tracked in-flight counters. Same policy, no DB round-trip
|
||||
// per inference. Trade-off: cross-frontend in-flight visibility
|
||||
// becomes eventually consistent, acceptable for 1-3 frontend
|
||||
// deployments.
|
||||
modelFile := filepath.Join(ml.ModelPath, modelName)
|
||||
model, err := loader(modelID, modelName, modelFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to route model with internal loader: %s", err)
|
||||
}
|
||||
if model == nil {
|
||||
return nil, fmt.Errorf("loader didn't return a model")
|
||||
}
|
||||
// Record the latest mapping so DistributedModelStore.Range, shutdown,
|
||||
// and listing endpoints see a representative entry. The DB is the
|
||||
// source of truth for cluster-wide state; the local store is just a
|
||||
// stub for in-process callers.
|
||||
ml.mu.Lock()
|
||||
ml.store.Set(modelID, model)
|
||||
ml.mu.Unlock()
|
||||
return model, nil
|
||||
}
|
||||
|
||||
ml.mu.Lock()
|
||||
|
||||
// Check if we already have a loaded model
|
||||
|
||||
Reference in New Issue
Block a user