mirror of
https://github.com/mudler/LocalAI.git
synced 2026-07-03 04:46:54 -04:00
* fix(distributed): don't let a dead worker pin the model-load advisory lock In distributed mode a chat request could fail with: failed to route model with internal loader: routing model ...: loading model ...: advisorylock: acquiring lock <id>: ERROR: canceling statement due to lock timeout (SQLSTATE 55P03) Root cause is two independent defects in the cross-replica model-load path: 1. SmartRouter.Route holds a per-model PostgreSQL advisory lock for the whole cold-load sequence, which includes installBackendOnNode -> InstallBackend, a NATS request-reply with a 15m deadline (DefaultBackendInstallTimeout) that ignored ctx. When the chosen worker died mid-install, the holder sat on the lock for up to 15m. The detached loadCtx (WithoutCancel) had no deadline, so nothing capped the hold. 2. The acquiring statement, pg_advisory_lock(), is subject to any deployment global lock_timeout. A common operator setting (e.g. 10s) aborts the wait with SQLSTATE 55P03, so every other replica's request for that model hard -errored instead of waiting for the in-progress load and reusing it. For the ~15m window the model was effectively unroutable. Fixes: - advisorylock.WithLockCtx (postgres): SET lock_timeout = 0 on its dedicated connection (RESET before it returns to the pool) so the Go context, not a deployment-wide GUC, governs how long we wait. Waiters now block and then re-check, reusing the model another replica just loaded. - SmartRouter: bound the detached loadCtx with a single ModelLoadCeiling so the lock is always released in bounded time even if a sub-step wedges. Default is the configured backend.install deadline + 10m (staging + LoadModel margin), so a legitimately slow load is never cut. - installBackendOnNode: use singleflight.DoChan + select on ctx.Done() so the install wait honors cancellation; the ceiling can then actually free a caller pinned behind a dead worker. The shared install still coalesces via singleflight. Reproduced both defects as failing tests first (a real 55P03 against a testcontainer with a short lock_timeout; a wedged install that blocks Route) and confirmed green. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] * fix(distributed): bound advisory-lock wait instead of disabling lock_timeout Setting lock_timeout = 0 to override a deployment's short global lock_timeout meant "wait forever" server-side. Safe for SmartRouter.Route (its loadCtx now carries the model-load ceiling) but unsafe for the schema-migration callers that pass context.Background(): a holder whose session never releases would hang them indefinitely. Derive the server-side lock_timeout from the caller's context instead: its remaining budget plus a margin (so the Go context's cancellation still wins with a clean error and the server bound is only a backstop), or a finite 30m backstop when the context has no deadline. Never zero - "wait forever" is no longer possible, while a deployment's hostile short lock_timeout is still overridden so legitimate cross-replica waits don't fail with 55P03. Added a spec proving a deadline-less waiter gives up at the (shrunk) backstop rather than hanging. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code] --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
1466 lines
52 KiB
Go
1466 lines
52 KiB
Go
package nodes
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
"github.com/mudler/LocalAI/core/services/messaging"
|
|
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
|
"github.com/mudler/LocalAI/core/services/testutil"
|
|
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
ggrpc "google.golang.org/grpc"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake FileStager (pre-existing)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// fakeFileStager is a minimal FileStager that records calls and returns
|
|
// predictable remote paths without touching the filesystem or network.
|
|
type fakeFileStager struct {
|
|
ensureCalls []ensureCall
|
|
}
|
|
|
|
type ensureCall struct {
|
|
nodeID, localPath, key string
|
|
}
|
|
|
|
func (f *fakeFileStager) EnsureRemote(_ context.Context, nodeID, localPath, key string) (string, error) {
|
|
f.ensureCalls = append(f.ensureCalls, ensureCall{nodeID, localPath, key})
|
|
return "/remote/" + key, nil
|
|
}
|
|
|
|
func (f *fakeFileStager) FetchRemote(_ context.Context, _, _, _ string) error { return nil }
|
|
|
|
func (f *fakeFileStager) FetchRemoteByKey(_ context.Context, _, _, _ string) error { return nil }
|
|
|
|
func (f *fakeFileStager) AllocRemoteTemp(_ context.Context, _ string) (string, error) {
|
|
return "/remote/tmp", nil
|
|
}
|
|
|
|
func (f *fakeFileStager) StageRemoteToStore(_ context.Context, _, _, _ string) error { return nil }
|
|
|
|
func (f *fakeFileStager) ListRemoteDir(_ context.Context, _, _ string) ([]string, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake ModelRouter
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// fakeModelRouter implements ModelRouter with configurable return values.
|
|
type fakeModelRouter struct {
|
|
// FindAndLockNodeWithModel returns
|
|
findAndLockNode *BackendNode
|
|
findAndLockNM *NodeModel
|
|
findAndLockErr error
|
|
|
|
// FindNodeWithVRAM returns
|
|
findVRAMNode *BackendNode
|
|
findVRAMErr error
|
|
|
|
// FindIdleNode returns
|
|
findIdleNode *BackendNode
|
|
findIdleErr error
|
|
|
|
// FindLeastLoadedNode returns
|
|
findLeastLoadedNode *BackendNode
|
|
findLeastLoadedErr error
|
|
|
|
// FindGlobalLRUModelWithZeroInFlight returns
|
|
findGlobalLRUModel *NodeModel
|
|
findGlobalLRUErr error
|
|
|
|
// FindLRUModel returns
|
|
findLRUModel *NodeModel
|
|
findLRUErr error
|
|
|
|
// Get returns
|
|
getNode *BackendNode
|
|
getErr error
|
|
|
|
// GetModelScheduling returns
|
|
getModelScheduling *ModelSchedulingConfig
|
|
getModelSchedErr error
|
|
|
|
// FindNodesBySelector returns
|
|
findBySelectorNodes []BackendNode
|
|
findBySelectorErr error
|
|
|
|
// *FromSet variants
|
|
findVRAMFromSetNode *BackendNode
|
|
findVRAMFromSetErr error
|
|
findIdleFromSetNode *BackendNode
|
|
findIdleFromSetErr error
|
|
findLeastLoadedFromSetNode *BackendNode
|
|
findLeastLoadedFromSetErr error
|
|
|
|
// GetNodeLabels returns
|
|
getNodeLabels []NodeLabel
|
|
getNodeLabelsErr error
|
|
|
|
// FindNodesWithModel returns (keyed by model name)
|
|
findNodesWithModelByName map[string][]BackendNode
|
|
findNodesWithModelErr error
|
|
|
|
// LoadedReplicaStats returns (keyed by model name)
|
|
loadedReplicaStatsByName map[string][]ReplicaCandidate
|
|
loadedReplicaStatsErr error
|
|
|
|
// Track calls for assertions
|
|
decrementCalls []string // "nodeID:modelName"
|
|
incrementCalls []string
|
|
removeCalls []string
|
|
setCalls []string
|
|
touchCalls []string
|
|
|
|
// Preferences passed to FindAndLockNodeWithModel, in call order. nil
|
|
// entries are recorded too, so tests can assert "preference was nil".
|
|
findAndLockPrefs []*RoutePreference
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) {
|
|
f.findAndLockPrefs = append(f.findAndLockPrefs, pref)
|
|
return f.findAndLockNode, f.findAndLockNM, f.findAndLockErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) LoadedReplicaStats(_ context.Context, modelName string, _ []string) ([]ReplicaCandidate, error) {
|
|
if f.loadedReplicaStatsErr != nil {
|
|
return nil, f.loadedReplicaStatsErr
|
|
}
|
|
return f.loadedReplicaStatsByName[modelName], nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
|
|
f.decrementCalls = append(f.decrementCalls, nodeID+":"+modelName)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) IncrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
|
|
f.incrementCalls = append(f.incrementCalls, nodeID+":"+modelName)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) RemoveNodeModel(_ context.Context, nodeID, modelName string, _ int) error {
|
|
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) RemoveAllNodeModelReplicas(_ context.Context, nodeID, modelName string) error {
|
|
// Same recorded key as RemoveNodeModel so existing tests that assert "the
|
|
// model was removed" don't need to know whether the production code used
|
|
// the per-replica or all-replicas variant.
|
|
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) TouchNodeModel(_ context.Context, nodeID, modelName string, _ int) {
|
|
f.touchCalls = append(f.touchCalls, nodeID+":"+modelName)
|
|
}
|
|
|
|
func (f *fakeModelRouter) SetNodeModel(_ context.Context, nodeID, modelName string, _ int, state, address string, _ int) error {
|
|
f.setCalls = append(f.setCalls, fmt.Sprintf("%s:%s:%s:%s", nodeID, modelName, state, address))
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) SetNodeModelLoadInfo(_ context.Context, _, _ string, _ int, _ string, _ []byte) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) UpsertModelLoadInfo(_ context.Context, _, _ string, _ []byte) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) GetModelLoadInfo(_ context.Context, _ string) (string, []byte, error) {
|
|
return "", nil, fmt.Errorf("not found")
|
|
}
|
|
|
|
func (f *fakeModelRouter) NextFreeReplicaIndex(_ context.Context, _, _ string, _ int) (int, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) CountReplicasOnNode(_ context.Context, _, _ string) (int, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodeWithVRAM(_ context.Context, _ uint64) (*BackendNode, error) {
|
|
return f.findVRAMNode, f.findVRAMErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindIdleNode(_ context.Context) (*BackendNode, error) {
|
|
return f.findIdleNode, f.findIdleErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindLeastLoadedNode(_ context.Context) (*BackendNode, error) {
|
|
return f.findLeastLoadedNode, f.findLeastLoadedErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindGlobalLRUModelWithZeroInFlight(_ context.Context) (*NodeModel, error) {
|
|
return f.findGlobalLRUModel, f.findGlobalLRUErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindLRUModel(_ context.Context, _ string) (*NodeModel, error) {
|
|
return f.findLRUModel, f.findLRUErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) Get(_ context.Context, _ string) (*BackendNode, error) {
|
|
return f.getNode, f.getErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
|
|
return f.getModelScheduling, f.getModelSchedErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
|
|
return f.findBySelectorNodes, f.findBySelectorErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodesWithFreeSlot(_ context.Context, _ string, _ []string) ([]BackendNode, error) {
|
|
// Default: same answer as FindNodesBySelector. Tests that need a
|
|
// specific filter can override by reusing findBySelectorNodes.
|
|
return f.findBySelectorNodes, f.findBySelectorErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) ReserveVRAM(_ context.Context, _ string, _ uint64) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) ReleaseVRAM(_ context.Context, _ string, _ uint64) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
|
|
return f.findVRAMFromSetNode, f.findVRAMFromSetErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
|
|
return f.findIdleFromSetNode, f.findIdleFromSetErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
|
|
return f.findLeastLoadedFromSetNode, f.findLeastLoadedFromSetErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
|
|
return f.getNodeLabels, f.getNodeLabelsErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodesWithModel(_ context.Context, modelName string) ([]BackendNode, error) {
|
|
if f.findNodesWithModelErr != nil {
|
|
return nil, f.findNodesWithModelErr
|
|
}
|
|
return f.findNodesWithModelByName[modelName], nil
|
|
}
|
|
|
|
// fakeConflictResolver implements ConcurrencyConflictResolver from a static map.
|
|
type fakeConflictResolver struct {
|
|
conflicts map[string][]string
|
|
}
|
|
|
|
func (f *fakeConflictResolver) GetModelsConflictingWith(name string) []string {
|
|
if f == nil {
|
|
return nil
|
|
}
|
|
return f.conflicts[name]
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake BackendClientFactory + Backend
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// stubBackend implements grpc.Backend with configurable HealthCheck and LoadModel.
|
|
type stubBackend struct {
|
|
grpc.Backend // embed to satisfy interface; unused methods will panic if called
|
|
|
|
healthResult bool
|
|
healthErr error
|
|
loadResult *pb.Result
|
|
loadErr error
|
|
}
|
|
|
|
func (f *stubBackend) HealthCheck(_ context.Context) (bool, error) {
|
|
return f.healthResult, f.healthErr
|
|
}
|
|
|
|
func (f *stubBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return f.loadResult, f.loadErr
|
|
}
|
|
|
|
func (f *stubBackend) IsBusy() bool { return false }
|
|
|
|
// stubClientFactory returns the same stubBackend for every call.
|
|
type stubClientFactory struct {
|
|
client *stubBackend
|
|
}
|
|
|
|
func (f *stubClientFactory) NewClient(_ string, _ bool) grpc.Backend {
|
|
return f.client
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake NodeCommandSender (unloader)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
type fakeUnloader struct {
|
|
// mu guards installCalls and upgradeCalls so concurrent test
|
|
// goroutines (e.g. singleflight specs) don't race the slice appends.
|
|
mu sync.Mutex
|
|
|
|
installReply *messaging.BackendInstallReply
|
|
installErr error
|
|
installCalls []installCall // every InstallBackend invocation, in order
|
|
// installHook, if non-nil, runs at the start of InstallBackend before
|
|
// the call is recorded. Used by concurrency tests as a deterministic
|
|
// "block here" seam — set installHook to a function that sleeps or
|
|
// blocks on a channel to overlap two callers.
|
|
installHook func()
|
|
|
|
upgradeReply *messaging.BackendUpgradeReply
|
|
upgradeErr error
|
|
upgradeCalls []upgradeCall // every UpgradeBackend invocation, in order
|
|
|
|
stopCalls []string // "nodeID:model"
|
|
stopErr error
|
|
unloadCalls []string
|
|
unloadErr error
|
|
}
|
|
|
|
// installCall captures the args we care about when asserting that the
|
|
// reconciler / router did or did not fire a NATS install. The fake records
|
|
// every call so tests can verify both presence and shape (e.g. that backend
|
|
// is non-empty).
|
|
type installCall struct {
|
|
nodeID string
|
|
backend string
|
|
modelID string
|
|
replica int
|
|
}
|
|
|
|
type upgradeCall struct {
|
|
nodeID string
|
|
backend string
|
|
replica int
|
|
}
|
|
|
|
func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
|
|
// installHook intentionally runs OUTSIDE the mutex: the hook may block
|
|
// on a channel and we don't want to serialize concurrent callers,
|
|
// which would defeat the singleflight-overlap test.
|
|
if f.installHook != nil {
|
|
f.installHook()
|
|
}
|
|
f.mu.Lock()
|
|
f.installCalls = append(f.installCalls, installCall{nodeID, backend, modelID, replica})
|
|
f.mu.Unlock()
|
|
return f.installReply, f.installErr
|
|
}
|
|
|
|
func (f *fakeUnloader) UpgradeBackend(nodeID, backend, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
|
|
f.mu.Lock()
|
|
f.upgradeCalls = append(f.upgradeCalls, upgradeCall{nodeID, backend, replica})
|
|
f.mu.Unlock()
|
|
return f.upgradeReply, f.upgradeErr
|
|
}
|
|
|
|
func (f *fakeUnloader) DeleteBackend(_, _ string) (*messaging.BackendDeleteReply, error) {
|
|
return &messaging.BackendDeleteReply{Success: true}, nil
|
|
}
|
|
|
|
func (f *fakeUnloader) ListBackends(_ string) (*messaging.BackendListReply, error) {
|
|
return &messaging.BackendListReply{}, nil
|
|
}
|
|
|
|
func (f *fakeUnloader) StopBackend(nodeID, backend string) error {
|
|
f.stopCalls = append(f.stopCalls, nodeID+":"+backend)
|
|
return f.stopErr
|
|
}
|
|
|
|
func (f *fakeUnloader) UnloadModelOnNode(nodeID, modelName string) error {
|
|
f.unloadCalls = append(f.unloadCalls, nodeID+":"+modelName)
|
|
return f.unloadErr
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
var _ = Describe("SmartRouter", func() {
|
|
// -----------------------------------------------------------------------
|
|
// Unit tests using mock interfaces (no DB required)
|
|
// -----------------------------------------------------------------------
|
|
Describe("Route (mock-based)", func() {
|
|
var (
|
|
reg *fakeModelRouter
|
|
backend *stubBackend
|
|
factory *stubClientFactory
|
|
unloader *fakeUnloader
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
reg = &fakeModelRouter{}
|
|
backend = &stubBackend{}
|
|
factory = &stubClientFactory{client: backend}
|
|
unloader = &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{
|
|
Success: true,
|
|
Address: "10.0.0.1:9001",
|
|
},
|
|
}
|
|
})
|
|
|
|
Context("model already loaded on a healthy node", func() {
|
|
It("returns the client and a release function", func() {
|
|
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
|
|
nm := &NodeModel{NodeID: "n1", ModelName: "my-model", Address: "10.0.0.1:9001"}
|
|
reg.findAndLockNode = node
|
|
reg.findAndLockNM = nm
|
|
backend.healthResult = true
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "my-model", "models/my-model.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.ID).To(Equal("n1"))
|
|
|
|
// TouchNodeModel should have been called
|
|
Expect(reg.touchCalls).To(ContainElement("n1:my-model"))
|
|
|
|
// The initial in-flight reservation from FindAndLockNodeWithModel is released
|
|
// after the first inference call completes via OnFirstComplete callback.
|
|
// Release only closes the client.
|
|
result.Release()
|
|
// No decrement on Release — it happens via OnFirstComplete after first Predict
|
|
Expect(reg.decrementCalls).To(BeEmpty())
|
|
})
|
|
})
|
|
|
|
Context("model not loaded, falls through to scheduling", func() {
|
|
It("schedules on an idle node and records the model", func() {
|
|
// FindAndLockNodeWithModel always fails — simulates no cached model
|
|
// (equivalent to the health-check-failure fallthrough path).
|
|
idleNode := &BackendNode{ID: "n2", Name: "idle-node", Address: "10.0.0.2:50051"}
|
|
reg2 := &fakeModelRouter{
|
|
findAndLockErr: errors.New("not found"),
|
|
findIdleNode: idleNode,
|
|
}
|
|
backend.loadResult = &pb.Result{Success: true}
|
|
|
|
router := NewSmartRouter(reg2, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "some-model", "models/some-model.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.ID).To(Equal("n2"))
|
|
|
|
// SetNodeModel should record the model as loaded on the node
|
|
Expect(reg2.setCalls).To(HaveLen(1))
|
|
Expect(reg2.setCalls[0]).To(ContainSubstring("n2:some-model:loaded"))
|
|
})
|
|
})
|
|
|
|
Context("model not loaded, no DB (advisory lock bypassed)", func() {
|
|
It("schedules on an available node via FindIdleNode", func() {
|
|
reg.findAndLockErr = errors.New("not found")
|
|
idleNode := &BackendNode{ID: "n3", Name: "idle", Address: "10.0.0.3:50051"}
|
|
reg.findIdleNode = idleNode
|
|
backend.loadResult = &pb.Result{Success: true}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
// DB is nil — no advisory lock
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "new-model", "models/new.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.ID).To(Equal("n3"))
|
|
})
|
|
})
|
|
|
|
Context("worker wedges mid-install (dead node holding the lock)", func() {
|
|
It("aborts the load at the ModelLoadCeiling instead of blocking forever", func() {
|
|
// Simulate the production incident: the chosen worker accepts the
|
|
// backend.install but never replies (it died), so InstallBackend
|
|
// would otherwise block for its full NATS deadline (15m by
|
|
// default) while pinning the per-model advisory lock. Route must
|
|
// give up at the ceiling so the lock is released promptly.
|
|
reg.findAndLockErr = errors.New("not found")
|
|
reg.findIdleNode = &BackendNode{ID: "n4", Name: "dead-node", Address: "10.0.0.4:50051"}
|
|
|
|
block := make(chan struct{})
|
|
defer close(block) // let the background install goroutine drain at test end
|
|
unloader.installHook = func() { <-block }
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
ModelLoadCeiling: 200 * time.Millisecond,
|
|
})
|
|
|
|
done := make(chan error, 1)
|
|
start := time.Now()
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
_, err := router.Route(context.Background(), "wedged-model",
|
|
"models/wedged.gguf", "llama-cpp",
|
|
&pb.ModelOptions{Model: "models/wedged.gguf"}, false)
|
|
done <- err
|
|
}()
|
|
|
|
var routeErr error
|
|
Eventually(done, 5*time.Second).Should(Receive(&routeErr),
|
|
"Route must not block on a wedged install past the ceiling")
|
|
Expect(routeErr).To(HaveOccurred())
|
|
Expect(time.Since(start)).To(BeNumerically("<", 5*time.Second))
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("scheduleNewModel (mock-based, via Route)", func() {
|
|
var (
|
|
reg *fakeModelRouter
|
|
backend *stubBackend
|
|
factory *stubClientFactory
|
|
unloader *fakeUnloader
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
reg = &fakeModelRouter{
|
|
findAndLockErr: errors.New("not found"),
|
|
}
|
|
backend = &stubBackend{
|
|
loadResult: &pb.Result{Success: true},
|
|
}
|
|
factory = &stubClientFactory{client: backend}
|
|
unloader = &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{
|
|
Success: true,
|
|
Address: "10.0.0.1:9001",
|
|
},
|
|
}
|
|
})
|
|
|
|
It("finds a node with sufficient VRAM first", func() {
|
|
vramNode := &BackendNode{ID: "vram-node", Name: "gpu-box", Address: "10.0.0.10:50051"}
|
|
reg.findVRAMNode = vramNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
// Pass non-nil ModelOptions so estimateModelVRAM runs (returns 0 for
|
|
// missing files, so FindNodeWithVRAM won't actually be called unless
|
|
// estimatedVRAM > 0). To trigger VRAM path we need estimatedVRAM > 0,
|
|
// but that requires real files. Instead test the fallback: VRAM returns
|
|
// error, idle succeeds.
|
|
// Actually, estimateModelVRAM returns 0 when model files don't exist,
|
|
// so the VRAM branch is skipped and we go to idle/least-loaded.
|
|
// To properly test VRAM path, we'd need to mock estimateModelVRAM.
|
|
// For now, verify the fallback paths work correctly.
|
|
|
|
// With no real model files, estimatedVRAM=0, so VRAM path is skipped.
|
|
// Set idle node to test that path.
|
|
reg.findVRAMNode = nil
|
|
reg.findVRAMErr = errors.New("no vram nodes")
|
|
idleNode := &BackendNode{ID: "idle-vram", Name: "idle", Address: "10.0.0.11:50051"}
|
|
reg.findIdleNode = idleNode
|
|
|
|
result, err := router.Route(context.Background(), "m1", "models/m1.gguf", "llama-cpp", &pb.ModelOptions{}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.ID).To(Equal("idle-vram"))
|
|
})
|
|
|
|
It("falls back to idle when VRAM search fails", func() {
|
|
reg.findVRAMErr = errors.New("no vram")
|
|
idleNode := &BackendNode{ID: "idle-1", Name: "idle-node", Address: "10.0.0.20:50051"}
|
|
reg.findIdleNode = idleNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "m2", "models/m2.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.ID).To(Equal("idle-1"))
|
|
})
|
|
|
|
It("falls back to least-loaded when both VRAM and idle fail", func() {
|
|
reg.findVRAMErr = errors.New("no vram")
|
|
reg.findIdleErr = errors.New("no idle")
|
|
llNode := &BackendNode{ID: "ll-1", Name: "least-loaded", Address: "10.0.0.30:50051"}
|
|
reg.findLeastLoadedNode = llNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "m3", "models/m3.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.ID).To(Equal("ll-1"))
|
|
})
|
|
|
|
It("returns error when no nodes are available and no DB for eviction", func() {
|
|
reg.findVRAMErr = errors.New("no vram")
|
|
reg.findIdleErr = errors.New("no idle")
|
|
reg.findLeastLoadedErr = errors.New("no nodes")
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
// DB is nil — evictLRUAndFreeNode will fail because r.db is nil
|
|
})
|
|
|
|
_, err := router.Route(context.Background(), "m4", "models/m4.gguf", "llama-cpp", nil, false)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("no available nodes"))
|
|
})
|
|
})
|
|
|
|
Describe("UnloadModel (mock-based)", func() {
|
|
It("calls StopBackend and removes the model from the registry", func() {
|
|
reg := &fakeModelRouter{}
|
|
unloader := &fakeUnloader{}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
})
|
|
|
|
err := router.UnloadModel(context.Background(), "node-1", "model-a")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(unloader.stopCalls).To(ContainElement("node-1:model-a"))
|
|
Expect(reg.removeCalls).To(ContainElement("node-1:model-a"))
|
|
})
|
|
|
|
It("returns error when no unloader is configured", func() {
|
|
reg := &fakeModelRouter{}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{})
|
|
|
|
err := router.UnloadModel(context.Background(), "node-1", "model-a")
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("no remote unloader"))
|
|
})
|
|
})
|
|
|
|
Describe("EvictLRU (mock-based)", func() {
|
|
It("finds LRU model and unloads it", func() {
|
|
reg := &fakeModelRouter{
|
|
findLRUModel: &NodeModel{NodeID: "n1", ModelName: "old-model"},
|
|
}
|
|
unloader := &fakeUnloader{}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
})
|
|
|
|
evicted, err := router.EvictLRU(context.Background(), "n1")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(evicted).To(Equal("old-model"))
|
|
Expect(unloader.stopCalls).To(ContainElement("n1:old-model"))
|
|
Expect(reg.removeCalls).To(ContainElement("n1:old-model"))
|
|
})
|
|
|
|
It("returns error when no LRU model is found", func() {
|
|
reg := &fakeModelRouter{
|
|
findLRUErr: errors.New("no models loaded"),
|
|
}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: &fakeUnloader{},
|
|
})
|
|
|
|
_, err := router.EvictLRU(context.Background(), "n1")
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("finding LRU model"))
|
|
})
|
|
})
|
|
|
|
Describe("scheduleNewModel with node selector (mock-based, via Route)", func() {
|
|
var (
|
|
reg *fakeModelRouter
|
|
backend *stubBackend
|
|
factory *stubClientFactory
|
|
unloader *fakeUnloader
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
reg = &fakeModelRouter{
|
|
findAndLockErr: errors.New("not found"),
|
|
}
|
|
backend = &stubBackend{
|
|
loadResult: &pb.Result{Success: true},
|
|
}
|
|
factory = &stubClientFactory{client: backend}
|
|
unloader = &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{
|
|
Success: true,
|
|
Address: "10.0.0.1:9001",
|
|
},
|
|
}
|
|
})
|
|
|
|
It("uses *FromSet methods when model has a node selector", func() {
|
|
gpuNode := &BackendNode{ID: "gpu-1", Name: "gpu-node", Address: "10.0.0.50:50051"}
|
|
reg.getModelScheduling = &ModelSchedulingConfig{
|
|
ModelName: "selector-model",
|
|
NodeSelector: `{"gpu.vendor":"nvidia"}`,
|
|
}
|
|
reg.findBySelectorNodes = []BackendNode{*gpuNode}
|
|
reg.findIdleFromSetNode = gpuNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "selector-model", "models/selector.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.ID).To(Equal("gpu-1"))
|
|
})
|
|
|
|
It("returns error when no nodes match selector", func() {
|
|
reg.getModelScheduling = &ModelSchedulingConfig{
|
|
ModelName: "no-match-model",
|
|
NodeSelector: `{"gpu.vendor":"tpu"}`,
|
|
}
|
|
reg.findBySelectorNodes = nil
|
|
reg.findBySelectorErr = nil
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
_, err := router.Route(context.Background(), "no-match-model", "models/nomatch.gguf", "llama-cpp", nil, false)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("no healthy nodes match selector"))
|
|
})
|
|
|
|
It("uses regular methods when model has no scheduling config", func() {
|
|
reg.getModelScheduling = nil
|
|
idleNode := &BackendNode{ID: "regular-1", Name: "regular-node", Address: "10.0.0.60:50051"}
|
|
reg.findIdleNode = idleNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "regular-model", "models/regular.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.ID).To(Equal("regular-1"))
|
|
})
|
|
})
|
|
|
|
Describe("Route with selector validation on cached model (mock-based)", func() {
|
|
It("falls through when cached node no longer matches selector", func() {
|
|
cachedNode := &BackendNode{ID: "n-old", Name: "old-node", Address: "10.0.0.70:50051"}
|
|
newNode := &BackendNode{ID: "n-new", Name: "new-node", Address: "10.0.0.71:50051"}
|
|
|
|
backend := &stubBackend{
|
|
healthResult: true,
|
|
loadResult: &pb.Result{Success: true},
|
|
}
|
|
factory := &stubClientFactory{client: backend}
|
|
unloader := &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{
|
|
Success: true,
|
|
Address: "10.0.0.71:9001",
|
|
},
|
|
}
|
|
|
|
reg := &fakeModelRouter{
|
|
// Step 1: cached model found on old node
|
|
findAndLockNode: cachedNode,
|
|
findAndLockNM: &NodeModel{NodeID: "n-old", ModelName: "sel-model", Address: "10.0.0.70:9001"},
|
|
// Scheduling config with selector that old node does NOT match
|
|
getModelScheduling: &ModelSchedulingConfig{
|
|
ModelName: "sel-model",
|
|
NodeSelector: `{"gpu.vendor":"nvidia"}`,
|
|
},
|
|
// Old node has no labels matching the selector
|
|
getNodeLabels: []NodeLabel{
|
|
{NodeID: "n-old", Key: "gpu.vendor", Value: "amd"},
|
|
},
|
|
// For scheduling fallthrough: selector matches new node
|
|
findBySelectorNodes: []BackendNode{*newNode},
|
|
findIdleFromSetNode: newNode,
|
|
}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "sel-model", "models/sel.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
// Should have fallen through to the new node
|
|
Expect(result.Node.ID).To(Equal("n-new"))
|
|
// Old node should have had its in-flight decremented
|
|
Expect(reg.decrementCalls).To(ContainElement("n-old:sel-model"))
|
|
})
|
|
})
|
|
|
|
Describe("ScheduleAndLoadModel (mock-based)", func() {
|
|
It("returns an error and does not fire a NATS install when no load info is stored", func() {
|
|
// Reproduces the reconciler scale-up bug: when GetModelLoadInfo
|
|
// returns ErrRecordNotFound (no replica has ever been loaded),
|
|
// the previous fallback called scheduleNewModel with an empty
|
|
// backend type, which the worker rejected on every reconciler
|
|
// tick. The fix bails out cleanly with an explanatory error and
|
|
// never sends backend.install.
|
|
unloader := &fakeUnloader{}
|
|
reg := &fakeModelRouter{}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader})
|
|
|
|
node, err := router.ScheduleAndLoadModel(context.Background(), "never-loaded", nil)
|
|
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(node).To(BeNil())
|
|
Expect(err.Error()).To(ContainSubstring("never-loaded"))
|
|
Expect(unloader.installCalls).To(BeEmpty(),
|
|
"reconciler must not fire backend.install when there is no load info to replicate")
|
|
})
|
|
})
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Integration tests using real PostgreSQL (existing)
|
|
// -----------------------------------------------------------------------
|
|
Describe("evictLRUAndFreeNode (integration)", func() {
|
|
var (
|
|
db *gorm.DB
|
|
registry *NodeRegistry
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
if runtime.GOOS == "darwin" {
|
|
Skip("testcontainers requires Docker, not available on macOS CI")
|
|
}
|
|
db = testutil.SetupTestDB()
|
|
var err error
|
|
registry, err = NewNodeRegistry(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("returns ErrEvictionBusy in under 5 seconds when all models are busy", func() {
|
|
node := &BackendNode{
|
|
Name: "busy-evict",
|
|
NodeType: NodeTypeBackend,
|
|
Address: "10.0.0.100:50051",
|
|
}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Load a model and give it in-flight requests so it cannot be evicted
|
|
Expect(registry.SetNodeModel(context.Background(), node.ID, "busy-model", 0, "loaded", "", 0)).To(Succeed())
|
|
Expect(registry.IncrementInFlight(context.Background(), node.ID, "busy-model", 0)).To(Succeed())
|
|
|
|
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
|
|
|
|
start := time.Now()
|
|
_, err := router.evictLRUAndFreeNode(context.Background())
|
|
elapsed := time.Since(start)
|
|
|
|
Expect(err).To(MatchError(ErrEvictionBusy))
|
|
// 5 retries * 500ms = 2.5s nominal; allow generous upper bound
|
|
Expect(elapsed).To(BeNumerically("<", 5*time.Second))
|
|
})
|
|
|
|
It("respects context cancellation", func() {
|
|
node := &BackendNode{
|
|
Name: "cancel-evict",
|
|
NodeType: NodeTypeBackend,
|
|
Address: "10.0.0.101:50051",
|
|
}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
Expect(registry.SetNodeModel(context.Background(), node.ID, "cancel-model", 0, "loaded", "", 0)).To(Succeed())
|
|
Expect(registry.IncrementInFlight(context.Background(), node.ID, "cancel-model", 0)).To(Succeed())
|
|
|
|
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel() // cancel immediately
|
|
|
|
start := time.Now()
|
|
_, err := router.evictLRUAndFreeNode(ctx)
|
|
elapsed := time.Since(start)
|
|
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("context cancelled"))
|
|
// Should return very quickly since context is already done
|
|
Expect(elapsed).To(BeNumerically("<", 2*time.Second))
|
|
})
|
|
})
|
|
|
|
Describe("stageModelFiles (integration)", func() {
|
|
var (
|
|
db *gorm.DB
|
|
registry *NodeRegistry
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
if runtime.GOOS == "darwin" {
|
|
Skip("testcontainers requires Docker, not available on macOS CI")
|
|
}
|
|
db = testutil.SetupTestDB()
|
|
var err error
|
|
registry, err = NewNodeRegistry(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("does not mutate the original ModelOptions", func() {
|
|
stager := &fakeFileStager{}
|
|
router := NewSmartRouter(registry, SmartRouterOptions{
|
|
FileStager: stager,
|
|
DB: db,
|
|
})
|
|
|
|
node := &BackendNode{
|
|
ID: "stage-node-id",
|
|
Name: "stage-node",
|
|
Address: "10.0.0.200:50051",
|
|
}
|
|
|
|
original := &pb.ModelOptions{
|
|
Model: "test-backend/models/test.gguf",
|
|
ModelFile: "/models/test-backend/models/test.gguf",
|
|
MMProj: "",
|
|
}
|
|
|
|
// Capture original values before staging
|
|
origModel := original.Model
|
|
origModelFile := original.ModelFile
|
|
origMMProj := original.MMProj
|
|
|
|
// stageModelFiles creates temp files for os.Stat checks.
|
|
// Since none of our test paths exist on disk, stageModelFiles will
|
|
// skip them (clearing non-existent optional fields). The key property
|
|
// is that the original proto pointer is not modified.
|
|
_, _ = router.stageModelFiles(context.Background(), node, original, "test-model")
|
|
|
|
// Verify the original proto was not mutated
|
|
Expect(original.Model).To(Equal(origModel))
|
|
Expect(original.ModelFile).To(Equal(origModelFile))
|
|
Expect(original.MMProj).To(Equal(origMMProj))
|
|
})
|
|
})
|
|
|
|
// -----------------------------------------------------------------------
|
|
// narrowByGroupAntiAffinity
|
|
// -----------------------------------------------------------------------
|
|
Describe("narrowByGroupAntiAffinity", func() {
|
|
var (
|
|
reg *fakeModelRouter
|
|
resolver *fakeConflictResolver
|
|
router *SmartRouter
|
|
ctx context.Context
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
reg = &fakeModelRouter{}
|
|
resolver = &fakeConflictResolver{conflicts: map[string][]string{}}
|
|
router = NewSmartRouter(reg, SmartRouterOptions{
|
|
ConflictResolver: resolver,
|
|
})
|
|
ctx = context.Background()
|
|
})
|
|
|
|
It("returns the input set unchanged when the model has no conflicts", func() {
|
|
candidates := []string{"n1", "n2", "n3"}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "lonely", candidates)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(Equal(candidates))
|
|
})
|
|
|
|
It("removes nodes that already host a conflicting model", func() {
|
|
resolver.conflicts["b"] = []string{"a"}
|
|
reg.findNodesWithModelByName = map[string][]BackendNode{
|
|
"a": {{ID: "n1"}},
|
|
}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "b", []string{"n1", "n2"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(ConsistOf("n2"))
|
|
})
|
|
|
|
It("returns the original set unchanged when every candidate has a conflict (soft fallback)", func() {
|
|
resolver.conflicts["b"] = []string{"a"}
|
|
reg.findNodesWithModelByName = map[string][]BackendNode{
|
|
"a": {{ID: "n1"}, {ID: "n2"}},
|
|
}
|
|
candidates := []string{"n1", "n2"}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "b", candidates)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(Equal(candidates))
|
|
})
|
|
|
|
It("removes nodes hosting any of multiple conflicting models", func() {
|
|
resolver.conflicts["c"] = []string{"a", "b"}
|
|
reg.findNodesWithModelByName = map[string][]BackendNode{
|
|
"a": {{ID: "n1"}},
|
|
"b": {{ID: "n2"}},
|
|
}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "c", []string{"n1", "n2", "n3"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(ConsistOf("n3"))
|
|
})
|
|
|
|
It("treats a nil candidate set (\"any healthy node\") by returning nil unchanged when narrowing yields nothing", func() {
|
|
resolver.conflicts["b"] = []string{"a"}
|
|
reg.findNodesWithModelByName = map[string][]BackendNode{
|
|
"a": {{ID: "n1"}, {ID: "n2"}},
|
|
}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "b", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// nil in → nil out: caller's "any healthy node" semantics preserved.
|
|
// Hard-narrowing nil would silently exclude every other node.
|
|
Expect(out).To(BeNil())
|
|
})
|
|
|
|
It("is a no-op when no resolver is configured", func() {
|
|
plain := NewSmartRouter(reg, SmartRouterOptions{})
|
|
candidates := []string{"n1", "n2"}
|
|
out, err := plain.narrowByGroupAntiAffinity(ctx, "b", candidates)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(Equal(candidates))
|
|
})
|
|
})
|
|
|
|
Describe("installBackendOnNode singleflight", func() {
|
|
It("coalesces concurrent identical installs into one NATS call", func() {
|
|
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
|
|
|
|
// Slow install reply so concurrent calls overlap deterministically.
|
|
started := make(chan struct{}, 5)
|
|
release := make(chan struct{})
|
|
unloader := &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"},
|
|
}
|
|
unloader.installHook = func() {
|
|
started <- struct{}{}
|
|
<-release
|
|
}
|
|
|
|
router := NewSmartRouter(&fakeModelRouter{}, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: &stubClientFactory{client: &stubBackend{}},
|
|
})
|
|
|
|
// Fire 5 concurrent identical installBackendOnNode calls.
|
|
done := make(chan error, 5)
|
|
for i := 0; i < 5; i++ {
|
|
go func() {
|
|
_, err := router.installBackendOnNode(context.Background(), node, "llama-cpp", "my-model", 0)
|
|
done <- err
|
|
}()
|
|
}
|
|
|
|
// Only ONE call should have entered the unloader hook (the
|
|
// singleflight leader). The other 4 are coalesced and waiting on
|
|
// the leader's result.
|
|
Eventually(started).Should(Receive())
|
|
Consistently(started, 100*time.Millisecond).ShouldNot(Receive())
|
|
|
|
// Release the leader; the other 4 callers receive the same result.
|
|
close(release)
|
|
for i := 0; i < 5; i++ {
|
|
Expect(<-done).ToNot(HaveOccurred())
|
|
}
|
|
Expect(unloader.installCalls).To(HaveLen(1),
|
|
"singleflight should coalesce 5 concurrent identical loads into 1 NATS call")
|
|
})
|
|
|
|
It("does NOT coalesce installs for different (modelID, replica) keys", func() {
|
|
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
|
|
unloader := &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"},
|
|
}
|
|
router := NewSmartRouter(&fakeModelRouter{}, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: &stubClientFactory{client: &stubBackend{}},
|
|
})
|
|
|
|
_, err1 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-A", 0)
|
|
_, err2 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-B", 0)
|
|
_, err3 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-A", 1)
|
|
Expect(err1).ToNot(HaveOccurred())
|
|
Expect(err2).ToNot(HaveOccurred())
|
|
Expect(err3).ToNot(HaveOccurred())
|
|
Expect(unloader.installCalls).To(HaveLen(3))
|
|
})
|
|
})
|
|
})
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake prefixcache.Provider for SmartRouter prefix-cache routing tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
type observeRecord struct {
|
|
model string
|
|
chain []uint64
|
|
key prefixcache.ReplicaKey
|
|
}
|
|
|
|
type invalidateRecord struct {
|
|
model string
|
|
key prefixcache.ReplicaKey
|
|
}
|
|
|
|
// fakePrefixProvider records all interactions and returns a configurable
|
|
// decision.
|
|
type fakePrefixProvider struct {
|
|
decideCalls int
|
|
observed []observeRecord
|
|
invalidated []invalidateRecord
|
|
invalidatedNode []string
|
|
decision prefixcache.PrefixDecision
|
|
}
|
|
|
|
func (f *fakePrefixProvider) Decide(_ string, _ []uint64, _ []prefixcache.ReplicaKey, _ time.Time) prefixcache.PrefixDecision {
|
|
f.decideCalls++
|
|
return f.decision
|
|
}
|
|
|
|
func (f *fakePrefixProvider) Observe(model string, chain []uint64, key prefixcache.ReplicaKey, _ time.Time) bool {
|
|
f.observed = append(f.observed, observeRecord{model: model, chain: append([]uint64(nil), chain...), key: key})
|
|
return true
|
|
}
|
|
|
|
func (f *fakePrefixProvider) Invalidate(model string, key prefixcache.ReplicaKey) {
|
|
f.invalidated = append(f.invalidated, invalidateRecord{model: model, key: key})
|
|
}
|
|
|
|
func (f *fakePrefixProvider) InvalidateNode(model, nodeID string) {
|
|
f.invalidatedNode = append(f.invalidatedNode, model+":"+nodeID)
|
|
}
|
|
|
|
func (f *fakePrefixProvider) Evict(_ time.Time) {}
|
|
|
|
var _ = Describe("SmartRouter prefix-cache routing", func() {
|
|
var (
|
|
backend *stubBackend
|
|
factory *stubClientFactory
|
|
unloader *fakeUnloader
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
backend = &stubBackend{healthResult: true}
|
|
factory = &stubClientFactory{client: backend}
|
|
unloader = &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:9001"},
|
|
}
|
|
})
|
|
|
|
// loadedReg builds a fake registry with one loaded healthy replica for
|
|
// "m" on node "X", plus matching replica stats so buildPreference can run.
|
|
loadedReg := func() *fakeModelRouter {
|
|
node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
|
|
nm := &NodeModel{NodeID: "X", ModelName: "m", Address: "10.0.0.1:9001"}
|
|
return &fakeModelRouter{
|
|
findAndLockNode: node,
|
|
findAndLockNM: nm,
|
|
getModelScheduling: &ModelSchedulingConfig{
|
|
RoutePolicy: "prefix_cache",
|
|
},
|
|
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
|
"m": {{NodeID: "X", InFlight: 0}},
|
|
},
|
|
}
|
|
}
|
|
|
|
Context("nil provider (round-robin floor)", func() {
|
|
It("passes a nil preference and never decides or observes", func() {
|
|
reg := loadedReg()
|
|
router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader, ClientFactory: factory})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(reg.findAndLockPrefs).ToNot(BeEmpty())
|
|
for _, p := range reg.findAndLockPrefs {
|
|
Expect(p).To(BeNil())
|
|
}
|
|
})
|
|
})
|
|
|
|
Context("with a provider", func() {
|
|
It("passes the decided node as the preference and observes the pick", func() {
|
|
reg := loadedReg()
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(prov.decideCalls).To(BeNumerically(">=", 1))
|
|
Expect(reg.findAndLockPrefs[0]).ToNot(BeNil())
|
|
Expect(reg.findAndLockPrefs[0].PreferredNodeID).To(Equal("X"))
|
|
Expect(reg.findAndLockPrefs[0].PreferredReplica).To(Equal(0))
|
|
Expect(prov.observed).To(HaveLen(1))
|
|
Expect(prov.observed[0].key).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
|
Expect(prov.observed[0].chain).To(Equal([]uint64{1, 2, 3}))
|
|
})
|
|
|
|
It("routes a recurring prefix back to the previously observed node", func() {
|
|
// Real Index as the provider: first request observes X, second
|
|
// request with the same chain must yield PreferredNodeID == X.
|
|
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
|
reg := loadedReg()
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: idx,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{7, 8, 9})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// First request landed on X (cold placement on the only candidate)
|
|
// and observed the prefix there.
|
|
dFirst := idx.Decide("m", []uint64{7, 8, 9}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now())
|
|
Expect(dFirst.HasHot).To(BeTrue())
|
|
Expect(dFirst.Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
|
|
|
// Second request, same chain: X is now the warm-cache hot match, so
|
|
// the preference must point at it.
|
|
_, err = router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
last := reg.findAndLockPrefs[len(reg.findAndLockPrefs)-1]
|
|
Expect(last).ToNot(BeNil())
|
|
Expect(last.PreferredNodeID).To(Equal("X"))
|
|
Expect(last.PreferredReplica).To(Equal(0))
|
|
})
|
|
|
|
It("prefers the exact hot replica when two replicas share a node", func() {
|
|
// Two replicas of "m" live on the SAME node X: replica 0 and replica
|
|
// 1. A hot prefix observed on (X,0) must produce a preference that
|
|
// locks replica 0 specifically, NOT the sibling replica 1 on the same
|
|
// node. This is the replica-granular regression this change fixes.
|
|
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
|
node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
|
|
nm := &NodeModel{NodeID: "X", ModelName: "m", ReplicaIndex: 0, Address: "10.0.0.1:9001"}
|
|
reg := &fakeModelRouter{
|
|
findAndLockNode: node,
|
|
findAndLockNM: nm,
|
|
getModelScheduling: &ModelSchedulingConfig{
|
|
RoutePolicy: "prefix_cache",
|
|
},
|
|
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
|
"m": {
|
|
{NodeID: "X", ReplicaIndex: 0, InFlight: 0},
|
|
{NodeID: "X", ReplicaIndex: 1, InFlight: 0},
|
|
},
|
|
},
|
|
}
|
|
// Seed the index so (X,0) is the warm replica for this chain.
|
|
idx.Observe("m", []uint64{1, 2, 3}, prefixcache.ReplicaKey{NodeID: "X", Replica: 0}, time.Now())
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: idx,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
pref := reg.findAndLockPrefs[0]
|
|
Expect(pref).ToNot(BeNil())
|
|
Expect(pref.PreferredNodeID).To(Equal("X"))
|
|
Expect(pref.PreferredReplica).To(Equal(0),
|
|
"the hot prefix lives on replica 0; the same-node sibling replica 1 must NOT be chosen")
|
|
})
|
|
|
|
It("does not decide or observe when no prefix chain is present", func() {
|
|
reg := loadedReg()
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
_, err := router.Route(context.Background(), "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(prov.decideCalls).To(Equal(0))
|
|
Expect(prov.observed).To(BeEmpty())
|
|
Expect(reg.findAndLockPrefs[0]).To(BeNil())
|
|
})
|
|
|
|
It("does not observe for round-robin models even with a chain", func() {
|
|
reg := loadedReg()
|
|
reg.getModelScheduling = &ModelSchedulingConfig{RoutePolicy: "round_robin"}
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(prov.decideCalls).To(Equal(0))
|
|
Expect(prov.observed).To(BeEmpty())
|
|
Expect(reg.findAndLockPrefs[0]).To(BeNil())
|
|
})
|
|
})
|
|
|
|
Context("forced-disturb pressure", func() {
|
|
// disturbReg builds a registry with two candidate replicas for "m":
|
|
// the hot node X is saturated (high in_flight) and Y is free. Select
|
|
// will therefore reject the hot node and pick Y, which is the
|
|
// forced-disturb signal. findAndLockNode returns Y so Route succeeds.
|
|
disturbReg := func() *fakeModelRouter {
|
|
nodeY := &BackendNode{ID: "Y", Name: "node-y", Address: "10.0.0.2:50051"}
|
|
nm := &NodeModel{NodeID: "Y", ModelName: "m", Address: "10.0.0.2:9001"}
|
|
return &fakeModelRouter{
|
|
findAndLockNode: nodeY,
|
|
findAndLockNM: nm,
|
|
getModelScheduling: &ModelSchedulingConfig{
|
|
RoutePolicy: "prefix_cache",
|
|
},
|
|
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
|
"m": {{NodeID: "X", InFlight: 50}, {NodeID: "Y", InFlight: 0}},
|
|
},
|
|
}
|
|
}
|
|
|
|
It("records pressure when a strong hot match was forced off the warm node", func() {
|
|
reg := disturbReg()
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
|
Hot: prefixcache.ReplicaKey{NodeID: "X"},
|
|
HasHot: true,
|
|
MatchRatio: 1.0,
|
|
ColdOrder: []prefixcache.ReplicaKey{{NodeID: "Y"}, {NodeID: "X"}},
|
|
}}
|
|
pressure := prefixcache.NewPressure(time.Minute)
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
Pressure: pressure,
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(pressure.Count("m", time.Now())).To(BeNumerically(">", 0),
|
|
"hot match existed but the load guard forced us off X: must record pressure")
|
|
})
|
|
|
|
It("does not record pressure when the hot node is itself eligible", func() {
|
|
reg := loadedReg() // single node X, in_flight 0 → X stays eligible
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
|
Hot: prefixcache.ReplicaKey{NodeID: "X"},
|
|
HasHot: true,
|
|
MatchRatio: 1.0,
|
|
}}
|
|
pressure := prefixcache.NewPressure(time.Minute)
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
Pressure: pressure,
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(pressure.Count("m", time.Now())).To(Equal(0),
|
|
"chosen == hot node, no disturb")
|
|
})
|
|
|
|
It("does not record pressure for an all-unique workload with no hot match", func() {
|
|
reg := loadedReg()
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
|
HasHot: false, // no prefix match at all
|
|
MatchRatio: 0,
|
|
ColdOrder: []prefixcache.ReplicaKey{{NodeID: "X"}},
|
|
}}
|
|
pressure := prefixcache.NewPressure(time.Minute)
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
Pressure: pressure,
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(pressure.Count("m", time.Now())).To(Equal(0),
|
|
"no hot match means no cache to disturb: must not false-positive")
|
|
})
|
|
})
|
|
|
|
Context("removal chokepoint on unload", func() {
|
|
It("removes the replica via the registry so the removal hook invalidates the prefix entry", func() {
|
|
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
|
reg := loadedReg()
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: idx,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{5, 6})
|
|
// Warm the cache: X now holds the prefix.
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(idx.Decide("m", []uint64{5, 6}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now()).Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
|
|
|
// UnloadModel must route the eviction through the registry removal
|
|
// chokepoint (RemoveAllNodeModelReplicas). The registry's
|
|
// SetReplicaRemovedHook is what invalidates the prefix index in
|
|
// production; the router no longer invalidates directly. Here the
|
|
// fake registry records the removal but fires no hook, so we assert
|
|
// the chokepoint is exercised rather than the downstream
|
|
// invalidation (covered by the registry hook integration tests).
|
|
Expect(router.UnloadModel(context.Background(), "X", "m")).To(Succeed())
|
|
Expect(reg.removeCalls).To(ContainElement("X:m"),
|
|
"UnloadModel must remove the replica via the registry removal chokepoint")
|
|
})
|
|
})
|
|
})
|