Files
LocalAI/core/services/nodes/router_test.go
LocalAI [bot] 29001a88c1 fix(distributed): don't let a dead worker pin the model-load advisory lock (#10600)
* fix(distributed): don't let a dead worker pin the model-load advisory lock

In distributed mode a chat request could fail with:

  failed to route model with internal loader: routing model ...:
  loading model ...: advisorylock: acquiring lock <id>:
  ERROR: canceling statement due to lock timeout (SQLSTATE 55P03)

Root cause is two independent defects in the cross-replica model-load path:

1. SmartRouter.Route holds a per-model PostgreSQL advisory lock for the whole
   cold-load sequence, which includes installBackendOnNode -> InstallBackend,
   a NATS request-reply with a 15m deadline (DefaultBackendInstallTimeout) that
   ignored ctx. When the chosen worker died mid-install, the holder sat on the
   lock for up to 15m. The detached loadCtx (WithoutCancel) had no deadline, so
   nothing capped the hold.

2. The acquiring statement, pg_advisory_lock(), is subject to any deployment
   global lock_timeout. A common operator setting (e.g. 10s) aborts the wait
   with SQLSTATE 55P03, so every other replica's request for that model hard
   -errored instead of waiting for the in-progress load and reusing it. For the
   ~15m window the model was effectively unroutable.

Fixes:

- advisorylock.WithLockCtx (postgres): SET lock_timeout = 0 on its dedicated
  connection (RESET before it returns to the pool) so the Go context, not a
  deployment-wide GUC, governs how long we wait. Waiters now block and then
  re-check, reusing the model another replica just loaded.

- SmartRouter: bound the detached loadCtx with a single ModelLoadCeiling so the
  lock is always released in bounded time even if a sub-step wedges. Default is
  the configured backend.install deadline + 10m (staging + LoadModel margin),
  so a legitimately slow load is never cut.

- installBackendOnNode: use singleflight.DoChan + select on ctx.Done() so the
  install wait honors cancellation; the ceiling can then actually free a caller
  pinned behind a dead worker. The shared install still coalesces via
  singleflight.

Reproduced both defects as failing tests first (a real 55P03 against a
testcontainer with a short lock_timeout; a wedged install that blocks Route)
and confirmed green.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

* fix(distributed): bound advisory-lock wait instead of disabling lock_timeout

Setting lock_timeout = 0 to override a deployment's short global lock_timeout
meant "wait forever" server-side. Safe for SmartRouter.Route (its loadCtx now
carries the model-load ceiling) but unsafe for the schema-migration callers
that pass context.Background(): a holder whose session never releases would
hang them indefinitely.

Derive the server-side lock_timeout from the caller's context instead: its
remaining budget plus a margin (so the Go context's cancellation still wins
with a clean error and the server bound is only a backstop), or a finite
30m backstop when the context has no deadline. Never zero - "wait forever"
is no longer possible, while a deployment's hostile short lock_timeout is
still overridden so legitimate cross-replica waits don't fail with 55P03.

Added a spec proving a deadline-less waiter gives up at the (shrunk) backstop
rather than hanging.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-07-02 09:52:51 +02:00

1466 lines
52 KiB
Go

package nodes
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
"github.com/mudler/LocalAI/core/services/testutil"
"github.com/mudler/LocalAI/pkg/distributedhdr"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
ggrpc "google.golang.org/grpc"
"gorm.io/gorm"
)
// ---------------------------------------------------------------------------
// Fake FileStager (pre-existing)
// ---------------------------------------------------------------------------
// fakeFileStager is a minimal FileStager that records calls and returns
// predictable remote paths without touching the filesystem or network.
type fakeFileStager struct {
ensureCalls []ensureCall
}
type ensureCall struct {
nodeID, localPath, key string
}
func (f *fakeFileStager) EnsureRemote(_ context.Context, nodeID, localPath, key string) (string, error) {
f.ensureCalls = append(f.ensureCalls, ensureCall{nodeID, localPath, key})
return "/remote/" + key, nil
}
func (f *fakeFileStager) FetchRemote(_ context.Context, _, _, _ string) error { return nil }
func (f *fakeFileStager) FetchRemoteByKey(_ context.Context, _, _, _ string) error { return nil }
func (f *fakeFileStager) AllocRemoteTemp(_ context.Context, _ string) (string, error) {
return "/remote/tmp", nil
}
func (f *fakeFileStager) StageRemoteToStore(_ context.Context, _, _, _ string) error { return nil }
func (f *fakeFileStager) ListRemoteDir(_ context.Context, _, _ string) ([]string, error) {
return nil, nil
}
// ---------------------------------------------------------------------------
// Fake ModelRouter
// ---------------------------------------------------------------------------
// fakeModelRouter implements ModelRouter with configurable return values.
type fakeModelRouter struct {
// FindAndLockNodeWithModel returns
findAndLockNode *BackendNode
findAndLockNM *NodeModel
findAndLockErr error
// FindNodeWithVRAM returns
findVRAMNode *BackendNode
findVRAMErr error
// FindIdleNode returns
findIdleNode *BackendNode
findIdleErr error
// FindLeastLoadedNode returns
findLeastLoadedNode *BackendNode
findLeastLoadedErr error
// FindGlobalLRUModelWithZeroInFlight returns
findGlobalLRUModel *NodeModel
findGlobalLRUErr error
// FindLRUModel returns
findLRUModel *NodeModel
findLRUErr error
// Get returns
getNode *BackendNode
getErr error
// GetModelScheduling returns
getModelScheduling *ModelSchedulingConfig
getModelSchedErr error
// FindNodesBySelector returns
findBySelectorNodes []BackendNode
findBySelectorErr error
// *FromSet variants
findVRAMFromSetNode *BackendNode
findVRAMFromSetErr error
findIdleFromSetNode *BackendNode
findIdleFromSetErr error
findLeastLoadedFromSetNode *BackendNode
findLeastLoadedFromSetErr error
// GetNodeLabels returns
getNodeLabels []NodeLabel
getNodeLabelsErr error
// FindNodesWithModel returns (keyed by model name)
findNodesWithModelByName map[string][]BackendNode
findNodesWithModelErr error
// LoadedReplicaStats returns (keyed by model name)
loadedReplicaStatsByName map[string][]ReplicaCandidate
loadedReplicaStatsErr error
// Track calls for assertions
decrementCalls []string // "nodeID:modelName"
incrementCalls []string
removeCalls []string
setCalls []string
touchCalls []string
// Preferences passed to FindAndLockNodeWithModel, in call order. nil
// entries are recorded too, so tests can assert "preference was nil".
findAndLockPrefs []*RoutePreference
}
func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) {
f.findAndLockPrefs = append(f.findAndLockPrefs, pref)
return f.findAndLockNode, f.findAndLockNM, f.findAndLockErr
}
func (f *fakeModelRouter) LoadedReplicaStats(_ context.Context, modelName string, _ []string) ([]ReplicaCandidate, error) {
if f.loadedReplicaStatsErr != nil {
return nil, f.loadedReplicaStatsErr
}
return f.loadedReplicaStatsByName[modelName], nil
}
func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
f.decrementCalls = append(f.decrementCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) IncrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
f.incrementCalls = append(f.incrementCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) RemoveNodeModel(_ context.Context, nodeID, modelName string, _ int) error {
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) RemoveAllNodeModelReplicas(_ context.Context, nodeID, modelName string) error {
// Same recorded key as RemoveNodeModel so existing tests that assert "the
// model was removed" don't need to know whether the production code used
// the per-replica or all-replicas variant.
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) TouchNodeModel(_ context.Context, nodeID, modelName string, _ int) {
f.touchCalls = append(f.touchCalls, nodeID+":"+modelName)
}
func (f *fakeModelRouter) SetNodeModel(_ context.Context, nodeID, modelName string, _ int, state, address string, _ int) error {
f.setCalls = append(f.setCalls, fmt.Sprintf("%s:%s:%s:%s", nodeID, modelName, state, address))
return nil
}
func (f *fakeModelRouter) SetNodeModelLoadInfo(_ context.Context, _, _ string, _ int, _ string, _ []byte) error {
return nil
}
func (f *fakeModelRouter) UpsertModelLoadInfo(_ context.Context, _, _ string, _ []byte) error {
return nil
}
func (f *fakeModelRouter) GetModelLoadInfo(_ context.Context, _ string) (string, []byte, error) {
return "", nil, fmt.Errorf("not found")
}
func (f *fakeModelRouter) NextFreeReplicaIndex(_ context.Context, _, _ string, _ int) (int, error) {
return 0, nil
}
func (f *fakeModelRouter) CountReplicasOnNode(_ context.Context, _, _ string) (int, error) {
return 0, nil
}
func (f *fakeModelRouter) FindNodeWithVRAM(_ context.Context, _ uint64) (*BackendNode, error) {
return f.findVRAMNode, f.findVRAMErr
}
func (f *fakeModelRouter) FindIdleNode(_ context.Context) (*BackendNode, error) {
return f.findIdleNode, f.findIdleErr
}
func (f *fakeModelRouter) FindLeastLoadedNode(_ context.Context) (*BackendNode, error) {
return f.findLeastLoadedNode, f.findLeastLoadedErr
}
func (f *fakeModelRouter) FindGlobalLRUModelWithZeroInFlight(_ context.Context) (*NodeModel, error) {
return f.findGlobalLRUModel, f.findGlobalLRUErr
}
func (f *fakeModelRouter) FindLRUModel(_ context.Context, _ string) (*NodeModel, error) {
return f.findLRUModel, f.findLRUErr
}
func (f *fakeModelRouter) Get(_ context.Context, _ string) (*BackendNode, error) {
return f.getNode, f.getErr
}
func (f *fakeModelRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
return f.getModelScheduling, f.getModelSchedErr
}
func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
return f.findBySelectorNodes, f.findBySelectorErr
}
func (f *fakeModelRouter) FindNodesWithFreeSlot(_ context.Context, _ string, _ []string) ([]BackendNode, error) {
// Default: same answer as FindNodesBySelector. Tests that need a
// specific filter can override by reusing findBySelectorNodes.
return f.findBySelectorNodes, f.findBySelectorErr
}
func (f *fakeModelRouter) ReserveVRAM(_ context.Context, _ string, _ uint64) error {
return nil
}
func (f *fakeModelRouter) ReleaseVRAM(_ context.Context, _ string, _ uint64) error {
return nil
}
func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
return f.findVRAMFromSetNode, f.findVRAMFromSetErr
}
func (f *fakeModelRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
return f.findIdleFromSetNode, f.findIdleFromSetErr
}
func (f *fakeModelRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
return f.findLeastLoadedFromSetNode, f.findLeastLoadedFromSetErr
}
func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
return f.getNodeLabels, f.getNodeLabelsErr
}
func (f *fakeModelRouter) FindNodesWithModel(_ context.Context, modelName string) ([]BackendNode, error) {
if f.findNodesWithModelErr != nil {
return nil, f.findNodesWithModelErr
}
return f.findNodesWithModelByName[modelName], nil
}
// fakeConflictResolver implements ConcurrencyConflictResolver from a static map.
type fakeConflictResolver struct {
conflicts map[string][]string
}
func (f *fakeConflictResolver) GetModelsConflictingWith(name string) []string {
if f == nil {
return nil
}
return f.conflicts[name]
}
// ---------------------------------------------------------------------------
// Fake BackendClientFactory + Backend
// ---------------------------------------------------------------------------
// stubBackend implements grpc.Backend with configurable HealthCheck and LoadModel.
type stubBackend struct {
grpc.Backend // embed to satisfy interface; unused methods will panic if called
healthResult bool
healthErr error
loadResult *pb.Result
loadErr error
}
func (f *stubBackend) HealthCheck(_ context.Context) (bool, error) {
return f.healthResult, f.healthErr
}
func (f *stubBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return f.loadResult, f.loadErr
}
func (f *stubBackend) IsBusy() bool { return false }
// stubClientFactory returns the same stubBackend for every call.
type stubClientFactory struct {
client *stubBackend
}
func (f *stubClientFactory) NewClient(_ string, _ bool) grpc.Backend {
return f.client
}
// ---------------------------------------------------------------------------
// Fake NodeCommandSender (unloader)
// ---------------------------------------------------------------------------
type fakeUnloader struct {
// mu guards installCalls and upgradeCalls so concurrent test
// goroutines (e.g. singleflight specs) don't race the slice appends.
mu sync.Mutex
installReply *messaging.BackendInstallReply
installErr error
installCalls []installCall // every InstallBackend invocation, in order
// installHook, if non-nil, runs at the start of InstallBackend before
// the call is recorded. Used by concurrency tests as a deterministic
// "block here" seam — set installHook to a function that sleeps or
// blocks on a channel to overlap two callers.
installHook func()
upgradeReply *messaging.BackendUpgradeReply
upgradeErr error
upgradeCalls []upgradeCall // every UpgradeBackend invocation, in order
stopCalls []string // "nodeID:model"
stopErr error
unloadCalls []string
unloadErr error
}
// installCall captures the args we care about when asserting that the
// reconciler / router did or did not fire a NATS install. The fake records
// every call so tests can verify both presence and shape (e.g. that backend
// is non-empty).
type installCall struct {
nodeID string
backend string
modelID string
replica int
}
type upgradeCall struct {
nodeID string
backend string
replica int
}
func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
// installHook intentionally runs OUTSIDE the mutex: the hook may block
// on a channel and we don't want to serialize concurrent callers,
// which would defeat the singleflight-overlap test.
if f.installHook != nil {
f.installHook()
}
f.mu.Lock()
f.installCalls = append(f.installCalls, installCall{nodeID, backend, modelID, replica})
f.mu.Unlock()
return f.installReply, f.installErr
}
func (f *fakeUnloader) UpgradeBackend(nodeID, backend, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
f.mu.Lock()
f.upgradeCalls = append(f.upgradeCalls, upgradeCall{nodeID, backend, replica})
f.mu.Unlock()
return f.upgradeReply, f.upgradeErr
}
func (f *fakeUnloader) DeleteBackend(_, _ string) (*messaging.BackendDeleteReply, error) {
return &messaging.BackendDeleteReply{Success: true}, nil
}
func (f *fakeUnloader) ListBackends(_ string) (*messaging.BackendListReply, error) {
return &messaging.BackendListReply{}, nil
}
func (f *fakeUnloader) StopBackend(nodeID, backend string) error {
f.stopCalls = append(f.stopCalls, nodeID+":"+backend)
return f.stopErr
}
func (f *fakeUnloader) UnloadModelOnNode(nodeID, modelName string) error {
f.unloadCalls = append(f.unloadCalls, nodeID+":"+modelName)
return f.unloadErr
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
var _ = Describe("SmartRouter", func() {
// -----------------------------------------------------------------------
// Unit tests using mock interfaces (no DB required)
// -----------------------------------------------------------------------
Describe("Route (mock-based)", func() {
var (
reg *fakeModelRouter
backend *stubBackend
factory *stubClientFactory
unloader *fakeUnloader
)
BeforeEach(func() {
reg = &fakeModelRouter{}
backend = &stubBackend{}
factory = &stubClientFactory{client: backend}
unloader = &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.1:9001",
},
}
})
Context("model already loaded on a healthy node", func() {
It("returns the client and a release function", func() {
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
nm := &NodeModel{NodeID: "n1", ModelName: "my-model", Address: "10.0.0.1:9001"}
reg.findAndLockNode = node
reg.findAndLockNM = nm
backend.healthResult = true
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "my-model", "models/my-model.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("n1"))
// TouchNodeModel should have been called
Expect(reg.touchCalls).To(ContainElement("n1:my-model"))
// The initial in-flight reservation from FindAndLockNodeWithModel is released
// after the first inference call completes via OnFirstComplete callback.
// Release only closes the client.
result.Release()
// No decrement on Release — it happens via OnFirstComplete after first Predict
Expect(reg.decrementCalls).To(BeEmpty())
})
})
Context("model not loaded, falls through to scheduling", func() {
It("schedules on an idle node and records the model", func() {
// FindAndLockNodeWithModel always fails — simulates no cached model
// (equivalent to the health-check-failure fallthrough path).
idleNode := &BackendNode{ID: "n2", Name: "idle-node", Address: "10.0.0.2:50051"}
reg2 := &fakeModelRouter{
findAndLockErr: errors.New("not found"),
findIdleNode: idleNode,
}
backend.loadResult = &pb.Result{Success: true}
router := NewSmartRouter(reg2, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "some-model", "models/some-model.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("n2"))
// SetNodeModel should record the model as loaded on the node
Expect(reg2.setCalls).To(HaveLen(1))
Expect(reg2.setCalls[0]).To(ContainSubstring("n2:some-model:loaded"))
})
})
Context("model not loaded, no DB (advisory lock bypassed)", func() {
It("schedules on an available node via FindIdleNode", func() {
reg.findAndLockErr = errors.New("not found")
idleNode := &BackendNode{ID: "n3", Name: "idle", Address: "10.0.0.3:50051"}
reg.findIdleNode = idleNode
backend.loadResult = &pb.Result{Success: true}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
// DB is nil — no advisory lock
})
result, err := router.Route(context.Background(), "new-model", "models/new.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.ID).To(Equal("n3"))
})
})
Context("worker wedges mid-install (dead node holding the lock)", func() {
It("aborts the load at the ModelLoadCeiling instead of blocking forever", func() {
// Simulate the production incident: the chosen worker accepts the
// backend.install but never replies (it died), so InstallBackend
// would otherwise block for its full NATS deadline (15m by
// default) while pinning the per-model advisory lock. Route must
// give up at the ceiling so the lock is released promptly.
reg.findAndLockErr = errors.New("not found")
reg.findIdleNode = &BackendNode{ID: "n4", Name: "dead-node", Address: "10.0.0.4:50051"}
block := make(chan struct{})
defer close(block) // let the background install goroutine drain at test end
unloader.installHook = func() { <-block }
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
ModelLoadCeiling: 200 * time.Millisecond,
})
done := make(chan error, 1)
start := time.Now()
go func() {
defer GinkgoRecover()
_, err := router.Route(context.Background(), "wedged-model",
"models/wedged.gguf", "llama-cpp",
&pb.ModelOptions{Model: "models/wedged.gguf"}, false)
done <- err
}()
var routeErr error
Eventually(done, 5*time.Second).Should(Receive(&routeErr),
"Route must not block on a wedged install past the ceiling")
Expect(routeErr).To(HaveOccurred())
Expect(time.Since(start)).To(BeNumerically("<", 5*time.Second))
})
})
})
Describe("scheduleNewModel (mock-based, via Route)", func() {
var (
reg *fakeModelRouter
backend *stubBackend
factory *stubClientFactory
unloader *fakeUnloader
)
BeforeEach(func() {
reg = &fakeModelRouter{
findAndLockErr: errors.New("not found"),
}
backend = &stubBackend{
loadResult: &pb.Result{Success: true},
}
factory = &stubClientFactory{client: backend}
unloader = &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.1:9001",
},
}
})
It("finds a node with sufficient VRAM first", func() {
vramNode := &BackendNode{ID: "vram-node", Name: "gpu-box", Address: "10.0.0.10:50051"}
reg.findVRAMNode = vramNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
// Pass non-nil ModelOptions so estimateModelVRAM runs (returns 0 for
// missing files, so FindNodeWithVRAM won't actually be called unless
// estimatedVRAM > 0). To trigger VRAM path we need estimatedVRAM > 0,
// but that requires real files. Instead test the fallback: VRAM returns
// error, idle succeeds.
// Actually, estimateModelVRAM returns 0 when model files don't exist,
// so the VRAM branch is skipped and we go to idle/least-loaded.
// To properly test VRAM path, we'd need to mock estimateModelVRAM.
// For now, verify the fallback paths work correctly.
// With no real model files, estimatedVRAM=0, so VRAM path is skipped.
// Set idle node to test that path.
reg.findVRAMNode = nil
reg.findVRAMErr = errors.New("no vram nodes")
idleNode := &BackendNode{ID: "idle-vram", Name: "idle", Address: "10.0.0.11:50051"}
reg.findIdleNode = idleNode
result, err := router.Route(context.Background(), "m1", "models/m1.gguf", "llama-cpp", &pb.ModelOptions{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.ID).To(Equal("idle-vram"))
})
It("falls back to idle when VRAM search fails", func() {
reg.findVRAMErr = errors.New("no vram")
idleNode := &BackendNode{ID: "idle-1", Name: "idle-node", Address: "10.0.0.20:50051"}
reg.findIdleNode = idleNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "m2", "models/m2.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.ID).To(Equal("idle-1"))
})
It("falls back to least-loaded when both VRAM and idle fail", func() {
reg.findVRAMErr = errors.New("no vram")
reg.findIdleErr = errors.New("no idle")
llNode := &BackendNode{ID: "ll-1", Name: "least-loaded", Address: "10.0.0.30:50051"}
reg.findLeastLoadedNode = llNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "m3", "models/m3.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.ID).To(Equal("ll-1"))
})
It("returns error when no nodes are available and no DB for eviction", func() {
reg.findVRAMErr = errors.New("no vram")
reg.findIdleErr = errors.New("no idle")
reg.findLeastLoadedErr = errors.New("no nodes")
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
// DB is nil — evictLRUAndFreeNode will fail because r.db is nil
})
_, err := router.Route(context.Background(), "m4", "models/m4.gguf", "llama-cpp", nil, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no available nodes"))
})
})
Describe("UnloadModel (mock-based)", func() {
It("calls StopBackend and removes the model from the registry", func() {
reg := &fakeModelRouter{}
unloader := &fakeUnloader{}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
})
err := router.UnloadModel(context.Background(), "node-1", "model-a")
Expect(err).ToNot(HaveOccurred())
Expect(unloader.stopCalls).To(ContainElement("node-1:model-a"))
Expect(reg.removeCalls).To(ContainElement("node-1:model-a"))
})
It("returns error when no unloader is configured", func() {
reg := &fakeModelRouter{}
router := NewSmartRouter(reg, SmartRouterOptions{})
err := router.UnloadModel(context.Background(), "node-1", "model-a")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no remote unloader"))
})
})
Describe("EvictLRU (mock-based)", func() {
It("finds LRU model and unloads it", func() {
reg := &fakeModelRouter{
findLRUModel: &NodeModel{NodeID: "n1", ModelName: "old-model"},
}
unloader := &fakeUnloader{}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
})
evicted, err := router.EvictLRU(context.Background(), "n1")
Expect(err).ToNot(HaveOccurred())
Expect(evicted).To(Equal("old-model"))
Expect(unloader.stopCalls).To(ContainElement("n1:old-model"))
Expect(reg.removeCalls).To(ContainElement("n1:old-model"))
})
It("returns error when no LRU model is found", func() {
reg := &fakeModelRouter{
findLRUErr: errors.New("no models loaded"),
}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: &fakeUnloader{},
})
_, err := router.EvictLRU(context.Background(), "n1")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("finding LRU model"))
})
})
Describe("scheduleNewModel with node selector (mock-based, via Route)", func() {
var (
reg *fakeModelRouter
backend *stubBackend
factory *stubClientFactory
unloader *fakeUnloader
)
BeforeEach(func() {
reg = &fakeModelRouter{
findAndLockErr: errors.New("not found"),
}
backend = &stubBackend{
loadResult: &pb.Result{Success: true},
}
factory = &stubClientFactory{client: backend}
unloader = &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.1:9001",
},
}
})
It("uses *FromSet methods when model has a node selector", func() {
gpuNode := &BackendNode{ID: "gpu-1", Name: "gpu-node", Address: "10.0.0.50:50051"}
reg.getModelScheduling = &ModelSchedulingConfig{
ModelName: "selector-model",
NodeSelector: `{"gpu.vendor":"nvidia"}`,
}
reg.findBySelectorNodes = []BackendNode{*gpuNode}
reg.findIdleFromSetNode = gpuNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "selector-model", "models/selector.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("gpu-1"))
})
It("returns error when no nodes match selector", func() {
reg.getModelScheduling = &ModelSchedulingConfig{
ModelName: "no-match-model",
NodeSelector: `{"gpu.vendor":"tpu"}`,
}
reg.findBySelectorNodes = nil
reg.findBySelectorErr = nil
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
_, err := router.Route(context.Background(), "no-match-model", "models/nomatch.gguf", "llama-cpp", nil, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no healthy nodes match selector"))
})
It("uses regular methods when model has no scheduling config", func() {
reg.getModelScheduling = nil
idleNode := &BackendNode{ID: "regular-1", Name: "regular-node", Address: "10.0.0.60:50051"}
reg.findIdleNode = idleNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "regular-model", "models/regular.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("regular-1"))
})
})
Describe("Route with selector validation on cached model (mock-based)", func() {
It("falls through when cached node no longer matches selector", func() {
cachedNode := &BackendNode{ID: "n-old", Name: "old-node", Address: "10.0.0.70:50051"}
newNode := &BackendNode{ID: "n-new", Name: "new-node", Address: "10.0.0.71:50051"}
backend := &stubBackend{
healthResult: true,
loadResult: &pb.Result{Success: true},
}
factory := &stubClientFactory{client: backend}
unloader := &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.71:9001",
},
}
reg := &fakeModelRouter{
// Step 1: cached model found on old node
findAndLockNode: cachedNode,
findAndLockNM: &NodeModel{NodeID: "n-old", ModelName: "sel-model", Address: "10.0.0.70:9001"},
// Scheduling config with selector that old node does NOT match
getModelScheduling: &ModelSchedulingConfig{
ModelName: "sel-model",
NodeSelector: `{"gpu.vendor":"nvidia"}`,
},
// Old node has no labels matching the selector
getNodeLabels: []NodeLabel{
{NodeID: "n-old", Key: "gpu.vendor", Value: "amd"},
},
// For scheduling fallthrough: selector matches new node
findBySelectorNodes: []BackendNode{*newNode},
findIdleFromSetNode: newNode,
}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "sel-model", "models/sel.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
// Should have fallen through to the new node
Expect(result.Node.ID).To(Equal("n-new"))
// Old node should have had its in-flight decremented
Expect(reg.decrementCalls).To(ContainElement("n-old:sel-model"))
})
})
Describe("ScheduleAndLoadModel (mock-based)", func() {
It("returns an error and does not fire a NATS install when no load info is stored", func() {
// Reproduces the reconciler scale-up bug: when GetModelLoadInfo
// returns ErrRecordNotFound (no replica has ever been loaded),
// the previous fallback called scheduleNewModel with an empty
// backend type, which the worker rejected on every reconciler
// tick. The fix bails out cleanly with an explanatory error and
// never sends backend.install.
unloader := &fakeUnloader{}
reg := &fakeModelRouter{}
router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader})
node, err := router.ScheduleAndLoadModel(context.Background(), "never-loaded", nil)
Expect(err).To(HaveOccurred())
Expect(node).To(BeNil())
Expect(err.Error()).To(ContainSubstring("never-loaded"))
Expect(unloader.installCalls).To(BeEmpty(),
"reconciler must not fire backend.install when there is no load info to replicate")
})
})
// -----------------------------------------------------------------------
// Integration tests using real PostgreSQL (existing)
// -----------------------------------------------------------------------
Describe("evictLRUAndFreeNode (integration)", func() {
var (
db *gorm.DB
registry *NodeRegistry
)
BeforeEach(func() {
if runtime.GOOS == "darwin" {
Skip("testcontainers requires Docker, not available on macOS CI")
}
db = testutil.SetupTestDB()
var err error
registry, err = NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
It("returns ErrEvictionBusy in under 5 seconds when all models are busy", func() {
node := &BackendNode{
Name: "busy-evict",
NodeType: NodeTypeBackend,
Address: "10.0.0.100:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Load a model and give it in-flight requests so it cannot be evicted
Expect(registry.SetNodeModel(context.Background(), node.ID, "busy-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "busy-model", 0)).To(Succeed())
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
start := time.Now()
_, err := router.evictLRUAndFreeNode(context.Background())
elapsed := time.Since(start)
Expect(err).To(MatchError(ErrEvictionBusy))
// 5 retries * 500ms = 2.5s nominal; allow generous upper bound
Expect(elapsed).To(BeNumerically("<", 5*time.Second))
})
It("respects context cancellation", func() {
node := &BackendNode{
Name: "cancel-evict",
NodeType: NodeTypeBackend,
Address: "10.0.0.101:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "cancel-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "cancel-model", 0)).To(Succeed())
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
start := time.Now()
_, err := router.evictLRUAndFreeNode(ctx)
elapsed := time.Since(start)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("context cancelled"))
// Should return very quickly since context is already done
Expect(elapsed).To(BeNumerically("<", 2*time.Second))
})
})
Describe("stageModelFiles (integration)", func() {
var (
db *gorm.DB
registry *NodeRegistry
)
BeforeEach(func() {
if runtime.GOOS == "darwin" {
Skip("testcontainers requires Docker, not available on macOS CI")
}
db = testutil.SetupTestDB()
var err error
registry, err = NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
It("does not mutate the original ModelOptions", func() {
stager := &fakeFileStager{}
router := NewSmartRouter(registry, SmartRouterOptions{
FileStager: stager,
DB: db,
})
node := &BackendNode{
ID: "stage-node-id",
Name: "stage-node",
Address: "10.0.0.200:50051",
}
original := &pb.ModelOptions{
Model: "test-backend/models/test.gguf",
ModelFile: "/models/test-backend/models/test.gguf",
MMProj: "",
}
// Capture original values before staging
origModel := original.Model
origModelFile := original.ModelFile
origMMProj := original.MMProj
// stageModelFiles creates temp files for os.Stat checks.
// Since none of our test paths exist on disk, stageModelFiles will
// skip them (clearing non-existent optional fields). The key property
// is that the original proto pointer is not modified.
_, _ = router.stageModelFiles(context.Background(), node, original, "test-model")
// Verify the original proto was not mutated
Expect(original.Model).To(Equal(origModel))
Expect(original.ModelFile).To(Equal(origModelFile))
Expect(original.MMProj).To(Equal(origMMProj))
})
})
// -----------------------------------------------------------------------
// narrowByGroupAntiAffinity
// -----------------------------------------------------------------------
Describe("narrowByGroupAntiAffinity", func() {
var (
reg *fakeModelRouter
resolver *fakeConflictResolver
router *SmartRouter
ctx context.Context
)
BeforeEach(func() {
reg = &fakeModelRouter{}
resolver = &fakeConflictResolver{conflicts: map[string][]string{}}
router = NewSmartRouter(reg, SmartRouterOptions{
ConflictResolver: resolver,
})
ctx = context.Background()
})
It("returns the input set unchanged when the model has no conflicts", func() {
candidates := []string{"n1", "n2", "n3"}
out, err := router.narrowByGroupAntiAffinity(ctx, "lonely", candidates)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal(candidates))
})
It("removes nodes that already host a conflicting model", func() {
resolver.conflicts["b"] = []string{"a"}
reg.findNodesWithModelByName = map[string][]BackendNode{
"a": {{ID: "n1"}},
}
out, err := router.narrowByGroupAntiAffinity(ctx, "b", []string{"n1", "n2"})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(ConsistOf("n2"))
})
It("returns the original set unchanged when every candidate has a conflict (soft fallback)", func() {
resolver.conflicts["b"] = []string{"a"}
reg.findNodesWithModelByName = map[string][]BackendNode{
"a": {{ID: "n1"}, {ID: "n2"}},
}
candidates := []string{"n1", "n2"}
out, err := router.narrowByGroupAntiAffinity(ctx, "b", candidates)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal(candidates))
})
It("removes nodes hosting any of multiple conflicting models", func() {
resolver.conflicts["c"] = []string{"a", "b"}
reg.findNodesWithModelByName = map[string][]BackendNode{
"a": {{ID: "n1"}},
"b": {{ID: "n2"}},
}
out, err := router.narrowByGroupAntiAffinity(ctx, "c", []string{"n1", "n2", "n3"})
Expect(err).ToNot(HaveOccurred())
Expect(out).To(ConsistOf("n3"))
})
It("treats a nil candidate set (\"any healthy node\") by returning nil unchanged when narrowing yields nothing", func() {
resolver.conflicts["b"] = []string{"a"}
reg.findNodesWithModelByName = map[string][]BackendNode{
"a": {{ID: "n1"}, {ID: "n2"}},
}
out, err := router.narrowByGroupAntiAffinity(ctx, "b", nil)
Expect(err).ToNot(HaveOccurred())
// nil in → nil out: caller's "any healthy node" semantics preserved.
// Hard-narrowing nil would silently exclude every other node.
Expect(out).To(BeNil())
})
It("is a no-op when no resolver is configured", func() {
plain := NewSmartRouter(reg, SmartRouterOptions{})
candidates := []string{"n1", "n2"}
out, err := plain.narrowByGroupAntiAffinity(ctx, "b", candidates)
Expect(err).ToNot(HaveOccurred())
Expect(out).To(Equal(candidates))
})
})
Describe("installBackendOnNode singleflight", func() {
It("coalesces concurrent identical installs into one NATS call", func() {
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
// Slow install reply so concurrent calls overlap deterministically.
started := make(chan struct{}, 5)
release := make(chan struct{})
unloader := &fakeUnloader{
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"},
}
unloader.installHook = func() {
started <- struct{}{}
<-release
}
router := NewSmartRouter(&fakeModelRouter{}, SmartRouterOptions{
Unloader: unloader,
ClientFactory: &stubClientFactory{client: &stubBackend{}},
})
// Fire 5 concurrent identical installBackendOnNode calls.
done := make(chan error, 5)
for i := 0; i < 5; i++ {
go func() {
_, err := router.installBackendOnNode(context.Background(), node, "llama-cpp", "my-model", 0)
done <- err
}()
}
// Only ONE call should have entered the unloader hook (the
// singleflight leader). The other 4 are coalesced and waiting on
// the leader's result.
Eventually(started).Should(Receive())
Consistently(started, 100*time.Millisecond).ShouldNot(Receive())
// Release the leader; the other 4 callers receive the same result.
close(release)
for i := 0; i < 5; i++ {
Expect(<-done).ToNot(HaveOccurred())
}
Expect(unloader.installCalls).To(HaveLen(1),
"singleflight should coalesce 5 concurrent identical loads into 1 NATS call")
})
It("does NOT coalesce installs for different (modelID, replica) keys", func() {
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
unloader := &fakeUnloader{
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"},
}
router := NewSmartRouter(&fakeModelRouter{}, SmartRouterOptions{
Unloader: unloader,
ClientFactory: &stubClientFactory{client: &stubBackend{}},
})
_, err1 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-A", 0)
_, err2 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-B", 0)
_, err3 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-A", 1)
Expect(err1).ToNot(HaveOccurred())
Expect(err2).ToNot(HaveOccurred())
Expect(err3).ToNot(HaveOccurred())
Expect(unloader.installCalls).To(HaveLen(3))
})
})
})
// ---------------------------------------------------------------------------
// Fake prefixcache.Provider for SmartRouter prefix-cache routing tests
// ---------------------------------------------------------------------------
type observeRecord struct {
model string
chain []uint64
key prefixcache.ReplicaKey
}
type invalidateRecord struct {
model string
key prefixcache.ReplicaKey
}
// fakePrefixProvider records all interactions and returns a configurable
// decision.
type fakePrefixProvider struct {
decideCalls int
observed []observeRecord
invalidated []invalidateRecord
invalidatedNode []string
decision prefixcache.PrefixDecision
}
func (f *fakePrefixProvider) Decide(_ string, _ []uint64, _ []prefixcache.ReplicaKey, _ time.Time) prefixcache.PrefixDecision {
f.decideCalls++
return f.decision
}
func (f *fakePrefixProvider) Observe(model string, chain []uint64, key prefixcache.ReplicaKey, _ time.Time) bool {
f.observed = append(f.observed, observeRecord{model: model, chain: append([]uint64(nil), chain...), key: key})
return true
}
func (f *fakePrefixProvider) Invalidate(model string, key prefixcache.ReplicaKey) {
f.invalidated = append(f.invalidated, invalidateRecord{model: model, key: key})
}
func (f *fakePrefixProvider) InvalidateNode(model, nodeID string) {
f.invalidatedNode = append(f.invalidatedNode, model+":"+nodeID)
}
func (f *fakePrefixProvider) Evict(_ time.Time) {}
var _ = Describe("SmartRouter prefix-cache routing", func() {
var (
backend *stubBackend
factory *stubClientFactory
unloader *fakeUnloader
)
BeforeEach(func() {
backend = &stubBackend{healthResult: true}
factory = &stubClientFactory{client: backend}
unloader = &fakeUnloader{
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:9001"},
}
})
// loadedReg builds a fake registry with one loaded healthy replica for
// "m" on node "X", plus matching replica stats so buildPreference can run.
loadedReg := func() *fakeModelRouter {
node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
nm := &NodeModel{NodeID: "X", ModelName: "m", Address: "10.0.0.1:9001"}
return &fakeModelRouter{
findAndLockNode: node,
findAndLockNM: nm,
getModelScheduling: &ModelSchedulingConfig{
RoutePolicy: "prefix_cache",
},
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
"m": {{NodeID: "X", InFlight: 0}},
},
}
}
Context("nil provider (round-robin floor)", func() {
It("passes a nil preference and never decides or observes", func() {
reg := loadedReg()
router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader, ClientFactory: factory})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(reg.findAndLockPrefs).ToNot(BeEmpty())
for _, p := range reg.findAndLockPrefs {
Expect(p).To(BeNil())
}
})
})
Context("with a provider", func() {
It("passes the decided node as the preference and observes the pick", func() {
reg := loadedReg()
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: prov,
PrefixConfig: prefixcache.DefaultConfig(),
})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(prov.decideCalls).To(BeNumerically(">=", 1))
Expect(reg.findAndLockPrefs[0]).ToNot(BeNil())
Expect(reg.findAndLockPrefs[0].PreferredNodeID).To(Equal("X"))
Expect(reg.findAndLockPrefs[0].PreferredReplica).To(Equal(0))
Expect(prov.observed).To(HaveLen(1))
Expect(prov.observed[0].key).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
Expect(prov.observed[0].chain).To(Equal([]uint64{1, 2, 3}))
})
It("routes a recurring prefix back to the previously observed node", func() {
// Real Index as the provider: first request observes X, second
// request with the same chain must yield PreferredNodeID == X.
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
reg := loadedReg()
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: idx,
PrefixConfig: prefixcache.DefaultConfig(),
})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{7, 8, 9})
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
// First request landed on X (cold placement on the only candidate)
// and observed the prefix there.
dFirst := idx.Decide("m", []uint64{7, 8, 9}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now())
Expect(dFirst.HasHot).To(BeTrue())
Expect(dFirst.Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
// Second request, same chain: X is now the warm-cache hot match, so
// the preference must point at it.
_, err = router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
last := reg.findAndLockPrefs[len(reg.findAndLockPrefs)-1]
Expect(last).ToNot(BeNil())
Expect(last.PreferredNodeID).To(Equal("X"))
Expect(last.PreferredReplica).To(Equal(0))
})
It("prefers the exact hot replica when two replicas share a node", func() {
// Two replicas of "m" live on the SAME node X: replica 0 and replica
// 1. A hot prefix observed on (X,0) must produce a preference that
// locks replica 0 specifically, NOT the sibling replica 1 on the same
// node. This is the replica-granular regression this change fixes.
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
nm := &NodeModel{NodeID: "X", ModelName: "m", ReplicaIndex: 0, Address: "10.0.0.1:9001"}
reg := &fakeModelRouter{
findAndLockNode: node,
findAndLockNM: nm,
getModelScheduling: &ModelSchedulingConfig{
RoutePolicy: "prefix_cache",
},
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
"m": {
{NodeID: "X", ReplicaIndex: 0, InFlight: 0},
{NodeID: "X", ReplicaIndex: 1, InFlight: 0},
},
},
}
// Seed the index so (X,0) is the warm replica for this chain.
idx.Observe("m", []uint64{1, 2, 3}, prefixcache.ReplicaKey{NodeID: "X", Replica: 0}, time.Now())
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: idx,
PrefixConfig: prefixcache.DefaultConfig(),
})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
pref := reg.findAndLockPrefs[0]
Expect(pref).ToNot(BeNil())
Expect(pref.PreferredNodeID).To(Equal("X"))
Expect(pref.PreferredReplica).To(Equal(0),
"the hot prefix lives on replica 0; the same-node sibling replica 1 must NOT be chosen")
})
It("does not decide or observe when no prefix chain is present", func() {
reg := loadedReg()
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: prov,
PrefixConfig: prefixcache.DefaultConfig(),
})
_, err := router.Route(context.Background(), "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(prov.decideCalls).To(Equal(0))
Expect(prov.observed).To(BeEmpty())
Expect(reg.findAndLockPrefs[0]).To(BeNil())
})
It("does not observe for round-robin models even with a chain", func() {
reg := loadedReg()
reg.getModelScheduling = &ModelSchedulingConfig{RoutePolicy: "round_robin"}
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: prov,
PrefixConfig: prefixcache.DefaultConfig(),
})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(prov.decideCalls).To(Equal(0))
Expect(prov.observed).To(BeEmpty())
Expect(reg.findAndLockPrefs[0]).To(BeNil())
})
})
Context("forced-disturb pressure", func() {
// disturbReg builds a registry with two candidate replicas for "m":
// the hot node X is saturated (high in_flight) and Y is free. Select
// will therefore reject the hot node and pick Y, which is the
// forced-disturb signal. findAndLockNode returns Y so Route succeeds.
disturbReg := func() *fakeModelRouter {
nodeY := &BackendNode{ID: "Y", Name: "node-y", Address: "10.0.0.2:50051"}
nm := &NodeModel{NodeID: "Y", ModelName: "m", Address: "10.0.0.2:9001"}
return &fakeModelRouter{
findAndLockNode: nodeY,
findAndLockNM: nm,
getModelScheduling: &ModelSchedulingConfig{
RoutePolicy: "prefix_cache",
},
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
"m": {{NodeID: "X", InFlight: 50}, {NodeID: "Y", InFlight: 0}},
},
}
}
It("records pressure when a strong hot match was forced off the warm node", func() {
reg := disturbReg()
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
Hot: prefixcache.ReplicaKey{NodeID: "X"},
HasHot: true,
MatchRatio: 1.0,
ColdOrder: []prefixcache.ReplicaKey{{NodeID: "Y"}, {NodeID: "X"}},
}}
pressure := prefixcache.NewPressure(time.Minute)
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: prov,
PrefixConfig: prefixcache.DefaultConfig(),
Pressure: pressure,
})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(pressure.Count("m", time.Now())).To(BeNumerically(">", 0),
"hot match existed but the load guard forced us off X: must record pressure")
})
It("does not record pressure when the hot node is itself eligible", func() {
reg := loadedReg() // single node X, in_flight 0 → X stays eligible
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
Hot: prefixcache.ReplicaKey{NodeID: "X"},
HasHot: true,
MatchRatio: 1.0,
}}
pressure := prefixcache.NewPressure(time.Minute)
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: prov,
PrefixConfig: prefixcache.DefaultConfig(),
Pressure: pressure,
})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(pressure.Count("m", time.Now())).To(Equal(0),
"chosen == hot node, no disturb")
})
It("does not record pressure for an all-unique workload with no hot match", func() {
reg := loadedReg()
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
HasHot: false, // no prefix match at all
MatchRatio: 0,
ColdOrder: []prefixcache.ReplicaKey{{NodeID: "X"}},
}}
pressure := prefixcache.NewPressure(time.Minute)
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: prov,
PrefixConfig: prefixcache.DefaultConfig(),
Pressure: pressure,
})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(pressure.Count("m", time.Now())).To(Equal(0),
"no hot match means no cache to disturb: must not false-positive")
})
})
Context("removal chokepoint on unload", func() {
It("removes the replica via the registry so the removal hook invalidates the prefix entry", func() {
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
reg := loadedReg()
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
PrefixProvider: idx,
PrefixConfig: prefixcache.DefaultConfig(),
})
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{5, 6})
// Warm the cache: X now holds the prefix.
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(idx.Decide("m", []uint64{5, 6}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now()).Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
// UnloadModel must route the eviction through the registry removal
// chokepoint (RemoveAllNodeModelReplicas). The registry's
// SetReplicaRemovedHook is what invalidates the prefix index in
// production; the router no longer invalidates directly. Here the
// fake registry records the removal but fires no hook, so we assert
// the chokepoint is exercised rather than the downstream
// invalidation (covered by the registry hook integration tests).
Expect(router.UnloadModel(context.Background(), "X", "m")).To(Succeed())
Expect(reg.removeCalls).To(ContainElement("X:m"),
"UnloadModel must remove the replica via the registry removal chokepoint")
})
})
})