Files
LocalAI/core/services/nodes/router_test.go
Ettore Di Giacinto 6b63b47f61 feat(distributed): support multiple replicas of one model on the same node (#9583)
* feat(distributed): support multiple replicas of one model on the same node

The distributed scheduler implicitly assumed `(node_id, model_name)` was
unique, but the schema didn't enforce it and the worker keyed all gRPC
processes by model name alone. With `MinReplicas=2` against a single
worker, the reconciler "scaled up" every 30s but the registry never
advanced past 1 row — the worker re-loaded the model in-place every tick
until VRAM fragmented and the gRPC process died.

This change introduces multi-replica-per-node as a first-class concept,
with capacity-aware scheduling, a circuit breaker, and VRAM
soft-reservation. Operators can declare per-node capacity via the worker
flag `--max-replicas-per-model` (mirrored as auto-label
`node.replica-slots=N`) or override per-node from the UI.

* Schema: BackendNode gains MaxReplicasPerModel (default 1) and
  ReservedVRAM. NodeModel gains ReplicaIndex (composite with node_id +
  model_name). ModelSchedulingConfig gains UnsatisfiableUntil/Ticks for
  the reconciler circuit breaker.

* Registry: replica_index threaded through SetNodeModel, RemoveNodeModel,
  IncrementInFlight, DecrementInFlight, TouchNodeModel, GetNodeModel,
  SetNodeModelLoadInfo and the InFlightTrackingClient. New helpers:
  CountReplicasOnNode, NextFreeReplicaIndex (with ErrNoFreeSlot),
  RemoveAllNodeModelReplicas, FindNodesWithFreeSlot,
  ClusterCapacityForModel, ReserveVRAM/ReleaseVRAM (atomic UPDATE with
  ErrInsufficientVRAM), and the unsatisfiable-flag CRUD.

* Worker: processKey now `<modelID>#<replicaIndex>` so concurrent loads
  of the same model land on distinct ports. Adds CLI flag
  --max-replicas-per-model (env LOCALAI_MAX_REPLICAS_PER_MODEL, default 1)
  and emits the auto-label.

* Router: scheduleNewModel filters candidates by free slot, allocates the
  replica index, and soft-reserves VRAM before installing the backend.
  evictLRUAndFreeNode now deletes the targeted row by ID instead of all
  replicas of the model on the node — fixes a latent bug where evicting
  one replica orphaned its siblings.

* Reconciler: caps scale-up at ClusterCapacityForModel so a misconfig
  (MinReplicas > capacity) doesn't loop forever. After 3 consecutive
  ticks of capacity==0 it sets UnsatisfiableUntil for a 5m cooldown and
  emits a warning. ClearAllUnsatisfiable fires from Register,
  ApproveNode, SetNodeLabel(s), RemoveNodeLabel and
  UpdateMaxReplicasPerModel so a new node joining or label changes wake
  the reconciler immediately. scaleDownIdle removes highest-replica-index
  first to keep slots compact.

* Heartbeat resets reserved_vram to 0 — worker is the source of truth
  for actual free VRAM; the reservation is only for the in-tick race
  window between two scheduling decisions.

* Probe path (reconciler.probeLoadedModels and health.doCheckAll) now
  pass the row's replica_index to RemoveNodeModel so an unreachable
  replica doesn't orphan healthy siblings.

* Admin override: PUT /api/nodes/:id/max-replicas-per-model sets a
  sticky override (preserved across worker re-registration). DELETE
  clears the override so the worker's flag applies again on next
  register. Required because Kong defaults the worker flag to 1, so
  every worker restart would have silently reverted the UI value.

* React UI: always-visible slot badge on the node row (muted at default
  1, accented when >1); inline editor in the expanded drawer with
  pencil-to-edit, Save/Cancel, Esc/Enter, "(override)" indicator when
  the value is admin-set, and a "Reset" button to hand control back to
  the worker. Soft confirm when shrinking the cap below the count of
  loaded replicas. Scheduling rules table gets an "Unsatisfiable until
  HH:MM" status badge surfacing the cooldown.

* node.replica-slots filtered out of the labels strip on the row to
  avoid duplicating the slot badge.

23 new Ginkgo specs (registry, reconciler, inflight, health) cover:
multi-replica row independence, RemoveNodeModel of one replica
preserving siblings, NextFreeReplicaIndex slot allocation including
ErrNoFreeSlot, capacity-gated scale-up with circuit breaker tripping
and recovery on Register, scheduleDownIdle ordering, ClusterCapacity
math, ReserveVRAM admission gating, Heartbeat reset, override survival
across worker re-registration, and ResetMaxReplicasPerModel handing
control back. Plus 8 stdlib tests for the worker processKey / CLI /
auto-label.

Closes the flap reproduced on Qwen3.6-35B against the nvidia-thor
worker (single 128 GiB node, MinReplicas=2): the reconciler now caps
the scale-up at the cluster's actual capacity instead of looping.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Read] [Edit] [Bash] [Skill:critique] [Skill:audit] [Skill:polish] [Skill:golang-testing]

* refactor(react-ui/nodes): tighten capacity editor copy + adopt ActionMenu for row actions

* Capacity editor hint trimmed from operator-doc-style ("Sourced from
  the worker's `--max-replicas-per-model` flag. Changing it here makes it
  a sticky admin override that survives worker restarts." → "Saved
  values stick across worker restarts.") and the override-state copy
  similarly compressed. The full mechanic is no longer needed in the UI
  — the override pill carries the meaning and the docs cover the rest.

* Node row actions migrated from an inline cluster of icon buttons
  (Drain / Resume / Trash) to the kebab ActionMenu used by /manage for
  per-row model actions, so dense Nodes tables stay clean. Approve
  stays as a prominent primary button — it's a stateful admission gate,
  not a routine action, and elevating it matches how /manage surfaces
  install-time decisions outside the menu.

* The expanded drawer's Labels section now filters node.replica-slots
  out of the editable label list. The label is owned by the Capacity
  editor above; surfacing it again as an editable label invited
  confusion (the Capacity save would clobber any direct edit).

Both backend and agent workers benefit — they share the row rendering
path, so the action menu and label filter apply to both.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp] [Skill:critique] [Skill:audit] [Skill:polish]

* fix(react-ui/nodes): suppress slot badge on agent workers

Agent workers don't load models, so the per-node replica capacity is
inapplicable to them. Showing "1× slots" on agent rows was a tiny
inconsistency from the unified rendering path — gate the badge on
node_type !== 'agent' so it only appears on backend workers.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp]

* refactor(react-ui/nodes): distill expanded drawer + restyle scheduling form

The expanded node drawer used to stack five panels — slot badge,
filled capacity box, Loaded Models h4+empty-state, Installed Backends
h4+empty-state, Labels h4+chips+form — making routine inspections feel
like a control panel. The scheduling rule form wrapped its mode toggle
as two 50%-width filled buttons that competed visually with the actual
primary action.

* Drawer: collapse three rarely-touched config zones (Capacity,
  Backends, Labels) into one `<details>` "Manage" disclosure (closed by
  default) with small uppercase eyebrow labels for each zone instead of
  parallel h4 sub-headings. Loaded Models stays as the at-a-glance
  headline with a single-line empty hint instead of a boxed empty state.
  CapacityEditor renders flat (no filled background) — the Manage
  disclosure provides framing.

* Scheduling form: replace the chunky 50%-width button-tabs with the
  project's existing `.segmented` control (icon + label, sized to
  content). Mode hint becomes a single tied line below. Fields stack
  vertically with helper text under inputs and a hairline divider above
  the right-aligned Save / Cancel.

The empty drawer collapses from ~5 stacked sections (~280px tall) to
two lines (~80px). The scheduling form now reads as a designed dialog
instead of raw building blocks. Both surfaces now match the typographic
density and weight of the rest of the admin pages.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp] [Skill:distill] [Skill:audit] [Skill:polish]

* feat(react-ui/nodes): replace scheduling form's model picker with searchable combobox

The native <select> made operators scroll through every gallery entry to
find a model name. The project already has SearchableModelSelect (used
in Studio/Talk/etc.) which combines free-text search with the gallery
list and accepts typed model names that aren't installed yet — useful
for pre-staging a scheduling rule before the node it'll run on has
finished bootstrapping.

Also drops the now-unused useModels import (the combobox manages the
gallery hook internally).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit]

* refactor(react-ui/nodes): consolidate key/value chip editor + add replica preset chips

The Nodes page was rendering the same key=value chip pattern in two
places with subtly different markup: the Labels editor in the expanded
drawer and (post-distill) the Node Selector input in the scheduling
form. The form's input was also a comma-separated string that operators
were getting wrong.

* Extract <KeyValueChips> as a fully controlled chip-builder. Parent
  owns the map and decides what onAdd/onRemove does — form state for the
  scheduling form, API calls for the live drawer Labels editor. Same
  visuals everywhere; one component to change when polish needs apply.

* Replace the comma-separated Node Selector text input with KeyValueChips.
  Operators were copying syntax from docs and missing commas; the chip
  vocabulary makes the key=value structure self-documenting.

* Add <ReplicaInput>: numeric input + quick-pick preset chips for Min/Max
  replicas. Picked over a slider because replica counts are exact specs
  derived from VRAM math (operator decision, not a fuzzy estimate). The
  chips give one-click access to common values (1/2/3/4 for Min,
  0=no-limit/2/4/8 for Max) without the slider's special-value problem
  (MaxReplicas=0 is categorical, not a position on a continuum).

* Drop the now-unused labelInputs state in the Nodes page (the inline
  label editor's per-node draft state lived there and is now owned by
  KeyValueChips).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [Skill:distill]

* test: fix CI fallout from multi-replica refactor (e2e/distributed + playwright)

Two breakages caught by CI that didn't surface in the local run:

* tests/e2e/distributed/*.go — multiple files used the pre-PR2 registry
  signatures for SetNodeModel / IncrementInFlight / DecrementInFlight /
  RemoveNodeModel / TouchNodeModel / GetNodeModel / SetNodeModelLoadInfo
  and one stale adapter.InstallBackend call in node_lifecycle_test.go.
  All updated to pass replicaIndex=0 — these tests don't exercise
  multi-replica behavior, they just need to compile against the new
  signatures. The chip-builder tests in core/services/nodes/ already
  cover the multi-replica logic.

* core/http/react-ui/e2e/nodes-per-node-backend-actions.spec.js — the
  drawer's distill refactor moved Backends inside a "Manage" <details>
  disclosure that's collapsed by default. The test helper expanded the
  node row but never opened Manage, so the per-node backend table was
  never in the DOM. Helper now clicks `.node-manage > summary` after
  expanding the row.

All 100 playwright tests pass locally; tests/e2e/distributed compiles
clean.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [Bash]

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-04-27 21:20:05 +02:00

816 lines
27 KiB
Go

package nodes
import (
"context"
"errors"
"fmt"
"runtime"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/testutil"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
ggrpc "google.golang.org/grpc"
"gorm.io/gorm"
)
// ---------------------------------------------------------------------------
// Fake FileStager (pre-existing)
// ---------------------------------------------------------------------------
// fakeFileStager is a minimal FileStager that records calls and returns
// predictable remote paths without touching the filesystem or network.
type fakeFileStager struct {
ensureCalls []ensureCall
}
type ensureCall struct {
nodeID, localPath, key string
}
func (f *fakeFileStager) EnsureRemote(_ context.Context, nodeID, localPath, key string) (string, error) {
f.ensureCalls = append(f.ensureCalls, ensureCall{nodeID, localPath, key})
return "/remote/" + key, nil
}
func (f *fakeFileStager) FetchRemote(_ context.Context, _, _, _ string) error { return nil }
func (f *fakeFileStager) FetchRemoteByKey(_ context.Context, _, _, _ string) error { return nil }
func (f *fakeFileStager) AllocRemoteTemp(_ context.Context, _ string) (string, error) {
return "/remote/tmp", nil
}
func (f *fakeFileStager) StageRemoteToStore(_ context.Context, _, _, _ string) error { return nil }
func (f *fakeFileStager) ListRemoteDir(_ context.Context, _, _ string) ([]string, error) {
return nil, nil
}
// ---------------------------------------------------------------------------
// Fake ModelRouter
// ---------------------------------------------------------------------------
// fakeModelRouter implements ModelRouter with configurable return values.
type fakeModelRouter struct {
// FindAndLockNodeWithModel returns
findAndLockNode *BackendNode
findAndLockNM *NodeModel
findAndLockErr error
// FindNodeWithVRAM returns
findVRAMNode *BackendNode
findVRAMErr error
// FindIdleNode returns
findIdleNode *BackendNode
findIdleErr error
// FindLeastLoadedNode returns
findLeastLoadedNode *BackendNode
findLeastLoadedErr error
// FindGlobalLRUModelWithZeroInFlight returns
findGlobalLRUModel *NodeModel
findGlobalLRUErr error
// FindLRUModel returns
findLRUModel *NodeModel
findLRUErr error
// Get returns
getNode *BackendNode
getErr error
// GetModelScheduling returns
getModelScheduling *ModelSchedulingConfig
getModelSchedErr error
// FindNodesBySelector returns
findBySelectorNodes []BackendNode
findBySelectorErr error
// *FromSet variants
findVRAMFromSetNode *BackendNode
findVRAMFromSetErr error
findIdleFromSetNode *BackendNode
findIdleFromSetErr error
findLeastLoadedFromSetNode *BackendNode
findLeastLoadedFromSetErr error
// GetNodeLabels returns
getNodeLabels []NodeLabel
getNodeLabelsErr error
// Track calls for assertions
decrementCalls []string // "nodeID:modelName"
incrementCalls []string
removeCalls []string
setCalls []string
touchCalls []string
}
func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName string) (*BackendNode, *NodeModel, error) {
return f.findAndLockNode, f.findAndLockNM, f.findAndLockErr
}
func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
f.decrementCalls = append(f.decrementCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) IncrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
f.incrementCalls = append(f.incrementCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) RemoveNodeModel(_ context.Context, nodeID, modelName string, _ int) error {
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) RemoveAllNodeModelReplicas(_ context.Context, nodeID, modelName string) error {
// Same recorded key as RemoveNodeModel so existing tests that assert "the
// model was removed" don't need to know whether the production code used
// the per-replica or all-replicas variant.
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) TouchNodeModel(_ context.Context, nodeID, modelName string, _ int) {
f.touchCalls = append(f.touchCalls, nodeID+":"+modelName)
}
func (f *fakeModelRouter) SetNodeModel(_ context.Context, nodeID, modelName string, _ int, state, address string, _ int) error {
f.setCalls = append(f.setCalls, fmt.Sprintf("%s:%s:%s:%s", nodeID, modelName, state, address))
return nil
}
func (f *fakeModelRouter) SetNodeModelLoadInfo(_ context.Context, _, _ string, _ int, _ string, _ []byte) error {
return nil
}
func (f *fakeModelRouter) GetModelLoadInfo(_ context.Context, _ string) (string, []byte, error) {
return "", nil, fmt.Errorf("not found")
}
func (f *fakeModelRouter) NextFreeReplicaIndex(_ context.Context, _, _ string, _ int) (int, error) {
return 0, nil
}
func (f *fakeModelRouter) CountReplicasOnNode(_ context.Context, _, _ string) (int, error) {
return 0, nil
}
func (f *fakeModelRouter) FindNodeWithVRAM(_ context.Context, _ uint64) (*BackendNode, error) {
return f.findVRAMNode, f.findVRAMErr
}
func (f *fakeModelRouter) FindIdleNode(_ context.Context) (*BackendNode, error) {
return f.findIdleNode, f.findIdleErr
}
func (f *fakeModelRouter) FindLeastLoadedNode(_ context.Context) (*BackendNode, error) {
return f.findLeastLoadedNode, f.findLeastLoadedErr
}
func (f *fakeModelRouter) FindGlobalLRUModelWithZeroInFlight(_ context.Context) (*NodeModel, error) {
return f.findGlobalLRUModel, f.findGlobalLRUErr
}
func (f *fakeModelRouter) FindLRUModel(_ context.Context, _ string) (*NodeModel, error) {
return f.findLRUModel, f.findLRUErr
}
func (f *fakeModelRouter) Get(_ context.Context, _ string) (*BackendNode, error) {
return f.getNode, f.getErr
}
func (f *fakeModelRouter) GetModelScheduling(_ context.Context, _ string) (*ModelSchedulingConfig, error) {
return f.getModelScheduling, f.getModelSchedErr
}
func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
return f.findBySelectorNodes, f.findBySelectorErr
}
func (f *fakeModelRouter) FindNodesWithFreeSlot(_ context.Context, _ string, _ []string) ([]BackendNode, error) {
// Default: same answer as FindNodesBySelector. Tests that need a
// specific filter can override by reusing findBySelectorNodes.
return f.findBySelectorNodes, f.findBySelectorErr
}
func (f *fakeModelRouter) ReserveVRAM(_ context.Context, _ string, _ uint64) error {
return nil
}
func (f *fakeModelRouter) ReleaseVRAM(_ context.Context, _ string, _ uint64) error {
return nil
}
func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
return f.findVRAMFromSetNode, f.findVRAMFromSetErr
}
func (f *fakeModelRouter) FindIdleNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
return f.findIdleFromSetNode, f.findIdleFromSetErr
}
func (f *fakeModelRouter) FindLeastLoadedNodeFromSet(_ context.Context, _ []string) (*BackendNode, error) {
return f.findLeastLoadedFromSetNode, f.findLeastLoadedFromSetErr
}
func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) {
return f.getNodeLabels, f.getNodeLabelsErr
}
// ---------------------------------------------------------------------------
// Fake BackendClientFactory + Backend
// ---------------------------------------------------------------------------
// stubBackend implements grpc.Backend with configurable HealthCheck and LoadModel.
type stubBackend struct {
grpc.Backend // embed to satisfy interface; unused methods will panic if called
healthResult bool
healthErr error
loadResult *pb.Result
loadErr error
}
func (f *stubBackend) HealthCheck(_ context.Context) (bool, error) {
return f.healthResult, f.healthErr
}
func (f *stubBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return f.loadResult, f.loadErr
}
func (f *stubBackend) IsBusy() bool { return false }
// stubClientFactory returns the same stubBackend for every call.
type stubClientFactory struct {
client *stubBackend
}
func (f *stubClientFactory) NewClient(_ string, _ bool) grpc.Backend {
return f.client
}
// ---------------------------------------------------------------------------
// Fake NodeCommandSender (unloader)
// ---------------------------------------------------------------------------
type fakeUnloader struct {
installReply *messaging.BackendInstallReply
installErr error
stopCalls []string // "nodeID:model"
stopErr error
unloadCalls []string
unloadErr error
}
func (f *fakeUnloader) InstallBackend(_, _, _, _, _, _, _ string, _ int) (*messaging.BackendInstallReply, error) {
return f.installReply, f.installErr
}
func (f *fakeUnloader) DeleteBackend(_, _ string) (*messaging.BackendDeleteReply, error) {
return &messaging.BackendDeleteReply{Success: true}, nil
}
func (f *fakeUnloader) ListBackends(_ string) (*messaging.BackendListReply, error) {
return &messaging.BackendListReply{}, nil
}
func (f *fakeUnloader) StopBackend(nodeID, backend string) error {
f.stopCalls = append(f.stopCalls, nodeID+":"+backend)
return f.stopErr
}
func (f *fakeUnloader) UnloadModelOnNode(nodeID, modelName string) error {
f.unloadCalls = append(f.unloadCalls, nodeID+":"+modelName)
return f.unloadErr
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
var _ = Describe("SmartRouter", func() {
// -----------------------------------------------------------------------
// Unit tests using mock interfaces (no DB required)
// -----------------------------------------------------------------------
Describe("Route (mock-based)", func() {
var (
reg *fakeModelRouter
backend *stubBackend
factory *stubClientFactory
unloader *fakeUnloader
)
BeforeEach(func() {
reg = &fakeModelRouter{}
backend = &stubBackend{}
factory = &stubClientFactory{client: backend}
unloader = &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.1:9001",
},
}
})
Context("model already loaded on a healthy node", func() {
It("returns the client and a release function", func() {
node := &BackendNode{ID: "n1", Name: "node-1", Address: "10.0.0.1:50051"}
nm := &NodeModel{NodeID: "n1", ModelName: "my-model", Address: "10.0.0.1:9001"}
reg.findAndLockNode = node
reg.findAndLockNM = nm
backend.healthResult = true
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "my-model", "models/my-model.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("n1"))
// TouchNodeModel should have been called
Expect(reg.touchCalls).To(ContainElement("n1:my-model"))
// The initial in-flight reservation from FindAndLockNodeWithModel is released
// after the first inference call completes via OnFirstComplete callback.
// Release only closes the client.
result.Release()
// No decrement on Release — it happens via OnFirstComplete after first Predict
Expect(reg.decrementCalls).To(BeEmpty())
})
})
Context("model not loaded, falls through to scheduling", func() {
It("schedules on an idle node and records the model", func() {
// FindAndLockNodeWithModel always fails — simulates no cached model
// (equivalent to the health-check-failure fallthrough path).
idleNode := &BackendNode{ID: "n2", Name: "idle-node", Address: "10.0.0.2:50051"}
reg2 := &fakeModelRouter{
findAndLockErr: errors.New("not found"),
findIdleNode: idleNode,
}
backend.loadResult = &pb.Result{Success: true}
router := NewSmartRouter(reg2, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "some-model", "models/some-model.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("n2"))
// SetNodeModel should record the model as loaded on the node
Expect(reg2.setCalls).To(HaveLen(1))
Expect(reg2.setCalls[0]).To(ContainSubstring("n2:some-model:loaded"))
})
})
Context("model not loaded, no DB (advisory lock bypassed)", func() {
It("schedules on an available node via FindIdleNode", func() {
reg.findAndLockErr = errors.New("not found")
idleNode := &BackendNode{ID: "n3", Name: "idle", Address: "10.0.0.3:50051"}
reg.findIdleNode = idleNode
backend.loadResult = &pb.Result{Success: true}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
// DB is nil — no advisory lock
})
result, err := router.Route(context.Background(), "new-model", "models/new.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.ID).To(Equal("n3"))
})
})
})
Describe("scheduleNewModel (mock-based, via Route)", func() {
var (
reg *fakeModelRouter
backend *stubBackend
factory *stubClientFactory
unloader *fakeUnloader
)
BeforeEach(func() {
reg = &fakeModelRouter{
findAndLockErr: errors.New("not found"),
}
backend = &stubBackend{
loadResult: &pb.Result{Success: true},
}
factory = &stubClientFactory{client: backend}
unloader = &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.1:9001",
},
}
})
It("finds a node with sufficient VRAM first", func() {
vramNode := &BackendNode{ID: "vram-node", Name: "gpu-box", Address: "10.0.0.10:50051"}
reg.findVRAMNode = vramNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
// Pass non-nil ModelOptions so estimateModelVRAM runs (returns 0 for
// missing files, so FindNodeWithVRAM won't actually be called unless
// estimatedVRAM > 0). To trigger VRAM path we need estimatedVRAM > 0,
// but that requires real files. Instead test the fallback: VRAM returns
// error, idle succeeds.
// Actually, estimateModelVRAM returns 0 when model files don't exist,
// so the VRAM branch is skipped and we go to idle/least-loaded.
// To properly test VRAM path, we'd need to mock estimateModelVRAM.
// For now, verify the fallback paths work correctly.
// With no real model files, estimatedVRAM=0, so VRAM path is skipped.
// Set idle node to test that path.
reg.findVRAMNode = nil
reg.findVRAMErr = errors.New("no vram nodes")
idleNode := &BackendNode{ID: "idle-vram", Name: "idle", Address: "10.0.0.11:50051"}
reg.findIdleNode = idleNode
result, err := router.Route(context.Background(), "m1", "models/m1.gguf", "llama-cpp", &pb.ModelOptions{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.ID).To(Equal("idle-vram"))
})
It("falls back to idle when VRAM search fails", func() {
reg.findVRAMErr = errors.New("no vram")
idleNode := &BackendNode{ID: "idle-1", Name: "idle-node", Address: "10.0.0.20:50051"}
reg.findIdleNode = idleNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "m2", "models/m2.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.ID).To(Equal("idle-1"))
})
It("falls back to least-loaded when both VRAM and idle fail", func() {
reg.findVRAMErr = errors.New("no vram")
reg.findIdleErr = errors.New("no idle")
llNode := &BackendNode{ID: "ll-1", Name: "least-loaded", Address: "10.0.0.30:50051"}
reg.findLeastLoadedNode = llNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "m3", "models/m3.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.ID).To(Equal("ll-1"))
})
It("returns error when no nodes are available and no DB for eviction", func() {
reg.findVRAMErr = errors.New("no vram")
reg.findIdleErr = errors.New("no idle")
reg.findLeastLoadedErr = errors.New("no nodes")
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
// DB is nil — evictLRUAndFreeNode will fail because r.db is nil
})
_, err := router.Route(context.Background(), "m4", "models/m4.gguf", "llama-cpp", nil, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no available nodes"))
})
})
Describe("UnloadModel (mock-based)", func() {
It("calls StopBackend and removes the model from the registry", func() {
reg := &fakeModelRouter{}
unloader := &fakeUnloader{}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
})
err := router.UnloadModel(context.Background(), "node-1", "model-a")
Expect(err).ToNot(HaveOccurred())
Expect(unloader.stopCalls).To(ContainElement("node-1:model-a"))
Expect(reg.removeCalls).To(ContainElement("node-1:model-a"))
})
It("returns error when no unloader is configured", func() {
reg := &fakeModelRouter{}
router := NewSmartRouter(reg, SmartRouterOptions{})
err := router.UnloadModel(context.Background(), "node-1", "model-a")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no remote unloader"))
})
})
Describe("EvictLRU (mock-based)", func() {
It("finds LRU model and unloads it", func() {
reg := &fakeModelRouter{
findLRUModel: &NodeModel{NodeID: "n1", ModelName: "old-model"},
}
unloader := &fakeUnloader{}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
})
evicted, err := router.EvictLRU(context.Background(), "n1")
Expect(err).ToNot(HaveOccurred())
Expect(evicted).To(Equal("old-model"))
Expect(unloader.stopCalls).To(ContainElement("n1:old-model"))
Expect(reg.removeCalls).To(ContainElement("n1:old-model"))
})
It("returns error when no LRU model is found", func() {
reg := &fakeModelRouter{
findLRUErr: errors.New("no models loaded"),
}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: &fakeUnloader{},
})
_, err := router.EvictLRU(context.Background(), "n1")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("finding LRU model"))
})
})
Describe("scheduleNewModel with node selector (mock-based, via Route)", func() {
var (
reg *fakeModelRouter
backend *stubBackend
factory *stubClientFactory
unloader *fakeUnloader
)
BeforeEach(func() {
reg = &fakeModelRouter{
findAndLockErr: errors.New("not found"),
}
backend = &stubBackend{
loadResult: &pb.Result{Success: true},
}
factory = &stubClientFactory{client: backend}
unloader = &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.1:9001",
},
}
})
It("uses *FromSet methods when model has a node selector", func() {
gpuNode := &BackendNode{ID: "gpu-1", Name: "gpu-node", Address: "10.0.0.50:50051"}
reg.getModelScheduling = &ModelSchedulingConfig{
ModelName: "selector-model",
NodeSelector: `{"gpu.vendor":"nvidia"}`,
}
reg.findBySelectorNodes = []BackendNode{*gpuNode}
reg.findIdleFromSetNode = gpuNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "selector-model", "models/selector.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("gpu-1"))
})
It("returns error when no nodes match selector", func() {
reg.getModelScheduling = &ModelSchedulingConfig{
ModelName: "no-match-model",
NodeSelector: `{"gpu.vendor":"tpu"}`,
}
reg.findBySelectorNodes = nil
reg.findBySelectorErr = nil
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
_, err := router.Route(context.Background(), "no-match-model", "models/nomatch.gguf", "llama-cpp", nil, false)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no healthy nodes match selector"))
})
It("uses regular methods when model has no scheduling config", func() {
reg.getModelScheduling = nil
idleNode := &BackendNode{ID: "regular-1", Name: "regular-node", Address: "10.0.0.60:50051"}
reg.findIdleNode = idleNode
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "regular-model", "models/regular.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.ID).To(Equal("regular-1"))
})
})
Describe("Route with selector validation on cached model (mock-based)", func() {
It("falls through when cached node no longer matches selector", func() {
cachedNode := &BackendNode{ID: "n-old", Name: "old-node", Address: "10.0.0.70:50051"}
newNode := &BackendNode{ID: "n-new", Name: "new-node", Address: "10.0.0.71:50051"}
backend := &stubBackend{
healthResult: true,
loadResult: &pb.Result{Success: true},
}
factory := &stubClientFactory{client: backend}
unloader := &fakeUnloader{
installReply: &messaging.BackendInstallReply{
Success: true,
Address: "10.0.0.71:9001",
},
}
reg := &fakeModelRouter{
// Step 1: cached model found on old node
findAndLockNode: cachedNode,
findAndLockNM: &NodeModel{NodeID: "n-old", ModelName: "sel-model", Address: "10.0.0.70:9001"},
// Scheduling config with selector that old node does NOT match
getModelScheduling: &ModelSchedulingConfig{
ModelName: "sel-model",
NodeSelector: `{"gpu.vendor":"nvidia"}`,
},
// Old node has no labels matching the selector
getNodeLabels: []NodeLabel{
{NodeID: "n-old", Key: "gpu.vendor", Value: "amd"},
},
// For scheduling fallthrough: selector matches new node
findBySelectorNodes: []BackendNode{*newNode},
findIdleFromSetNode: newNode,
}
router := NewSmartRouter(reg, SmartRouterOptions{
Unloader: unloader,
ClientFactory: factory,
})
result, err := router.Route(context.Background(), "sel-model", "models/sel.gguf", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
// Should have fallen through to the new node
Expect(result.Node.ID).To(Equal("n-new"))
// Old node should have had its in-flight decremented
Expect(reg.decrementCalls).To(ContainElement("n-old:sel-model"))
})
})
// -----------------------------------------------------------------------
// Integration tests using real PostgreSQL (existing)
// -----------------------------------------------------------------------
Describe("evictLRUAndFreeNode (integration)", func() {
var (
db *gorm.DB
registry *NodeRegistry
)
BeforeEach(func() {
if runtime.GOOS == "darwin" {
Skip("testcontainers requires Docker, not available on macOS CI")
}
db = testutil.SetupTestDB()
var err error
registry, err = NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
It("returns ErrEvictionBusy in under 5 seconds when all models are busy", func() {
node := &BackendNode{
Name: "busy-evict",
NodeType: NodeTypeBackend,
Address: "10.0.0.100:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Load a model and give it in-flight requests so it cannot be evicted
Expect(registry.SetNodeModel(context.Background(), node.ID, "busy-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "busy-model", 0)).To(Succeed())
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
start := time.Now()
_, err := router.evictLRUAndFreeNode(context.Background())
elapsed := time.Since(start)
Expect(err).To(MatchError(ErrEvictionBusy))
// 5 retries * 500ms = 2.5s nominal; allow generous upper bound
Expect(elapsed).To(BeNumerically("<", 5*time.Second))
})
It("respects context cancellation", func() {
node := &BackendNode{
Name: "cancel-evict",
NodeType: NodeTypeBackend,
Address: "10.0.0.101:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "cancel-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "cancel-model", 0)).To(Succeed())
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
start := time.Now()
_, err := router.evictLRUAndFreeNode(ctx)
elapsed := time.Since(start)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("context cancelled"))
// Should return very quickly since context is already done
Expect(elapsed).To(BeNumerically("<", 2*time.Second))
})
})
Describe("stageModelFiles (integration)", func() {
var (
db *gorm.DB
registry *NodeRegistry
)
BeforeEach(func() {
if runtime.GOOS == "darwin" {
Skip("testcontainers requires Docker, not available on macOS CI")
}
db = testutil.SetupTestDB()
var err error
registry, err = NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
It("does not mutate the original ModelOptions", func() {
stager := &fakeFileStager{}
router := NewSmartRouter(registry, SmartRouterOptions{
FileStager: stager,
DB: db,
})
node := &BackendNode{
ID: "stage-node-id",
Name: "stage-node",
Address: "10.0.0.200:50051",
}
original := &pb.ModelOptions{
Model: "test-backend/models/test.gguf",
ModelFile: "/models/test-backend/models/test.gguf",
MMProj: "",
}
// Capture original values before staging
origModel := original.Model
origModelFile := original.ModelFile
origMMProj := original.MMProj
// stageModelFiles creates temp files for os.Stat checks.
// Since none of our test paths exist on disk, stageModelFiles will
// skip them (clearing non-existent optional fields). The key property
// is that the original proto pointer is not modified.
_, _ = router.stageModelFiles(context.Background(), node, original, "test-model")
// Verify the original proto was not mutated
Expect(original.Model).To(Equal(origModel))
Expect(original.ModelFile).To(Equal(origModelFile))
Expect(original.MMProj).To(Equal(origMMProj))
})
})
})