mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-12 10:47:23 -04:00
* fix(galleryop): self-evict terminal ops from OpCache.GetStatus The processingBackends map (the UI 'reinstalling' spinner source) only cleared an op when a client polled /api/backends/job/:uid. The Manage-page Reinstall and Upgrade buttons never poll, so completed installs leaked into processingBackends forever and the backend card spun 'reinstalling' even though the install had finished. Evict terminal ops on the list read instead; DeleteUUID already broadcasts the eviction so peer replicas converge. Reproduced on a live 5-node distributed cluster: 5 backends sat in processingBackends with underlying jobs reporting completed:true,progress:100. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(nodes): clear pending backend ops behind offline/draining nodes ListDuePendingBackendOps filters status=healthy, so a backend op queued against a node that went offline (stale heartbeat) or draining (admin action) was never retried, aged out, or deleted - it leaked forever and kept the UI operation spinning. Add DeleteStalePendingBackendOps and run it each reconcile pass: draining nodes are cleared immediately (model rows already purged), offline nodes once their heartbeat is older than a grace window (blip protection). Reproduced on a live cluster: orphaned llama-cpp install rows targeting an offline (nvidia-thor) and a draining (mac-mini-m4) node sat at attempts=0 indefinitely. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(nodes): stream per-node progress during backend upgrade The install dispatch subscribed to a per-op progress subject and streamed per-node download ticks; the upgrade dispatch did a bare 15-minute blocking NATS round-trip with no subscription, so the UI showed progress:0 the whole time (the 'reinstalling but nothing happens' report on a slow node). Thread the op ID through BackendManager.UpgradeBackend -> the distributed manager -> the adapter, and have the adapter subscribe to the per-op progress subject before the request (extracted into a shared subscribeProgress helper reused by install/upgrade/force-fallback). The worker's upgradeBackend now creates the same DebouncedInstallProgressPublisher installBackend uses. An upgrade is a force-reinstall, so it reuses SubjectNodeBackendInstallProgress rather than minting a new subject - no new NATS permission, no new rolling-update compat surface. Reconciler-driven retries pass empty opID/onProgress and stay on the silent path. Reproduced on a live cluster: upgrade of llama-cpp-development on agx-orin-slow sat at progress:0 for 4+ minutes with no per-node feedback. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(galleryop): persist cancellation + periodically reap orphaned ops Two distributed gaps surfaced when a replica was killed mid-upgrade on a live cluster, leaving the backend stuck 'processing' in the UI forever: 1. CancelOperation flipped the in-memory status to cancelled and broadcast a NATS event but never persisted the terminal status. On the next replica restart the still-active row re-hydrated straight back into processingBackends and the UI spun again. It now calls store.Cancel(id) so the cancel survives a restart. 2. CleanStale (which marks abandoned active ops failed) only ran once on startup, so an op orphaned AFTER startup - its owning replica's foreground handler goroutine gone - was never reaped until the next restart. Add GalleryService.ReapStaleOperations and run it on a 15m ticker (CleanStale now returns the reaped count for observability). Neither is covered by the OpCache self-evict fix: an orphaned op never reaches Processed, so it would never self-evict. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(review): address self-review findings on the distributed install fixes Three findings from an adversarial review of this branch: 1. CRITICAL - OpCache.GetStatus crashed under concurrent load. m.Map() returns the live internal map by reference, so deleting from it on the read path was an unsynchronized write to a map four HTTP handlers poll every ~1s -> a 'concurrent map writes' fatal. Rewritten to iterate a Keys() snapshot, build a fresh result map, and apply evictions via the locked DeleteUUID after the loop. Added a -race concurrency regression guard. 2. HIGH - GetStatus evicted failed ops too, hiding them from /api/operations and breaking the dismiss-failed-op flow (the panel keeps Error != nil ops so the admin can read the error and click Dismiss). Eviction now fires only for terminal ops with Error == nil (success/cancelled); failures are retained. 3. MEDIUM - DeleteStalePendingBackendOps missed StatusUnhealthy nodes. A node marked unhealthy on a NATS ErrNoResponders never transitions to offline (health.go skips re-marking it), so its pending ops leaked exactly like the offline case. Unhealthy is now reaped via the same stale-heartbeat grace path (a fresh-heartbeat node is recovering and keeps its op). Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(review-2): don't evict the still-installing soft-path; don't spin on failed ops Second review pass found two issues: 1. MEDIUM (Go) - OpCache.GetStatus evicted the ErrWorkerStillInstalling soft-path op. That op is deliberately Processed=true with no error to show a yellow in-progress state when a worker timed out the NATS round-trip but is still installing in the background; the reconciler confirms the real outcome later. Evicting it (and broadcasting OpEnd + marking the DB completed) hid an install that may still fail. Eviction is now scoped to a clean success (progress 100 + 'completed', matching the job-poll's historical condition) or a cancellation - the soft-path (progress != 100) and failures are kept. 2. MEDIUM (React) - the Backends gallery card rendered ANY operation as an 'Installing...' spinner, so a failed op (now intentionally kept in the list for the OperationsBar error + Dismiss) spun forever. Exclude errored ops from the card spinner, mirroring Models.jsx (isInstalling already excludes op.error). The error + Dismiss still surface in the global OperationsBar. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(ui): refresh Manage backends table when an operation settles The Manage backends table fetched installed backends only on mount/after delete and checked upgrades only on tab activation. After a reinstall/upgrade completed neither re-ran, so the installed-version cell and the 'update available' badge stayed stale until the user switched tabs - the op looked like it 'did nothing'. Watch the operations list (via useOperations) and re-fetch installed backends + available upgrades whenever the count settles, mirroring the operations.length watch Backends.jsx already uses. Consolidates the prior tab-activation upgrades check into the same effect. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
1428 lines
50 KiB
Go
1428 lines
50 KiB
Go
package nodes
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
"github.com/mudler/LocalAI/core/services/messaging"
|
|
"github.com/mudler/LocalAI/core/services/nodes/prefixcache"
|
|
"github.com/mudler/LocalAI/core/services/testutil"
|
|
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
ggrpc "google.golang.org/grpc"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake FileStager (pre-existing)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// fakeFileStager is a minimal FileStager that records calls and returns
|
|
// predictable remote paths without touching the filesystem or network.
|
|
type fakeFileStager struct {
|
|
ensureCalls []ensureCall
|
|
}
|
|
|
|
type ensureCall struct {
|
|
nodeID, localPath, key string
|
|
}
|
|
|
|
func (f *fakeFileStager) EnsureRemote(_ context.Context, nodeID, localPath, key string) (string, error) {
|
|
f.ensureCalls = append(f.ensureCalls, ensureCall{nodeID, localPath, key})
|
|
return "/remote/" + key, nil
|
|
}
|
|
|
|
func (f *fakeFileStager) FetchRemote(_ context.Context, _, _, _ string) error { return nil }
|
|
|
|
func (f *fakeFileStager) FetchRemoteByKey(_ context.Context, _, _, _ string) error { return nil }
|
|
|
|
func (f *fakeFileStager) AllocRemoteTemp(_ context.Context, _ string) (string, error) {
|
|
return "/remote/tmp", nil
|
|
}
|
|
|
|
func (f *fakeFileStager) StageRemoteToStore(_ context.Context, _, _, _ string) error { return nil }
|
|
|
|
func (f *fakeFileStager) ListRemoteDir(_ context.Context, _, _ string) ([]string, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake ModelRouter
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// fakeModelRouter implements ModelRouter with configurable return values.
|
|
type fakeModelRouter struct {
|
|
// FindAndLockNodeWithModel returns
|
|
findAndLockNode *BackendNode
|
|
findAndLockNM *NodeModel
|
|
findAndLockErr error
|
|
|
|
// FindNodeWithVRAM returns
|
|
findVRAMNode *BackendNode
|
|
findVRAMErr error
|
|
|
|
// FindIdleNode returns
|
|
findIdleNode *BackendNode
|
|
findIdleErr error
|
|
|
|
// FindLeastLoadedNode returns
|
|
findLeastLoadedNode *BackendNode
|
|
findLeastLoadedErr error
|
|
|
|
// FindGlobalLRUModelWithZeroInFlight returns
|
|
findGlobalLRUModel *NodeModel
|
|
findGlobalLRUErr error
|
|
|
|
// FindLRUModel returns
|
|
findLRUModel *NodeModel
|
|
findLRUErr error
|
|
|
|
// Get returns
|
|
getNode *BackendNode
|
|
getErr error
|
|
|
|
// GetModelScheduling returns
|
|
getModelScheduling *ModelSchedulingConfig
|
|
getModelSchedErr error
|
|
|
|
// FindNodesBySelector returns
|
|
findBySelectorNodes []BackendNode
|
|
findBySelectorErr error
|
|
|
|
// *FromSet variants
|
|
findVRAMFromSetNode *BackendNode
|
|
findVRAMFromSetErr error
|
|
findIdleFromSetNode *BackendNode
|
|
findIdleFromSetErr error
|
|
findLeastLoadedFromSetNode *BackendNode
|
|
findLeastLoadedFromSetErr error
|
|
|
|
// GetNodeLabels returns
|
|
getNodeLabels []NodeLabel
|
|
getNodeLabelsErr error
|
|
|
|
// FindNodesWithModel returns (keyed by model name)
|
|
findNodesWithModelByName map[string][]BackendNode
|
|
findNodesWithModelErr error
|
|
|
|
// LoadedReplicaStats returns (keyed by model name)
|
|
loadedReplicaStatsByName map[string][]ReplicaCandidate
|
|
loadedReplicaStatsErr error
|
|
|
|
// Track calls for assertions
|
|
decrementCalls []string // "nodeID:modelName"
|
|
incrementCalls []string
|
|
removeCalls []string
|
|
setCalls []string
|
|
touchCalls []string
|
|
|
|
// Preferences passed to FindAndLockNodeWithModel, in call order. nil
|
|
// entries are recorded too, so tests can assert "preference was nil".
|
|
findAndLockPrefs []*RoutePreference
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string, _ []string, pref *RoutePreference) (*BackendNode, *NodeModel, error) {
|
|
f.findAndLockPrefs = append(f.findAndLockPrefs, pref)
|
|
return f.findAndLockNode, f.findAndLockNM, f.findAndLockErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) LoadedReplicaStats(_ context.Context, modelName string, _ []string) ([]ReplicaCandidate, error) {
|
|
if f.loadedReplicaStatsErr != nil {
|
|
return nil, f.loadedReplicaStatsErr
|
|
}
|
|
return f.loadedReplicaStatsByName[modelName], nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
|
|
f.decrementCalls = append(f.decrementCalls, nodeID+":"+modelName)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) IncrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
|
|
f.incrementCalls = append(f.incrementCalls, nodeID+":"+modelName)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) RemoveNodeModel(_ context.Context, nodeID, modelName string, _ int) error {
|
|
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) RemoveAllNodeModelReplicas(_ context.Context, nodeID, modelName string) error {
|
|
// Same recorded key as RemoveNodeModel so existing tests that assert "the
|
|
// model was removed" don't need to know whether the production code used
|
|
// the per-replica or all-replicas variant.
|
|
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) TouchNodeModel(_ context.Context, nodeID, modelName string, _ int) {
|
|
f.touchCalls = append(f.touchCalls, nodeID+":"+modelName)
|
|
}
|
|
|
|
func (f *fakeModelRouter) SetNodeModel(_ context.Context, nodeID, modelName string, _ int, state, address string, _ int) error {
|
|
f.setCalls = append(f.setCalls, fmt.Sprintf("%s:%s:%s:%s", nodeID, modelName, state, address))
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) SetNodeModelLoadInfo(_ context.Context, _, _ string, _ int, _ string, _ []byte) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) UpsertModelLoadInfo(_ context.Context, _, _ string, _ []byte) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) GetModelLoadInfo(_ context.Context, _ string) (string, []byte, error) {
|
|
return "", nil, fmt.Errorf("not found")
|
|
}
|
|
|
|
func (f *fakeModelRouter) NextFreeReplicaIndex(_ context.Context, _, _ string, _ int) (int, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) CountReplicasOnNode(_ context.Context, _, _ string) (int, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodeWithVRAM(_ context.Context, _ uint64) (*BackendNode, error) {
|
|
return f.findVRAMNode, f.findVRAMErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindIdleNode(_ context.Context) (*BackendNode, error) {
|
|
return f.findIdleNode, f.findIdleErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindLeastLoadedNode(_ context.Context) (*BackendNode, error) {
|
|
return f.findLeastLoadedNode, f.findLeastLoadedErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindGlobalLRUModelWithZeroInFlight(_ context.Context) (*NodeModel, error) {
|
|
return f.findGlobalLRUModel, f.findGlobalLRUErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindLRUModel(_ context.Context, _ string) (*NodeModel, error) {
|
|
return f.findLRUModel, f.findLRUErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) Get(_ context.Context, _ string) (*BackendNode, error) {
|
|
return f.getNode, f.getErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
|
|
return f.getModelScheduling, f.getModelSchedErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
|
|
return f.findBySelectorNodes, f.findBySelectorErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodesWithFreeSlot(_ context.Context, _ string, _ []string) ([]BackendNode, error) {
|
|
// Default: same answer as FindNodesBySelector. Tests that need a
|
|
// specific filter can override by reusing findBySelectorNodes.
|
|
return f.findBySelectorNodes, f.findBySelectorErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) ReserveVRAM(_ context.Context, _ string, _ uint64) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) ReleaseVRAM(_ context.Context, _ string, _ uint64) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
|
|
return f.findVRAMFromSetNode, f.findVRAMFromSetErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
|
|
return f.findIdleFromSetNode, f.findIdleFromSetErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
|
|
return f.findLeastLoadedFromSetNode, f.findLeastLoadedFromSetErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
|
|
return f.getNodeLabels, f.getNodeLabelsErr
|
|
}
|
|
|
|
func (f *fakeModelRouter) FindNodesWithModel(_ context.Context, modelName string) ([]BackendNode, error) {
|
|
if f.findNodesWithModelErr != nil {
|
|
return nil, f.findNodesWithModelErr
|
|
}
|
|
return f.findNodesWithModelByName[modelName], nil
|
|
}
|
|
|
|
// fakeConflictResolver implements ConcurrencyConflictResolver from a static map.
|
|
type fakeConflictResolver struct {
|
|
conflicts map[string][]string
|
|
}
|
|
|
|
func (f *fakeConflictResolver) GetModelsConflictingWith(name string) []string {
|
|
if f == nil {
|
|
return nil
|
|
}
|
|
return f.conflicts[name]
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake BackendClientFactory + Backend
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// stubBackend implements grpc.Backend with configurable HealthCheck and LoadModel.
|
|
type stubBackend struct {
|
|
grpc.Backend // embed to satisfy interface; unused methods will panic if called
|
|
|
|
healthResult bool
|
|
healthErr error
|
|
loadResult *pb.Result
|
|
loadErr error
|
|
}
|
|
|
|
func (f *stubBackend) HealthCheck(_ context.Context) (bool, error) {
|
|
return f.healthResult, f.healthErr
|
|
}
|
|
|
|
func (f *stubBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return f.loadResult, f.loadErr
|
|
}
|
|
|
|
func (f *stubBackend) IsBusy() bool { return false }
|
|
|
|
// stubClientFactory returns the same stubBackend for every call.
|
|
type stubClientFactory struct {
|
|
client *stubBackend
|
|
}
|
|
|
|
func (f *stubClientFactory) NewClient(_ string, _ bool) grpc.Backend {
|
|
return f.client
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake NodeCommandSender (unloader)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
type fakeUnloader struct {
|
|
// mu guards installCalls and upgradeCalls so concurrent test
|
|
// goroutines (e.g. singleflight specs) don't race the slice appends.
|
|
mu sync.Mutex
|
|
|
|
installReply *messaging.BackendInstallReply
|
|
installErr error
|
|
installCalls []installCall // every InstallBackend invocation, in order
|
|
// installHook, if non-nil, runs at the start of InstallBackend before
|
|
// the call is recorded. Used by concurrency tests as a deterministic
|
|
// "block here" seam — set installHook to a function that sleeps or
|
|
// blocks on a channel to overlap two callers.
|
|
installHook func()
|
|
|
|
upgradeReply *messaging.BackendUpgradeReply
|
|
upgradeErr error
|
|
upgradeCalls []upgradeCall // every UpgradeBackend invocation, in order
|
|
|
|
stopCalls []string // "nodeID:model"
|
|
stopErr error
|
|
unloadCalls []string
|
|
unloadErr error
|
|
}
|
|
|
|
// installCall captures the args we care about when asserting that the
|
|
// reconciler / router did or did not fire a NATS install. The fake records
|
|
// every call so tests can verify both presence and shape (e.g. that backend
|
|
// is non-empty).
|
|
type installCall struct {
|
|
nodeID string
|
|
backend string
|
|
modelID string
|
|
replica int
|
|
}
|
|
|
|
type upgradeCall struct {
|
|
nodeID string
|
|
backend string
|
|
replica int
|
|
}
|
|
|
|
func (f *fakeUnloader) InstallBackend(nodeID, backend, modelID, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendInstallReply, error) {
|
|
// installHook intentionally runs OUTSIDE the mutex: the hook may block
|
|
// on a channel and we don't want to serialize concurrent callers,
|
|
// which would defeat the singleflight-overlap test.
|
|
if f.installHook != nil {
|
|
f.installHook()
|
|
}
|
|
f.mu.Lock()
|
|
f.installCalls = append(f.installCalls, installCall{nodeID, backend, modelID, replica})
|
|
f.mu.Unlock()
|
|
return f.installReply, f.installErr
|
|
}
|
|
|
|
func (f *fakeUnloader) UpgradeBackend(nodeID, backend, _, _, _, _ string, replica int, _ string, _ func(messaging.BackendInstallProgressEvent)) (*messaging.BackendUpgradeReply, error) {
|
|
f.mu.Lock()
|
|
f.upgradeCalls = append(f.upgradeCalls, upgradeCall{nodeID, backend, replica})
|
|
f.mu.Unlock()
|
|
return f.upgradeReply, f.upgradeErr
|
|
}
|
|
|
|
func (f *fakeUnloader) DeleteBackend(_, _ string) (*messaging.BackendDeleteReply, error) {
|
|
return &messaging.BackendDeleteReply{Success: true}, nil
|
|
}
|
|
|
|
func (f *fakeUnloader) ListBackends(_ string) (*messaging.BackendListReply, error) {
|
|
return &messaging.BackendListReply{}, nil
|
|
}
|
|
|
|
func (f *fakeUnloader) StopBackend(nodeID, backend string) error {
|
|
f.stopCalls = append(f.stopCalls, nodeID+":"+backend)
|
|
return f.stopErr
|
|
}
|
|
|
|
func (f *fakeUnloader) UnloadModelOnNode(nodeID, modelName string) error {
|
|
f.unloadCalls = append(f.unloadCalls, nodeID+":"+modelName)
|
|
return f.unloadErr
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
var _ = Describe("SmartRouter", func() {
|
|
// -----------------------------------------------------------------------
|
|
// Unit tests using mock interfaces (no DB required)
|
|
// -----------------------------------------------------------------------
|
|
Describe("Route (mock-based)", func() {
|
|
var (
|
|
reg *fakeModelRouter
|
|
backend *stubBackend
|
|
factory *stubClientFactory
|
|
unloader *fakeUnloader
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
reg = &fakeModelRouter{}
|
|
backend = &stubBackend{}
|
|
factory = &stubClientFactory{client: backend}
|
|
unloader = &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{
|
|
Success: true,
|
|
Address: "10.0.0.1:9001",
|
|
},
|
|
}
|
|
})
|
|
|
|
Context("model already loaded on a healthy node", func() {
|
|
It("returns the client and a release function", func() {
|
|
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
|
|
nm := &NodeModel{NodeID: "n1", ModelName: "my-model", Address: "10.0.0.1:9001"}
|
|
reg.findAndLockNode = node
|
|
reg.findAndLockNM = nm
|
|
backend.healthResult = true
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "my-model", "models/my-model.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.ID).To(Equal("n1"))
|
|
|
|
// TouchNodeModel should have been called
|
|
Expect(reg.touchCalls).To(ContainElement("n1:my-model"))
|
|
|
|
// The initial in-flight reservation from FindAndLockNodeWithModel is released
|
|
// after the first inference call completes via OnFirstComplete callback.
|
|
// Release only closes the client.
|
|
result.Release()
|
|
// No decrement on Release — it happens via OnFirstComplete after first Predict
|
|
Expect(reg.decrementCalls).To(BeEmpty())
|
|
})
|
|
})
|
|
|
|
Context("model not loaded, falls through to scheduling", func() {
|
|
It("schedules on an idle node and records the model", func() {
|
|
// FindAndLockNodeWithModel always fails — simulates no cached model
|
|
// (equivalent to the health-check-failure fallthrough path).
|
|
idleNode := &BackendNode{ID: "n2", Name: "idle-node", Address: "10.0.0.2:50051"}
|
|
reg2 := &fakeModelRouter{
|
|
findAndLockErr: errors.New("not found"),
|
|
findIdleNode: idleNode,
|
|
}
|
|
backend.loadResult = &pb.Result{Success: true}
|
|
|
|
router := NewSmartRouter(reg2, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "some-model", "models/some-model.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.ID).To(Equal("n2"))
|
|
|
|
// SetNodeModel should record the model as loaded on the node
|
|
Expect(reg2.setCalls).To(HaveLen(1))
|
|
Expect(reg2.setCalls[0]).To(ContainSubstring("n2:some-model:loaded"))
|
|
})
|
|
})
|
|
|
|
Context("model not loaded, no DB (advisory lock bypassed)", func() {
|
|
It("schedules on an available node via FindIdleNode", func() {
|
|
reg.findAndLockErr = errors.New("not found")
|
|
idleNode := &BackendNode{ID: "n3", Name: "idle", Address: "10.0.0.3:50051"}
|
|
reg.findIdleNode = idleNode
|
|
backend.loadResult = &pb.Result{Success: true}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
// DB is nil — no advisory lock
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "new-model", "models/new.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.ID).To(Equal("n3"))
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("scheduleNewModel (mock-based, via Route)", func() {
|
|
var (
|
|
reg *fakeModelRouter
|
|
backend *stubBackend
|
|
factory *stubClientFactory
|
|
unloader *fakeUnloader
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
reg = &fakeModelRouter{
|
|
findAndLockErr: errors.New("not found"),
|
|
}
|
|
backend = &stubBackend{
|
|
loadResult: &pb.Result{Success: true},
|
|
}
|
|
factory = &stubClientFactory{client: backend}
|
|
unloader = &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{
|
|
Success: true,
|
|
Address: "10.0.0.1:9001",
|
|
},
|
|
}
|
|
})
|
|
|
|
It("finds a node with sufficient VRAM first", func() {
|
|
vramNode := &BackendNode{ID: "vram-node", Name: "gpu-box", Address: "10.0.0.10:50051"}
|
|
reg.findVRAMNode = vramNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
// Pass non-nil ModelOptions so estimateModelVRAM runs (returns 0 for
|
|
// missing files, so FindNodeWithVRAM won't actually be called unless
|
|
// estimatedVRAM > 0). To trigger VRAM path we need estimatedVRAM > 0,
|
|
// but that requires real files. Instead test the fallback: VRAM returns
|
|
// error, idle succeeds.
|
|
// Actually, estimateModelVRAM returns 0 when model files don't exist,
|
|
// so the VRAM branch is skipped and we go to idle/least-loaded.
|
|
// To properly test VRAM path, we'd need to mock estimateModelVRAM.
|
|
// For now, verify the fallback paths work correctly.
|
|
|
|
// With no real model files, estimatedVRAM=0, so VRAM path is skipped.
|
|
// Set idle node to test that path.
|
|
reg.findVRAMNode = nil
|
|
reg.findVRAMErr = errors.New("no vram nodes")
|
|
idleNode := &BackendNode{ID: "idle-vram", Name: "idle", Address: "10.0.0.11:50051"}
|
|
reg.findIdleNode = idleNode
|
|
|
|
result, err := router.Route(context.Background(), "m1", "models/m1.gguf", "llama-cpp", &pb.ModelOptions{}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.ID).To(Equal("idle-vram"))
|
|
})
|
|
|
|
It("falls back to idle when VRAM search fails", func() {
|
|
reg.findVRAMErr = errors.New("no vram")
|
|
idleNode := &BackendNode{ID: "idle-1", Name: "idle-node", Address: "10.0.0.20:50051"}
|
|
reg.findIdleNode = idleNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "m2", "models/m2.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.ID).To(Equal("idle-1"))
|
|
})
|
|
|
|
It("falls back to least-loaded when both VRAM and idle fail", func() {
|
|
reg.findVRAMErr = errors.New("no vram")
|
|
reg.findIdleErr = errors.New("no idle")
|
|
llNode := &BackendNode{ID: "ll-1", Name: "least-loaded", Address: "10.0.0.30:50051"}
|
|
reg.findLeastLoadedNode = llNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "m3", "models/m3.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.ID).To(Equal("ll-1"))
|
|
})
|
|
|
|
It("returns error when no nodes are available and no DB for eviction", func() {
|
|
reg.findVRAMErr = errors.New("no vram")
|
|
reg.findIdleErr = errors.New("no idle")
|
|
reg.findLeastLoadedErr = errors.New("no nodes")
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
// DB is nil — evictLRUAndFreeNode will fail because r.db is nil
|
|
})
|
|
|
|
_, err := router.Route(context.Background(), "m4", "models/m4.gguf", "llama-cpp", nil, false)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("no available nodes"))
|
|
})
|
|
})
|
|
|
|
Describe("UnloadModel (mock-based)", func() {
|
|
It("calls StopBackend and removes the model from the registry", func() {
|
|
reg := &fakeModelRouter{}
|
|
unloader := &fakeUnloader{}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
})
|
|
|
|
err := router.UnloadModel(context.Background(), "node-1", "model-a")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(unloader.stopCalls).To(ContainElement("node-1:model-a"))
|
|
Expect(reg.removeCalls).To(ContainElement("node-1:model-a"))
|
|
})
|
|
|
|
It("returns error when no unloader is configured", func() {
|
|
reg := &fakeModelRouter{}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{})
|
|
|
|
err := router.UnloadModel(context.Background(), "node-1", "model-a")
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("no remote unloader"))
|
|
})
|
|
})
|
|
|
|
Describe("EvictLRU (mock-based)", func() {
|
|
It("finds LRU model and unloads it", func() {
|
|
reg := &fakeModelRouter{
|
|
findLRUModel: &NodeModel{NodeID: "n1", ModelName: "old-model"},
|
|
}
|
|
unloader := &fakeUnloader{}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
})
|
|
|
|
evicted, err := router.EvictLRU(context.Background(), "n1")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(evicted).To(Equal("old-model"))
|
|
Expect(unloader.stopCalls).To(ContainElement("n1:old-model"))
|
|
Expect(reg.removeCalls).To(ContainElement("n1:old-model"))
|
|
})
|
|
|
|
It("returns error when no LRU model is found", func() {
|
|
reg := &fakeModelRouter{
|
|
findLRUErr: errors.New("no models loaded"),
|
|
}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: &fakeUnloader{},
|
|
})
|
|
|
|
_, err := router.EvictLRU(context.Background(), "n1")
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("finding LRU model"))
|
|
})
|
|
})
|
|
|
|
Describe("scheduleNewModel with node selector (mock-based, via Route)", func() {
|
|
var (
|
|
reg *fakeModelRouter
|
|
backend *stubBackend
|
|
factory *stubClientFactory
|
|
unloader *fakeUnloader
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
reg = &fakeModelRouter{
|
|
findAndLockErr: errors.New("not found"),
|
|
}
|
|
backend = &stubBackend{
|
|
loadResult: &pb.Result{Success: true},
|
|
}
|
|
factory = &stubClientFactory{client: backend}
|
|
unloader = &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{
|
|
Success: true,
|
|
Address: "10.0.0.1:9001",
|
|
},
|
|
}
|
|
})
|
|
|
|
It("uses *FromSet methods when model has a node selector", func() {
|
|
gpuNode := &BackendNode{ID: "gpu-1", Name: "gpu-node", Address: "10.0.0.50:50051"}
|
|
reg.getModelScheduling = &ModelSchedulingConfig{
|
|
ModelName: "selector-model",
|
|
NodeSelector: `{"gpu.vendor":"nvidia"}`,
|
|
}
|
|
reg.findBySelectorNodes = []BackendNode{*gpuNode}
|
|
reg.findIdleFromSetNode = gpuNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "selector-model", "models/selector.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.ID).To(Equal("gpu-1"))
|
|
})
|
|
|
|
It("returns error when no nodes match selector", func() {
|
|
reg.getModelScheduling = &ModelSchedulingConfig{
|
|
ModelName: "no-match-model",
|
|
NodeSelector: `{"gpu.vendor":"tpu"}`,
|
|
}
|
|
reg.findBySelectorNodes = nil
|
|
reg.findBySelectorErr = nil
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
_, err := router.Route(context.Background(), "no-match-model", "models/nomatch.gguf", "llama-cpp", nil, false)
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("no healthy nodes match selector"))
|
|
})
|
|
|
|
It("uses regular methods when model has no scheduling config", func() {
|
|
reg.getModelScheduling = nil
|
|
idleNode := &BackendNode{ID: "regular-1", Name: "regular-node", Address: "10.0.0.60:50051"}
|
|
reg.findIdleNode = idleNode
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "regular-model", "models/regular.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.ID).To(Equal("regular-1"))
|
|
})
|
|
})
|
|
|
|
Describe("Route with selector validation on cached model (mock-based)", func() {
|
|
It("falls through when cached node no longer matches selector", func() {
|
|
cachedNode := &BackendNode{ID: "n-old", Name: "old-node", Address: "10.0.0.70:50051"}
|
|
newNode := &BackendNode{ID: "n-new", Name: "new-node", Address: "10.0.0.71:50051"}
|
|
|
|
backend := &stubBackend{
|
|
healthResult: true,
|
|
loadResult: &pb.Result{Success: true},
|
|
}
|
|
factory := &stubClientFactory{client: backend}
|
|
unloader := &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{
|
|
Success: true,
|
|
Address: "10.0.0.71:9001",
|
|
},
|
|
}
|
|
|
|
reg := &fakeModelRouter{
|
|
// Step 1: cached model found on old node
|
|
findAndLockNode: cachedNode,
|
|
findAndLockNM: &NodeModel{NodeID: "n-old", ModelName: "sel-model", Address: "10.0.0.70:9001"},
|
|
// Scheduling config with selector that old node does NOT match
|
|
getModelScheduling: &ModelSchedulingConfig{
|
|
ModelName: "sel-model",
|
|
NodeSelector: `{"gpu.vendor":"nvidia"}`,
|
|
},
|
|
// Old node has no labels matching the selector
|
|
getNodeLabels: []NodeLabel{
|
|
{NodeID: "n-old", Key: "gpu.vendor", Value: "amd"},
|
|
},
|
|
// For scheduling fallthrough: selector matches new node
|
|
findBySelectorNodes: []BackendNode{*newNode},
|
|
findIdleFromSetNode: newNode,
|
|
}
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
})
|
|
|
|
result, err := router.Route(context.Background(), "sel-model", "models/sel.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
// Should have fallen through to the new node
|
|
Expect(result.Node.ID).To(Equal("n-new"))
|
|
// Old node should have had its in-flight decremented
|
|
Expect(reg.decrementCalls).To(ContainElement("n-old:sel-model"))
|
|
})
|
|
})
|
|
|
|
Describe("ScheduleAndLoadModel (mock-based)", func() {
|
|
It("returns an error and does not fire a NATS install when no load info is stored", func() {
|
|
// Reproduces the reconciler scale-up bug: when GetModelLoadInfo
|
|
// returns ErrRecordNotFound (no replica has ever been loaded),
|
|
// the previous fallback called scheduleNewModel with an empty
|
|
// backend type, which the worker rejected on every reconciler
|
|
// tick. The fix bails out cleanly with an explanatory error and
|
|
// never sends backend.install.
|
|
unloader := &fakeUnloader{}
|
|
reg := &fakeModelRouter{}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader})
|
|
|
|
node, err := router.ScheduleAndLoadModel(context.Background(), "never-loaded", nil)
|
|
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(node).To(BeNil())
|
|
Expect(err.Error()).To(ContainSubstring("never-loaded"))
|
|
Expect(unloader.installCalls).To(BeEmpty(),
|
|
"reconciler must not fire backend.install when there is no load info to replicate")
|
|
})
|
|
})
|
|
|
|
// -----------------------------------------------------------------------
|
|
// Integration tests using real PostgreSQL (existing)
|
|
// -----------------------------------------------------------------------
|
|
Describe("evictLRUAndFreeNode (integration)", func() {
|
|
var (
|
|
db *gorm.DB
|
|
registry *NodeRegistry
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
if runtime.GOOS == "darwin" {
|
|
Skip("testcontainers requires Docker, not available on macOS CI")
|
|
}
|
|
db = testutil.SetupTestDB()
|
|
var err error
|
|
registry, err = NewNodeRegistry(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("returns ErrEvictionBusy in under 5 seconds when all models are busy", func() {
|
|
node := &BackendNode{
|
|
Name: "busy-evict",
|
|
NodeType: NodeTypeBackend,
|
|
Address: "10.0.0.100:50051",
|
|
}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Load a model and give it in-flight requests so it cannot be evicted
|
|
Expect(registry.SetNodeModel(context.Background(), node.ID, "busy-model", 0, "loaded", "", 0)).To(Succeed())
|
|
Expect(registry.IncrementInFlight(context.Background(), node.ID, "busy-model", 0)).To(Succeed())
|
|
|
|
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
|
|
|
|
start := time.Now()
|
|
_, err := router.evictLRUAndFreeNode(context.Background())
|
|
elapsed := time.Since(start)
|
|
|
|
Expect(err).To(MatchError(ErrEvictionBusy))
|
|
// 5 retries * 500ms = 2.5s nominal; allow generous upper bound
|
|
Expect(elapsed).To(BeNumerically("<", 5*time.Second))
|
|
})
|
|
|
|
It("respects context cancellation", func() {
|
|
node := &BackendNode{
|
|
Name: "cancel-evict",
|
|
NodeType: NodeTypeBackend,
|
|
Address: "10.0.0.101:50051",
|
|
}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
Expect(registry.SetNodeModel(context.Background(), node.ID, "cancel-model", 0, "loaded", "", 0)).To(Succeed())
|
|
Expect(registry.IncrementInFlight(context.Background(), node.ID, "cancel-model", 0)).To(Succeed())
|
|
|
|
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel() // cancel immediately
|
|
|
|
start := time.Now()
|
|
_, err := router.evictLRUAndFreeNode(ctx)
|
|
elapsed := time.Since(start)
|
|
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("context cancelled"))
|
|
// Should return very quickly since context is already done
|
|
Expect(elapsed).To(BeNumerically("<", 2*time.Second))
|
|
})
|
|
})
|
|
|
|
Describe("stageModelFiles (integration)", func() {
|
|
var (
|
|
db *gorm.DB
|
|
registry *NodeRegistry
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
if runtime.GOOS == "darwin" {
|
|
Skip("testcontainers requires Docker, not available on macOS CI")
|
|
}
|
|
db = testutil.SetupTestDB()
|
|
var err error
|
|
registry, err = NewNodeRegistry(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
It("does not mutate the original ModelOptions", func() {
|
|
stager := &fakeFileStager{}
|
|
router := NewSmartRouter(registry, SmartRouterOptions{
|
|
FileStager: stager,
|
|
DB: db,
|
|
})
|
|
|
|
node := &BackendNode{
|
|
ID: "stage-node-id",
|
|
Name: "stage-node",
|
|
Address: "10.0.0.200:50051",
|
|
}
|
|
|
|
original := &pb.ModelOptions{
|
|
Model: "test-backend/models/test.gguf",
|
|
ModelFile: "/models/test-backend/models/test.gguf",
|
|
MMProj: "",
|
|
}
|
|
|
|
// Capture original values before staging
|
|
origModel := original.Model
|
|
origModelFile := original.ModelFile
|
|
origMMProj := original.MMProj
|
|
|
|
// stageModelFiles creates temp files for os.Stat checks.
|
|
// Since none of our test paths exist on disk, stageModelFiles will
|
|
// skip them (clearing non-existent optional fields). The key property
|
|
// is that the original proto pointer is not modified.
|
|
_, _ = router.stageModelFiles(context.Background(), node, original, "test-model")
|
|
|
|
// Verify the original proto was not mutated
|
|
Expect(original.Model).To(Equal(origModel))
|
|
Expect(original.ModelFile).To(Equal(origModelFile))
|
|
Expect(original.MMProj).To(Equal(origMMProj))
|
|
})
|
|
})
|
|
|
|
// -----------------------------------------------------------------------
|
|
// narrowByGroupAntiAffinity
|
|
// -----------------------------------------------------------------------
|
|
Describe("narrowByGroupAntiAffinity", func() {
|
|
var (
|
|
reg *fakeModelRouter
|
|
resolver *fakeConflictResolver
|
|
router *SmartRouter
|
|
ctx context.Context
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
reg = &fakeModelRouter{}
|
|
resolver = &fakeConflictResolver{conflicts: map[string][]string{}}
|
|
router = NewSmartRouter(reg, SmartRouterOptions{
|
|
ConflictResolver: resolver,
|
|
})
|
|
ctx = context.Background()
|
|
})
|
|
|
|
It("returns the input set unchanged when the model has no conflicts", func() {
|
|
candidates := []string{"n1", "n2", "n3"}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "lonely", candidates)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(Equal(candidates))
|
|
})
|
|
|
|
It("removes nodes that already host a conflicting model", func() {
|
|
resolver.conflicts["b"] = []string{"a"}
|
|
reg.findNodesWithModelByName = map[string][]BackendNode{
|
|
"a": {{ID: "n1"}},
|
|
}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "b", []string{"n1", "n2"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(ConsistOf("n2"))
|
|
})
|
|
|
|
It("returns the original set unchanged when every candidate has a conflict (soft fallback)", func() {
|
|
resolver.conflicts["b"] = []string{"a"}
|
|
reg.findNodesWithModelByName = map[string][]BackendNode{
|
|
"a": {{ID: "n1"}, {ID: "n2"}},
|
|
}
|
|
candidates := []string{"n1", "n2"}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "b", candidates)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(Equal(candidates))
|
|
})
|
|
|
|
It("removes nodes hosting any of multiple conflicting models", func() {
|
|
resolver.conflicts["c"] = []string{"a", "b"}
|
|
reg.findNodesWithModelByName = map[string][]BackendNode{
|
|
"a": {{ID: "n1"}},
|
|
"b": {{ID: "n2"}},
|
|
}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "c", []string{"n1", "n2", "n3"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(ConsistOf("n3"))
|
|
})
|
|
|
|
It("treats a nil candidate set (\"any healthy node\") by returning nil unchanged when narrowing yields nothing", func() {
|
|
resolver.conflicts["b"] = []string{"a"}
|
|
reg.findNodesWithModelByName = map[string][]BackendNode{
|
|
"a": {{ID: "n1"}, {ID: "n2"}},
|
|
}
|
|
out, err := router.narrowByGroupAntiAffinity(ctx, "b", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// nil in → nil out: caller's "any healthy node" semantics preserved.
|
|
// Hard-narrowing nil would silently exclude every other node.
|
|
Expect(out).To(BeNil())
|
|
})
|
|
|
|
It("is a no-op when no resolver is configured", func() {
|
|
plain := NewSmartRouter(reg, SmartRouterOptions{})
|
|
candidates := []string{"n1", "n2"}
|
|
out, err := plain.narrowByGroupAntiAffinity(ctx, "b", candidates)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(out).To(Equal(candidates))
|
|
})
|
|
})
|
|
|
|
Describe("installBackendOnNode singleflight", func() {
|
|
It("coalesces concurrent identical installs into one NATS call", func() {
|
|
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
|
|
|
|
// Slow install reply so concurrent calls overlap deterministically.
|
|
started := make(chan struct{}, 5)
|
|
release := make(chan struct{})
|
|
unloader := &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"},
|
|
}
|
|
unloader.installHook = func() {
|
|
started <- struct{}{}
|
|
<-release
|
|
}
|
|
|
|
router := NewSmartRouter(&fakeModelRouter{}, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: &stubClientFactory{client: &stubBackend{}},
|
|
})
|
|
|
|
// Fire 5 concurrent identical installBackendOnNode calls.
|
|
done := make(chan error, 5)
|
|
for i := 0; i < 5; i++ {
|
|
go func() {
|
|
_, err := router.installBackendOnNode(context.Background(), node, "llama-cpp", "my-model", 0)
|
|
done <- err
|
|
}()
|
|
}
|
|
|
|
// Only ONE call should have entered the unloader hook (the
|
|
// singleflight leader). The other 4 are coalesced and waiting on
|
|
// the leader's result.
|
|
Eventually(started).Should(Receive())
|
|
Consistently(started, 100*time.Millisecond).ShouldNot(Receive())
|
|
|
|
// Release the leader; the other 4 callers receive the same result.
|
|
close(release)
|
|
for i := 0; i < 5; i++ {
|
|
Expect(<-done).ToNot(HaveOccurred())
|
|
}
|
|
Expect(unloader.installCalls).To(HaveLen(1),
|
|
"singleflight should coalesce 5 concurrent identical loads into 1 NATS call")
|
|
})
|
|
|
|
It("does NOT coalesce installs for different (modelID, replica) keys", func() {
|
|
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
|
|
unloader := &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:50100"},
|
|
}
|
|
router := NewSmartRouter(&fakeModelRouter{}, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: &stubClientFactory{client: &stubBackend{}},
|
|
})
|
|
|
|
_, err1 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-A", 0)
|
|
_, err2 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-B", 0)
|
|
_, err3 := router.installBackendOnNode(context.Background(), node, "llama-cpp", "model-A", 1)
|
|
Expect(err1).ToNot(HaveOccurred())
|
|
Expect(err2).ToNot(HaveOccurred())
|
|
Expect(err3).ToNot(HaveOccurred())
|
|
Expect(unloader.installCalls).To(HaveLen(3))
|
|
})
|
|
})
|
|
})
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Fake prefixcache.Provider for SmartRouter prefix-cache routing tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
type observeRecord struct {
|
|
model string
|
|
chain []uint64
|
|
key prefixcache.ReplicaKey
|
|
}
|
|
|
|
type invalidateRecord struct {
|
|
model string
|
|
key prefixcache.ReplicaKey
|
|
}
|
|
|
|
// fakePrefixProvider records all interactions and returns a configurable
|
|
// decision.
|
|
type fakePrefixProvider struct {
|
|
decideCalls int
|
|
observed []observeRecord
|
|
invalidated []invalidateRecord
|
|
invalidatedNode []string
|
|
decision prefixcache.PrefixDecision
|
|
}
|
|
|
|
func (f *fakePrefixProvider) Decide(_ string, _ []uint64, _ []prefixcache.ReplicaKey, _ time.Time) prefixcache.PrefixDecision {
|
|
f.decideCalls++
|
|
return f.decision
|
|
}
|
|
|
|
func (f *fakePrefixProvider) Observe(model string, chain []uint64, key prefixcache.ReplicaKey, _ time.Time) bool {
|
|
f.observed = append(f.observed, observeRecord{model: model, chain: append([]uint64(nil), chain...), key: key})
|
|
return true
|
|
}
|
|
|
|
func (f *fakePrefixProvider) Invalidate(model string, key prefixcache.ReplicaKey) {
|
|
f.invalidated = append(f.invalidated, invalidateRecord{model: model, key: key})
|
|
}
|
|
|
|
func (f *fakePrefixProvider) InvalidateNode(model, nodeID string) {
|
|
f.invalidatedNode = append(f.invalidatedNode, model+":"+nodeID)
|
|
}
|
|
|
|
func (f *fakePrefixProvider) Evict(_ time.Time) {}
|
|
|
|
var _ = Describe("SmartRouter prefix-cache routing", func() {
|
|
var (
|
|
backend *stubBackend
|
|
factory *stubClientFactory
|
|
unloader *fakeUnloader
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
backend = &stubBackend{healthResult: true}
|
|
factory = &stubClientFactory{client: backend}
|
|
unloader = &fakeUnloader{
|
|
installReply: &messaging.BackendInstallReply{Success: true, Address: "10.0.0.1:9001"},
|
|
}
|
|
})
|
|
|
|
// loadedReg builds a fake registry with one loaded healthy replica for
|
|
// "m" on node "X", plus matching replica stats so buildPreference can run.
|
|
loadedReg := func() *fakeModelRouter {
|
|
node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
|
|
nm := &NodeModel{NodeID: "X", ModelName: "m", Address: "10.0.0.1:9001"}
|
|
return &fakeModelRouter{
|
|
findAndLockNode: node,
|
|
findAndLockNM: nm,
|
|
getModelScheduling: &ModelSchedulingConfig{
|
|
RoutePolicy: "prefix_cache",
|
|
},
|
|
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
|
"m": {{NodeID: "X", InFlight: 0}},
|
|
},
|
|
}
|
|
}
|
|
|
|
Context("nil provider (round-robin floor)", func() {
|
|
It("passes a nil preference and never decides or observes", func() {
|
|
reg := loadedReg()
|
|
router := NewSmartRouter(reg, SmartRouterOptions{Unloader: unloader, ClientFactory: factory})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(reg.findAndLockPrefs).ToNot(BeEmpty())
|
|
for _, p := range reg.findAndLockPrefs {
|
|
Expect(p).To(BeNil())
|
|
}
|
|
})
|
|
})
|
|
|
|
Context("with a provider", func() {
|
|
It("passes the decided node as the preference and observes the pick", func() {
|
|
reg := loadedReg()
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(prov.decideCalls).To(BeNumerically(">=", 1))
|
|
Expect(reg.findAndLockPrefs[0]).ToNot(BeNil())
|
|
Expect(reg.findAndLockPrefs[0].PreferredNodeID).To(Equal("X"))
|
|
Expect(reg.findAndLockPrefs[0].PreferredReplica).To(Equal(0))
|
|
Expect(prov.observed).To(HaveLen(1))
|
|
Expect(prov.observed[0].key).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
|
Expect(prov.observed[0].chain).To(Equal([]uint64{1, 2, 3}))
|
|
})
|
|
|
|
It("routes a recurring prefix back to the previously observed node", func() {
|
|
// Real Index as the provider: first request observes X, second
|
|
// request with the same chain must yield PreferredNodeID == X.
|
|
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
|
reg := loadedReg()
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: idx,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{7, 8, 9})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
// First request landed on X (cold placement on the only candidate)
|
|
// and observed the prefix there.
|
|
dFirst := idx.Decide("m", []uint64{7, 8, 9}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now())
|
|
Expect(dFirst.HasHot).To(BeTrue())
|
|
Expect(dFirst.Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
|
|
|
// Second request, same chain: X is now the warm-cache hot match, so
|
|
// the preference must point at it.
|
|
_, err = router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
last := reg.findAndLockPrefs[len(reg.findAndLockPrefs)-1]
|
|
Expect(last).ToNot(BeNil())
|
|
Expect(last.PreferredNodeID).To(Equal("X"))
|
|
Expect(last.PreferredReplica).To(Equal(0))
|
|
})
|
|
|
|
It("prefers the exact hot replica when two replicas share a node", func() {
|
|
// Two replicas of "m" live on the SAME node X: replica 0 and replica
|
|
// 1. A hot prefix observed on (X,0) must produce a preference that
|
|
// locks replica 0 specifically, NOT the sibling replica 1 on the same
|
|
// node. This is the replica-granular regression this change fixes.
|
|
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
|
node := &BackendNode{ID: "X", Name: "node-x", Address: "10.0.0.1:50051"}
|
|
nm := &NodeModel{NodeID: "X", ModelName: "m", ReplicaIndex: 0, Address: "10.0.0.1:9001"}
|
|
reg := &fakeModelRouter{
|
|
findAndLockNode: node,
|
|
findAndLockNM: nm,
|
|
getModelScheduling: &ModelSchedulingConfig{
|
|
RoutePolicy: "prefix_cache",
|
|
},
|
|
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
|
"m": {
|
|
{NodeID: "X", ReplicaIndex: 0, InFlight: 0},
|
|
{NodeID: "X", ReplicaIndex: 1, InFlight: 0},
|
|
},
|
|
},
|
|
}
|
|
// Seed the index so (X,0) is the warm replica for this chain.
|
|
idx.Observe("m", []uint64{1, 2, 3}, prefixcache.ReplicaKey{NodeID: "X", Replica: 0}, time.Now())
|
|
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: idx,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
pref := reg.findAndLockPrefs[0]
|
|
Expect(pref).ToNot(BeNil())
|
|
Expect(pref.PreferredNodeID).To(Equal("X"))
|
|
Expect(pref.PreferredReplica).To(Equal(0),
|
|
"the hot prefix lives on replica 0; the same-node sibling replica 1 must NOT be chosen")
|
|
})
|
|
|
|
It("does not decide or observe when no prefix chain is present", func() {
|
|
reg := loadedReg()
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
_, err := router.Route(context.Background(), "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(prov.decideCalls).To(Equal(0))
|
|
Expect(prov.observed).To(BeEmpty())
|
|
Expect(reg.findAndLockPrefs[0]).To(BeNil())
|
|
})
|
|
|
|
It("does not observe for round-robin models even with a chain", func() {
|
|
reg := loadedReg()
|
|
reg.getModelScheduling = &ModelSchedulingConfig{RoutePolicy: "round_robin"}
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{Hot: prefixcache.ReplicaKey{NodeID: "X"}, HasHot: true, MatchRatio: 1.0}}
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(prov.decideCalls).To(Equal(0))
|
|
Expect(prov.observed).To(BeEmpty())
|
|
Expect(reg.findAndLockPrefs[0]).To(BeNil())
|
|
})
|
|
})
|
|
|
|
Context("forced-disturb pressure", func() {
|
|
// disturbReg builds a registry with two candidate replicas for "m":
|
|
// the hot node X is saturated (high in_flight) and Y is free. Select
|
|
// will therefore reject the hot node and pick Y, which is the
|
|
// forced-disturb signal. findAndLockNode returns Y so Route succeeds.
|
|
disturbReg := func() *fakeModelRouter {
|
|
nodeY := &BackendNode{ID: "Y", Name: "node-y", Address: "10.0.0.2:50051"}
|
|
nm := &NodeModel{NodeID: "Y", ModelName: "m", Address: "10.0.0.2:9001"}
|
|
return &fakeModelRouter{
|
|
findAndLockNode: nodeY,
|
|
findAndLockNM: nm,
|
|
getModelScheduling: &ModelSchedulingConfig{
|
|
RoutePolicy: "prefix_cache",
|
|
},
|
|
loadedReplicaStatsByName: map[string][]ReplicaCandidate{
|
|
"m": {{NodeID: "X", InFlight: 50}, {NodeID: "Y", InFlight: 0}},
|
|
},
|
|
}
|
|
}
|
|
|
|
It("records pressure when a strong hot match was forced off the warm node", func() {
|
|
reg := disturbReg()
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
|
Hot: prefixcache.ReplicaKey{NodeID: "X"},
|
|
HasHot: true,
|
|
MatchRatio: 1.0,
|
|
ColdOrder: []prefixcache.ReplicaKey{{NodeID: "Y"}, {NodeID: "X"}},
|
|
}}
|
|
pressure := prefixcache.NewPressure(time.Minute)
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
Pressure: pressure,
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(pressure.Count("m", time.Now())).To(BeNumerically(">", 0),
|
|
"hot match existed but the load guard forced us off X: must record pressure")
|
|
})
|
|
|
|
It("does not record pressure when the hot node is itself eligible", func() {
|
|
reg := loadedReg() // single node X, in_flight 0 → X stays eligible
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
|
Hot: prefixcache.ReplicaKey{NodeID: "X"},
|
|
HasHot: true,
|
|
MatchRatio: 1.0,
|
|
}}
|
|
pressure := prefixcache.NewPressure(time.Minute)
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
Pressure: pressure,
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(pressure.Count("m", time.Now())).To(Equal(0),
|
|
"chosen == hot node, no disturb")
|
|
})
|
|
|
|
It("does not record pressure for an all-unique workload with no hot match", func() {
|
|
reg := loadedReg()
|
|
prov := &fakePrefixProvider{decision: prefixcache.PrefixDecision{
|
|
HasHot: false, // no prefix match at all
|
|
MatchRatio: 0,
|
|
ColdOrder: []prefixcache.ReplicaKey{{NodeID: "X"}},
|
|
}}
|
|
pressure := prefixcache.NewPressure(time.Minute)
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: prov,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
Pressure: pressure,
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{1, 2, 3})
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
Expect(pressure.Count("m", time.Now())).To(Equal(0),
|
|
"no hot match means no cache to disturb: must not false-positive")
|
|
})
|
|
})
|
|
|
|
Context("removal chokepoint on unload", func() {
|
|
It("removes the replica via the registry so the removal hook invalidates the prefix entry", func() {
|
|
idx := prefixcache.NewIndex(prefixcache.DefaultConfig())
|
|
reg := loadedReg()
|
|
router := NewSmartRouter(reg, SmartRouterOptions{
|
|
Unloader: unloader,
|
|
ClientFactory: factory,
|
|
PrefixProvider: idx,
|
|
PrefixConfig: prefixcache.DefaultConfig(),
|
|
})
|
|
|
|
ctx := distributedhdr.WithPrefixChain(context.Background(), []uint64{5, 6})
|
|
// Warm the cache: X now holds the prefix.
|
|
_, err := router.Route(ctx, "m", "models/m.gguf", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(idx.Decide("m", []uint64{5, 6}, []prefixcache.ReplicaKey{{NodeID: "X", Replica: 0}}, time.Now()).Hot).To(Equal(prefixcache.ReplicaKey{NodeID: "X", Replica: 0}))
|
|
|
|
// UnloadModel must route the eviction through the registry removal
|
|
// chokepoint (RemoveAllNodeModelReplicas). The registry's
|
|
// SetReplicaRemovedHook is what invalidates the prefix index in
|
|
// production; the router no longer invalidates directly. Here the
|
|
// fake registry records the removal but fires no hook, so we assert
|
|
// the chokepoint is exercised rather than the downstream
|
|
// invalidation (covered by the registry hook integration tests).
|
|
Expect(router.UnloadModel(context.Background(), "X", "m")).To(Succeed())
|
|
Expect(reg.removeCalls).To(ContainElement("X:m"),
|
|
"UnloadModel must remove the replica via the registry removal chokepoint")
|
|
})
|
|
})
|
|
})
|