feat(distributed): support multiple replicas of one model on the same node (#9583)

* feat(distributed): support multiple replicas of one model on the same node

The distributed scheduler implicitly assumed `(node_id, model_name)` was
unique, but the schema didn't enforce it and the worker keyed all gRPC
processes by model name alone. With `MinReplicas=2` against a single
worker, the reconciler "scaled up" every 30s but the registry never
advanced past 1 row — the worker re-loaded the model in-place every tick
until VRAM fragmented and the gRPC process died.

This change introduces multi-replica-per-node as a first-class concept,
with capacity-aware scheduling, a circuit breaker, and VRAM
soft-reservation. Operators can declare per-node capacity via the worker
flag `--max-replicas-per-model` (mirrored as auto-label
`node.replica-slots=N`) or override per-node from the UI.

* Schema: BackendNode gains MaxReplicasPerModel (default 1) and
  ReservedVRAM. NodeModel gains ReplicaIndex (composite with node_id +
  model_name). ModelSchedulingConfig gains UnsatisfiableUntil/Ticks for
  the reconciler circuit breaker.

* Registry: replica_index threaded through SetNodeModel, RemoveNodeModel,
  IncrementInFlight, DecrementInFlight, TouchNodeModel, GetNodeModel,
  SetNodeModelLoadInfo and the InFlightTrackingClient. New helpers:
  CountReplicasOnNode, NextFreeReplicaIndex (with ErrNoFreeSlot),
  RemoveAllNodeModelReplicas, FindNodesWithFreeSlot,
  ClusterCapacityForModel, ReserveVRAM/ReleaseVRAM (atomic UPDATE with
  ErrInsufficientVRAM), and the unsatisfiable-flag CRUD.

* Worker: processKey now `<modelID>#<replicaIndex>` so concurrent loads
  of the same model land on distinct ports. Adds CLI flag
  --max-replicas-per-model (env LOCALAI_MAX_REPLICAS_PER_MODEL, default 1)
  and emits the auto-label.

* Router: scheduleNewModel filters candidates by free slot, allocates the
  replica index, and soft-reserves VRAM before installing the backend.
  evictLRUAndFreeNode now deletes the targeted row by ID instead of all
  replicas of the model on the node — fixes a latent bug where evicting
  one replica orphaned its siblings.

* Reconciler: caps scale-up at ClusterCapacityForModel so a misconfig
  (MinReplicas > capacity) doesn't loop forever. After 3 consecutive
  ticks of capacity==0 it sets UnsatisfiableUntil for a 5m cooldown and
  emits a warning. ClearAllUnsatisfiable fires from Register,
  ApproveNode, SetNodeLabel(s), RemoveNodeLabel and
  UpdateMaxReplicasPerModel so a new node joining or label changes wake
  the reconciler immediately. scaleDownIdle removes highest-replica-index
  first to keep slots compact.

* Heartbeat resets reserved_vram to 0 — worker is the source of truth
  for actual free VRAM; the reservation is only for the in-tick race
  window between two scheduling decisions.

* Probe path (reconciler.probeLoadedModels and health.doCheckAll) now
  pass the row's replica_index to RemoveNodeModel so an unreachable
  replica doesn't orphan healthy siblings.

* Admin override: PUT /api/nodes/:id/max-replicas-per-model sets a
  sticky override (preserved across worker re-registration). DELETE
  clears the override so the worker's flag applies again on next
  register. Required because Kong defaults the worker flag to 1, so
  every worker restart would have silently reverted the UI value.

* React UI: always-visible slot badge on the node row (muted at default
  1, accented when >1); inline editor in the expanded drawer with
  pencil-to-edit, Save/Cancel, Esc/Enter, "(override)" indicator when
  the value is admin-set, and a "Reset" button to hand control back to
  the worker. Soft confirm when shrinking the cap below the count of
  loaded replicas. Scheduling rules table gets an "Unsatisfiable until
  HH:MM" status badge surfacing the cooldown.

* node.replica-slots filtered out of the labels strip on the row to
  avoid duplicating the slot badge.

23 new Ginkgo specs (registry, reconciler, inflight, health) cover:
multi-replica row independence, RemoveNodeModel of one replica
preserving siblings, NextFreeReplicaIndex slot allocation including
ErrNoFreeSlot, capacity-gated scale-up with circuit breaker tripping
and recovery on Register, scheduleDownIdle ordering, ClusterCapacity
math, ReserveVRAM admission gating, Heartbeat reset, override survival
across worker re-registration, and ResetMaxReplicasPerModel handing
control back. Plus 8 stdlib tests for the worker processKey / CLI /
auto-label.

Closes the flap reproduced on Qwen3.6-35B against the nvidia-thor
worker (single 128 GiB node, MinReplicas=2): the reconciler now caps
the scale-up at the cluster's actual capacity instead of looping.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Read] [Edit] [Bash] [Skill:critique] [Skill:audit] [Skill:polish] [Skill:golang-testing]

* refactor(react-ui/nodes): tighten capacity editor copy + adopt ActionMenu for row actions

* Capacity editor hint trimmed from operator-doc-style ("Sourced from
  the worker's `--max-replicas-per-model` flag. Changing it here makes it
  a sticky admin override that survives worker restarts." → "Saved
  values stick across worker restarts.") and the override-state copy
  similarly compressed. The full mechanic is no longer needed in the UI
  — the override pill carries the meaning and the docs cover the rest.

* Node row actions migrated from an inline cluster of icon buttons
  (Drain / Resume / Trash) to the kebab ActionMenu used by /manage for
  per-row model actions, so dense Nodes tables stay clean. Approve
  stays as a prominent primary button — it's a stateful admission gate,
  not a routine action, and elevating it matches how /manage surfaces
  install-time decisions outside the menu.

* The expanded drawer's Labels section now filters node.replica-slots
  out of the editable label list. The label is owned by the Capacity
  editor above; surfacing it again as an editable label invited
  confusion (the Capacity save would clobber any direct edit).

Both backend and agent workers benefit — they share the row rendering
path, so the action menu and label filter apply to both.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp] [Skill:critique] [Skill:audit] [Skill:polish]

* fix(react-ui/nodes): suppress slot badge on agent workers

Agent workers don't load models, so the per-node replica capacity is
inapplicable to them. Showing "1× slots" on agent rows was a tiny
inconsistency from the unified rendering path — gate the badge on
node_type !== 'agent' so it only appears on backend workers.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp]

* refactor(react-ui/nodes): distill expanded drawer + restyle scheduling form

The expanded node drawer used to stack five panels — slot badge,
filled capacity box, Loaded Models h4+empty-state, Installed Backends
h4+empty-state, Labels h4+chips+form — making routine inspections feel
like a control panel. The scheduling rule form wrapped its mode toggle
as two 50%-width filled buttons that competed visually with the actual
primary action.

* Drawer: collapse three rarely-touched config zones (Capacity,
  Backends, Labels) into one `<details>` "Manage" disclosure (closed by
  default) with small uppercase eyebrow labels for each zone instead of
  parallel h4 sub-headings. Loaded Models stays as the at-a-glance
  headline with a single-line empty hint instead of a boxed empty state.
  CapacityEditor renders flat (no filled background) — the Manage
  disclosure provides framing.

* Scheduling form: replace the chunky 50%-width button-tabs with the
  project's existing `.segmented` control (icon + label, sized to
  content). Mode hint becomes a single tied line below. Fields stack
  vertically with helper text under inputs and a hairline divider above
  the right-aligned Save / Cancel.

The empty drawer collapses from ~5 stacked sections (~280px tall) to
two lines (~80px). The scheduling form now reads as a designed dialog
instead of raw building blocks. Both surfaces now match the typographic
density and weight of the rest of the admin pages.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp] [Skill:distill] [Skill:audit] [Skill:polish]

* feat(react-ui/nodes): replace scheduling form's model picker with searchable combobox

The native <select> made operators scroll through every gallery entry to
find a model name. The project already has SearchableModelSelect (used
in Studio/Talk/etc.) which combines free-text search with the gallery
list and accepts typed model names that aren't installed yet — useful
for pre-staging a scheduling rule before the node it'll run on has
finished bootstrapping.

Also drops the now-unused useModels import (the combobox manages the
gallery hook internally).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit]

* refactor(react-ui/nodes): consolidate key/value chip editor + add replica preset chips

The Nodes page was rendering the same key=value chip pattern in two
places with subtly different markup: the Labels editor in the expanded
drawer and (post-distill) the Node Selector input in the scheduling
form. The form's input was also a comma-separated string that operators
were getting wrong.

* Extract <KeyValueChips> as a fully controlled chip-builder. Parent
  owns the map and decides what onAdd/onRemove does — form state for the
  scheduling form, API calls for the live drawer Labels editor. Same
  visuals everywhere; one component to change when polish needs apply.

* Replace the comma-separated Node Selector text input with KeyValueChips.
  Operators were copying syntax from docs and missing commas; the chip
  vocabulary makes the key=value structure self-documenting.

* Add <ReplicaInput>: numeric input + quick-pick preset chips for Min/Max
  replicas. Picked over a slider because replica counts are exact specs
  derived from VRAM math (operator decision, not a fuzzy estimate). The
  chips give one-click access to common values (1/2/3/4 for Min,
  0=no-limit/2/4/8 for Max) without the slider's special-value problem
  (MaxReplicas=0 is categorical, not a position on a continuum).

* Drop the now-unused labelInputs state in the Nodes page (the inline
  label editor's per-node draft state lived there and is now owned by
  KeyValueChips).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [Skill:distill]

* test: fix CI fallout from multi-replica refactor (e2e/distributed + playwright)

Two breakages caught by CI that didn't surface in the local run:

* tests/e2e/distributed/*.go — multiple files used the pre-PR2 registry
  signatures for SetNodeModel / IncrementInFlight / DecrementInFlight /
  RemoveNodeModel / TouchNodeModel / GetNodeModel / SetNodeModelLoadInfo
  and one stale adapter.InstallBackend call in node_lifecycle_test.go.
  All updated to pass replicaIndex=0 — these tests don't exercise
  multi-replica behavior, they just need to compile against the new
  signatures. The chip-builder tests in core/services/nodes/ already
  cover the multi-replica logic.

* core/http/react-ui/e2e/nodes-per-node-backend-actions.spec.js — the
  drawer's distill refactor moved Backends inside a "Manage" <details>
  disclosure that's collapsed by default. The test helper expanded the
  node row but never opened Manage, so the per-node backend table was
  never in the DOM. Helper now clicks `.node-manage > summary` after
  expanding the row.

All 100 playwright tests pass locally; tests/e2e/distributed compiles
clean.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [Bash]

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-04-27 21:20:05 +02:00
committed by GitHub
parent f4036fa83f
commit 6b63b47f61
34 changed files with 2491 additions and 569 deletions

View File

@@ -90,6 +90,14 @@ type WorkerCMD struct {
RegistrationToken string `env:"LOCALAI_REGISTRATION_TOKEN" help:"Token for authenticating with the frontend" group:"registration"`
HeartbeatInterval string `env:"LOCALAI_HEARTBEAT_INTERVAL" default:"10s" help:"Interval between heartbeats" group:"registration"`
NodeLabels string `env:"LOCALAI_NODE_LABELS" help:"Comma-separated key=value labels for this node (e.g. tier=fast,gpu=a100)" group:"registration"`
// MaxReplicasPerModel caps how many replicas of any one model can run on
// this worker concurrently. Default 1 = historical single-replica
// behavior. Set higher when a node has enough VRAM to host multiple
// copies of the same model (e.g. a fat 128 GiB box running 4× of a
// 24 GiB model for throughput). The auto-label `node.replica-slots=N`
// is published so model schedulers can target high-capacity nodes via
// the existing label selector.
MaxReplicasPerModel int `env:"LOCALAI_MAX_REPLICAS_PER_MODEL" default:"1" help:"Max replicas of any single model on this worker. Default 1 preserves single-replica behavior; set higher to allow stacking replicas on a fat node." group:"registration"`
// NATS (required)
NatsURL string `env:"LOCALAI_NATS_URL" required:"" help:"NATS server URL" group:"distributed"`
@@ -567,22 +575,35 @@ func (s *backendSupervisor) getAddr(backend string) string {
return ""
}
// buildProcessKey is the supervisor's stable identifier for a backend gRPC
// process. It includes the replica index so the same model can run multiple
// processes on a worker simultaneously without colliding on the same map slot
// or port. The "#N" suffix is purely internal — the controller never reads it.
func buildProcessKey(modelID, backend string, replicaIndex int) string {
base := modelID
if base == "" {
base = backend
}
return fmt.Sprintf("%s#%d", base, replicaIndex)
}
// installBackend handles the backend.install flow:
// 1. If already running for this model, return existing address
// 1. If already running for this (model, replica) slot, return existing address
// 2. Install backend from gallery (if not already installed)
// 3. Find backend binary
// 4. Start gRPC process on a new port
// Returns the gRPC address of the backend process.
//
// ProcessKey includes the replica index so a worker with MaxReplicasPerModel>1
// can host multiple processes for the same model on distinct ports. Old
// controllers (no replica_index in the request) implicitly target replica 0,
// which preserves single-replica behavior.
func (s *backendSupervisor) installBackend(req messaging.BackendInstallRequest) (string, error) {
// Process key: use ModelID if provided (per-model process), else backend name
processKey := req.ModelID
if processKey == "" {
processKey = req.Backend
}
processKey := buildProcessKey(req.ModelID, req.Backend, int(req.ReplicaIndex))
// If already running for this model, return its address
// If already running for this model+replica, return its address
if addr := s.getAddr(processKey); addr != "" {
xlog.Info("Backend already running for model", "backend", req.Backend, "model", req.ModelID, "addr", addr)
xlog.Info("Backend already running for model replica", "backend", req.Backend, "model", req.ModelID, "replica", req.ReplicaIndex, "addr", addr)
return addr, nil
}
@@ -886,13 +907,18 @@ func (cmd *WorkerCMD) registrationBody() map[string]any {
totalVRAM, _ := xsysinfo.TotalAvailableVRAM()
gpuVendor, _ := xsysinfo.DetectGPUVendor()
maxReplicas := cmd.MaxReplicasPerModel
if maxReplicas < 1 {
maxReplicas = 1
}
body := map[string]any{
"name": nodeName,
"address": cmd.advertiseAddr(),
"http_address": cmd.advertiseHTTPAddr(),
"total_vram": totalVRAM,
"available_vram": totalVRAM, // initially all VRAM is available
"gpu_vendor": gpuVendor,
"name": nodeName,
"address": cmd.advertiseAddr(),
"http_address": cmd.advertiseHTTPAddr(),
"total_vram": totalVRAM,
"available_vram": totalVRAM, // initially all VRAM is available
"gpu_vendor": gpuVendor,
"max_replicas_per_model": maxReplicas,
}
// If no GPU detected, report system RAM so the scheduler/UI has capacity info
@@ -906,19 +932,20 @@ func (cmd *WorkerCMD) registrationBody() map[string]any {
body["token"] = cmd.RegistrationToken
}
// Parse and add static node labels
// Parse and add static node labels. Always include the auto-label
// `node.replica-slots=N` so AND-selectors in ModelSchedulingConfig can
// target high-capacity nodes (e.g. {"node.replica-slots":"4"}).
labels := make(map[string]string)
if cmd.NodeLabels != "" {
labels := make(map[string]string)
for _, pair := range strings.Split(cmd.NodeLabels, ",") {
pair = strings.TrimSpace(pair)
if k, v, ok := strings.Cut(pair, "="); ok {
labels[strings.TrimSpace(k)] = strings.TrimSpace(v)
}
}
if len(labels) > 0 {
body["labels"] = labels
}
}
labels["node.replica-slots"] = strconv.Itoa(maxReplicas)
body["labels"] = labels
return body
}

View File

@@ -0,0 +1,70 @@
package cli
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Worker per-replica process keying", func() {
Describe("buildProcessKey", func() {
// Pin the supervisor's keying contract: distinct replica indexes for
// the same modelID produce distinct process keys, so the supervisor
// map can hold multiple processes for one model. Dropping the suffix
// would re-introduce the original flap (one model, one slot, churn).
DescribeTable("produces stable, distinct keys",
func(modelID, backend string, replica int, want string) {
Expect(buildProcessKey(modelID, backend, replica)).To(Equal(want))
},
Entry("modelID present, replica 0", "Qwen3-35B", "llama-cpp", 0, "Qwen3-35B#0"),
Entry("modelID present, replica 1", "Qwen3-35B", "llama-cpp", 1, "Qwen3-35B#1"),
Entry("falls back to backend when modelID empty", "", "llama-cpp", 0, "llama-cpp#0"),
Entry("backend fallback with replica 2", "", "llama-cpp", 2, "llama-cpp#2"),
)
It("makes replicas distinguishable", func() {
r0 := buildProcessKey("model-a", "llama-cpp", 0)
r1 := buildProcessKey("model-a", "llama-cpp", 1)
Expect(r0).ToNot(Equal(r1), "replicas of the same model must produce distinct keys")
})
})
Describe("registrationBody", func() {
It("includes max_replicas_per_model and the auto-label", func() {
cmd := &WorkerCMD{
Addr: "worker.example.com:50051",
MaxReplicasPerModel: 4,
}
body := cmd.registrationBody()
Expect(body).To(HaveKey("max_replicas_per_model"))
Expect(body["max_replicas_per_model"]).To(Equal(4))
labels, ok := body["labels"].(map[string]string)
Expect(ok).To(BeTrue(), "labels must be present so selectors can target the slot count")
Expect(labels).To(HaveKeyWithValue("node.replica-slots", "4"))
})
It("coerces zero/unset MaxReplicasPerModel to 1", func() {
cmd := &WorkerCMD{Addr: "worker.example.com:50051"}
body := cmd.registrationBody()
Expect(body["max_replicas_per_model"]).To(Equal(1),
"unset must default to single-replica behavior, not capacity 0")
labels := body["labels"].(map[string]string)
Expect(labels).To(HaveKeyWithValue("node.replica-slots", "1"))
})
It("preserves user-provided labels alongside the auto-label", func() {
cmd := &WorkerCMD{
Addr: "worker.example.com:50051",
MaxReplicasPerModel: 2,
NodeLabels: "tier=fast,gpu=a100",
}
body := cmd.registrationBody()
labels := body["labels"].(map[string]string)
Expect(labels).To(HaveKeyWithValue("tier", "fast"))
Expect(labels).To(HaveKeyWithValue("gpu", "a100"))
Expect(labels).To(HaveKeyWithValue("node.replica-slots", "2"))
})
})
})

View File

@@ -73,6 +73,10 @@ type RegisterNodeRequest struct {
AvailableRAM uint64 `json:"available_ram,omitempty"`
GPUVendor string `json:"gpu_vendor,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
// MaxReplicasPerModel is the per-node cap on replicas of any single model.
// Workers older than this field omit it; we coerce 0 → 1 below to preserve
// historical single-replica behavior.
MaxReplicasPerModel int `json:"max_replicas_per_model,omitempty"`
}
// RegisterNodeEndpoint registers a new backend node.
@@ -131,17 +135,26 @@ func RegisterNodeEndpoint(registry *nodes.NodeRegistry, expectedToken string, au
tokenHash = hex.EncodeToString(h[:])
}
// Coerce 0 → 1 for backward compat with workers that don't send the field.
// GORM's `default:1` only fires for a missing column; once Go zero-values
// reach the struct field they're written as 0 unless explicitly set here.
maxReplicasPerModel := req.MaxReplicasPerModel
if maxReplicasPerModel < 1 {
maxReplicasPerModel = 1
}
node := &nodes.BackendNode{
Name: req.Name,
NodeType: nodeType,
Address: req.Address,
HTTPAddress: req.HTTPAddress,
TokenHash: tokenHash,
TotalVRAM: req.TotalVRAM,
AvailableVRAM: req.AvailableVRAM,
TotalRAM: req.TotalRAM,
AvailableRAM: req.AvailableRAM,
GPUVendor: req.GPUVendor,
Name: req.Name,
NodeType: nodeType,
Address: req.Address,
HTTPAddress: req.HTTPAddress,
TokenHash: tokenHash,
TotalVRAM: req.TotalVRAM,
AvailableVRAM: req.AvailableVRAM,
TotalRAM: req.TotalRAM,
AvailableRAM: req.AvailableRAM,
GPUVendor: req.GPUVendor,
MaxReplicasPerModel: maxReplicasPerModel,
}
ctx := c.Request().Context()
@@ -386,7 +399,10 @@ func InstallBackendOnNodeEndpoint(unloader nodes.NodeCommandSender) echo.Handler
if req.Backend == "" && req.URI == "" {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "backend name or uri required"))
}
reply, err := unloader.InstallBackend(nodeID, req.Backend, "", req.BackendGalleries, req.URI, req.Name, req.Alias)
// Admin-driven backend install: not tied to a specific replica slot
// (no model is being loaded). Pass replica 0 to match the worker's
// admin process-key convention (`backend#0`).
reply, err := unloader.InstallBackend(nodeID, req.Backend, "", req.BackendGalleries, req.URI, req.Name, req.Alias, 0)
if err != nil {
xlog.Error("Failed to install backend on node", "node", nodeID, "backend", req.Backend, "uri", req.URI, "error", err)
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to install backend on node"))
@@ -467,8 +483,8 @@ func UnloadModelOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes
xlog.Error("Failed to stop backend after model unload", "node", nodeID, "model", req.ModelName, "error", err)
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "model unloaded but backend stop failed"))
}
// Remove from registry
registry.RemoveNodeModel(c.Request().Context(), nodeID, req.ModelName)
// Remove every replica of this model on the node from the registry.
registry.RemoveAllNodeModelReplicas(c.Request().Context(), nodeID, req.ModelName)
return c.JSON(http.StatusOK, map[string]string{"message": "model unloaded"})
}
}
@@ -494,7 +510,7 @@ func DeleteModelOnNodeEndpoint(unloader nodes.NodeCommandSender, registry *nodes
// Non-fatal — backend process may not be running
xlog.Warn("StopBackend failed during model deletion (non-fatal)", "node", nodeID, "model", req.ModelName, "error", err)
}
registry.RemoveNodeModel(c.Request().Context(), nodeID, req.ModelName)
registry.RemoveAllNodeModelReplicas(c.Request().Context(), nodeID, req.ModelName)
return c.JSON(http.StatusOK, map[string]string{"message": "model deleted from node"})
}
}
@@ -669,6 +685,78 @@ func GetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
}
}
// UpdateMaxReplicasPerModelRequest is the body for the per-node replica cap endpoint.
type UpdateMaxReplicasPerModelRequest struct {
// Value is the new per-model replica cap on this node. Must be >= 1.
Value int `json:"value"`
}
// UpdateMaxReplicasPerModelEndpoint sets the per-node cap on how many replicas
// of any one model can be loaded concurrently. The corresponding
// `node.replica-slots` auto-label is refreshed so existing AND-selectors keep
// matching, and any unsatisfiable scheduling cooldowns are cleared so the
// reconciler retries on the next tick.
//
// This is a transient admin override — a worker re-registration restores the
// value the worker was started with (--max-replicas-per-model). For permanent
// fleet changes, change the worker flag.
//
// @Summary Update a node's max replicas per model
// @Tags Nodes
// @Param id path string true "Node ID"
// @Param request body UpdateMaxReplicasPerModelRequest true "New value"
// @Success 200 {object} map[string]int
// @Failure 400 {object} map[string]any "value must be >= 1"
// @Failure 404 {object} map[string]any "node not found"
// @Router /api/nodes/{id}/max-replicas-per-model [put]
func UpdateMaxReplicasPerModelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
nodeID := c.Param("id")
if _, err := registry.Get(ctx, nodeID); err != nil {
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
}
var req UpdateMaxReplicasPerModelRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "invalid request body"))
}
if req.Value < 1 {
return c.JSON(http.StatusBadRequest, nodeError(http.StatusBadRequest, "value must be >= 1"))
}
if err := registry.UpdateMaxReplicasPerModel(ctx, nodeID, req.Value); err != nil {
xlog.Error("Failed to update max_replicas_per_model", "node", nodeID, "value", req.Value, "error", err)
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to update max replicas per model"))
}
return c.JSON(http.StatusOK, map[string]int{"max_replicas_per_model": req.Value})
}
}
// ResetMaxReplicasPerModelEndpoint clears the admin override on a node, so
// the next worker re-registration is allowed to update the value from its
// CLI flag again. The current value is left in place until the worker calls
// register.
//
// @Summary Reset a node's max replicas per model to the worker default
// @Tags Nodes
// @Param id path string true "Node ID"
// @Success 200 {object} map[string]bool
// @Failure 404 {object} map[string]any "node not found"
// @Router /api/nodes/{id}/max-replicas-per-model [delete]
func ResetMaxReplicasPerModelEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
nodeID := c.Param("id")
if _, err := registry.Get(ctx, nodeID); err != nil {
return c.JSON(http.StatusNotFound, nodeError(http.StatusNotFound, "node not found"))
}
if err := registry.ResetMaxReplicasPerModel(ctx, nodeID); err != nil {
xlog.Error("Failed to reset max_replicas_per_model override", "node", nodeID, "error", err)
return c.JSON(http.StatusInternalServerError, nodeError(http.StatusInternalServerError, "failed to reset override"))
}
return c.JSON(http.StatusOK, map[string]bool{"reset": true})
}
}
// SetNodeLabelsEndpoint replaces all labels for a node.
func SetNodeLabelsEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
return func(c echo.Context) error {

View File

@@ -85,6 +85,12 @@ async function expandNodeAndWaitForBackends(page) {
// Click the row to expand it. The chevron toggle and the row both work,
// but clicking the name cell is the most user-like.
await page.getByText(NODE_NAME).first().click()
// Backends, Capacity and Labels live behind a "Manage" <details>
// disclosure (the drawer was distilled to keep at-a-glance content
// lean — see distill refactor in the multi-replica branch). Open it
// by clicking the summary inside the .node-manage scope so the
// per-node backend table is in the DOM before assertions run.
await page.locator('.node-manage > summary').first().click()
await expect(page.getByRole('cell', { name: BACKEND_NAME, exact: true })).toBeVisible({ timeout: 10_000 })
}

View File

@@ -1977,6 +1977,41 @@ select.input {
opacity: 0.8;
}
/* Small caps eyebrow inside the drawer's "Manage" disclosure. Replaces the
h4 sub-headings that used to stack inside the drawer — at this depth, an
eyebrow keeps the typographic hierarchy from feeling parallel to the
page-level h1/h2 stack. */
.drawer-eyebrow {
font-size: 0.6875rem;
font-weight: var(--font-weight-semibold);
letter-spacing: 0.06em;
text-transform: uppercase;
color: var(--color-text-muted);
margin-bottom: var(--spacing-xs);
}
/* "Manage" disclosure inside the node drawer. The chevron rotates with the
open state so the affordance reads as an accordion, not a link. */
.node-manage > summary {
user-select: none;
outline: none;
}
.node-manage > summary::-webkit-details-marker {
display: none;
}
.node-manage > summary:focus-visible {
outline: 2px solid var(--color-primary);
outline-offset: 2px;
border-radius: var(--radius-sm);
}
.node-manage__chevron {
font-size: 0.625rem;
transition: transform var(--duration-fast) ease-out;
}
.node-manage[open] > summary .node-manage__chevron {
transform: rotate(90deg);
}
/* Node-status indicator — replaces the tiny bullet with a proper LED-style
dot next to a bold status label. Colors are applied inline from statusConfig
so one primitive handles healthy/unhealthy/draining/pending in one shape. */

View File

File diff suppressed because it is too large Load Diff

View File

@@ -481,6 +481,17 @@ export const nodesApi = {
getLabels: (id) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id)),
mergeLabels: (id, labels) => fetchJSON(API_CONFIG.endpoints.nodeLabels(id), { method: 'PATCH', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(labels) }),
deleteLabel: (id, key) => fetchJSON(API_CONFIG.endpoints.nodeLabelKey(id, key), { method: 'DELETE' }),
// Set a sticky admin override for the per-node replica cap. The override
// is preserved across worker restarts; call resetMaxReplicasPerModel to
// hand control back to the worker's CLI flag.
updateMaxReplicasPerModel: (id, value) => fetchJSON(API_CONFIG.endpoints.nodeMaxReplicasPerModel(id), {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ value }),
}),
resetMaxReplicasPerModel: (id) => fetchJSON(API_CONFIG.endpoints.nodeMaxReplicasPerModel(id), {
method: 'DELETE',
}),
listScheduling: () => fetchJSON(API_CONFIG.endpoints.nodesScheduling),
setScheduling: (config) => postJSON(API_CONFIG.endpoints.nodesScheduling, config),
deleteScheduling: (model) => fetchJSON(API_CONFIG.endpoints.nodesSchedulingModel(model), { method: 'DELETE' }),

View File

@@ -138,6 +138,7 @@ export const API_CONFIG = {
nodeModelsUnload: (id) => `/api/nodes/${id}/models/unload`,
nodeLabels: (id) => `/api/nodes/${id}/labels`,
nodeLabelKey: (id, key) => `/api/nodes/${id}/labels/${key}`,
nodeMaxReplicasPerModel: (id) => `/api/nodes/${id}/max-replicas-per-model`,
nodesScheduling: '/api/nodes/scheduling',
nodesSchedulingModel: (model) => `/api/nodes/scheduling/${encodeURIComponent(model)}`,
},

View File

@@ -95,6 +95,12 @@ func RegisterNodeAdminRoutes(e *echo.Echo, registry *nodes.NodeRegistry, unloade
admin.PATCH("/:id/labels", localai.MergeNodeLabelsEndpoint(registry))
admin.DELETE("/:id/labels/:key", localai.DeleteNodeLabelEndpoint(registry))
// Per-node replica capacity. PUT sets a sticky admin override that
// survives worker restarts. DELETE clears the override so the worker's
// CLI flag takes over again at the next re-registration.
admin.PUT("/:id/max-replicas-per-model", localai.UpdateMaxReplicasPerModelEndpoint(registry))
admin.DELETE("/:id/max-replicas-per-model", localai.ResetMaxReplicasPerModelEndpoint(registry))
// WebSocket proxy for real-time log streaming from workers
e.GET("/ws/nodes/:id/backend-logs/:modelId", localai.NodeBackendLogsWSEndpoint(registry, registrationToken), readyMw, adminMw)
}

View File

@@ -131,6 +131,12 @@ type BackendInstallRequest struct {
URI string `json:"uri,omitempty"`
Name string `json:"name,omitempty"`
Alias string `json:"alias,omitempty"`
// ReplicaIndex selects which slot on the worker this load occupies, so two
// concurrent backend.install requests for the same model land on distinct
// gRPC processes and ports. Workers older than this field treat it as 0
// (single-replica behavior — no collision because the controller never
// asks for replica > 0 on a node whose MaxReplicasPerModel is 1).
ReplicaIndex int32 `json:"replica_index,omitempty"`
}
// BackendInstallReply is the response from a backend.install NATS request.

View File

@@ -159,8 +159,8 @@ func (hm *HealthMonitor) doCheckAll(ctx context.Context) {
mCheckCtx, mCancel := context.WithTimeout(ctx, 5*time.Second)
if ok, _ := mClient.HealthCheck(mCheckCtx); !ok {
xlog.Warn("Model backend unhealthy, removing from registry",
"node", node.ID, "model", m.ModelName, "address", m.Address)
hm.registry.RemoveNodeModel(ctx, node.ID, m.ModelName)
"node", node.ID, "model", m.ModelName, "replica", m.ReplicaIndex, "address", m.Address)
hm.registry.RemoveNodeModel(ctx, node.ID, m.ModelName, m.ReplicaIndex)
}
mCancel()
if closer, ok := mClient.(io.Closer); ok {

View File

@@ -2,6 +2,7 @@ package nodes
import (
"context"
"fmt"
"sync"
"time"
@@ -122,8 +123,8 @@ func (f *fakeNodeHealthStore) FindStaleNodes(_ context.Context, _ time.Duration)
return nil, nil
}
func (f *fakeNodeHealthStore) RemoveNodeModel(_ context.Context, nodeID, modelName string) error {
f.record("RemoveNodeModel:" + nodeID + ":" + modelName)
func (f *fakeNodeHealthStore) RemoveNodeModel(_ context.Context, nodeID, modelName string, replicaIndex int) error {
f.record(fmt.Sprintf("RemoveNodeModel:%s:%s:%d", nodeID, modelName, replicaIndex))
return nil
}

View File

@@ -270,9 +270,9 @@ var _ = Describe("HealthMonitor (mock-based)", func() {
hm.doCheckAll(context.Background())
// Node should remain healthy — only the model record is removed
// Node should remain healthy — only the specific replica record is removed.
Expect(store.getNode("node-model").Status).To(Equal(StatusHealthy))
Expect(store.getCalls()).To(ContainElement("RemoveNodeModel:node-model:piper-model"))
Expect(store.getCalls()).To(ContainElement("RemoveNodeModel:node-model:piper-model:0"))
Expect(store.getCalls()).NotTo(ContainElement(ContainSubstring("MarkUnhealthy")))
})
})

View File

@@ -14,23 +14,29 @@ import (
// InFlightTrackingClient wraps a grpc.Backend and tracks active inference requests
// in the NodeRegistry. This allows the router's eviction logic to know which models
// are actively serving and should not be unloaded.
//
// Per-replica: a single tracker instance is bound to (nodeID, modelName, replicaIndex).
// The router constructs one tracker per Route() result, so each in-flight tick lands
// on the correct row even when multiple replicas of the same model live on the same node.
type InFlightTrackingClient struct {
grpc.Backend // embed for passthrough of untracked methods
registry InFlightTracker
nodeID string
modelName string
replicaIndex int
firstOnce sync.Once // guards onFirstComplete
onFirstComplete func() // called once after the first tracked inference call completes
firstOnce sync.Once // guards onFirstComplete
onFirstComplete func() // called once after the first tracked inference call completes
}
// NewInFlightTrackingClient wraps a gRPC backend client with in-flight tracking.
func NewInFlightTrackingClient(inner grpc.Backend, registry InFlightTracker, nodeID, modelName string) *InFlightTrackingClient {
func NewInFlightTrackingClient(inner grpc.Backend, registry InFlightTracker, nodeID, modelName string, replicaIndex int) *InFlightTrackingClient {
return &InFlightTrackingClient{
Backend: inner,
registry: registry,
nodeID: nodeID,
modelName: modelName,
Backend: inner,
registry: registry,
nodeID: nodeID,
modelName: modelName,
replicaIndex: replicaIndex,
}
}
@@ -43,14 +49,14 @@ func (c *InFlightTrackingClient) OnFirstComplete(fn func()) {
}
func (c *InFlightTrackingClient) track(ctx context.Context) func() {
if err := c.registry.IncrementInFlight(ctx, c.nodeID, c.modelName); err != nil {
xlog.Warn("Failed to increment in-flight counter", "node", c.nodeID, "model", c.modelName, "error", err)
if err := c.registry.IncrementInFlight(ctx, c.nodeID, c.modelName, c.replicaIndex); err != nil {
xlog.Warn("Failed to increment in-flight counter", "node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex, "error", err)
return func() {}
}
return func() {
decCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c.registry.DecrementInFlight(decCtx, c.nodeID, c.modelName)
c.registry.DecrementInFlight(decCtx, c.nodeID, c.modelName, c.replicaIndex)
// Release the initial reservation after the first inference call completes
if c.onFirstComplete != nil {
c.firstOnce.Do(c.onFirstComplete)

View File

@@ -22,14 +22,14 @@ type fakeInFlightTracker struct {
incrementErr error
}
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string) error {
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.increments++
return f.incrementErr
}
func (f *fakeInFlightTracker) DecrementInFlight(_ context.Context, _, _ string) error {
func (f *fakeInFlightTracker) DecrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.decrements++
@@ -218,7 +218,7 @@ var _ = Describe("InFlightTrackingClient", func() {
predictReply: &pb.Reply{Message: []byte("hello")},
streamReplies: []*pb.Reply{{Message: []byte("chunk")}},
}
client = NewInFlightTrackingClient(backend, tracker, "node-1", "llama")
client = NewInFlightTrackingClient(backend, tracker, "node-1", "llama", 0)
})
Describe("track", func() {

View File

@@ -10,13 +10,16 @@ import (
// ModelRouter is used by SmartRouter for routing decisions and model lifecycle.
type ModelRouter interface {
FindAndLockNodeWithModel(ctx context.Context, modelName string) (*BackendNode, *NodeModel, error)
DecrementInFlight(ctx context.Context, nodeID, modelName string) error
IncrementInFlight(ctx context.Context, nodeID, modelName string) error
RemoveNodeModel(ctx context.Context, nodeID, modelName string) error
TouchNodeModel(ctx context.Context, nodeID, modelName string)
SetNodeModel(ctx context.Context, nodeID, modelName, state, address string, initialInFlight int) error
SetNodeModelLoadInfo(ctx context.Context, nodeID, modelName, backendType string, optsBlob []byte) error
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
RemoveAllNodeModelReplicas(ctx context.Context, nodeID, modelName string) error
TouchNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int)
SetNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int, state, address string, initialInFlight int) error
SetNodeModelLoadInfo(ctx context.Context, nodeID, modelName string, replicaIndex int, backendType string, optsBlob []byte) error
GetModelLoadInfo(ctx context.Context, modelName string) (backendType string, optsBlob []byte, err error)
NextFreeReplicaIndex(ctx context.Context, nodeID, modelName string, maxSlots int) (int, error)
CountReplicasOnNode(ctx context.Context, nodeID, modelName string) (int, error)
FindNodeWithVRAM(ctx context.Context, minBytes uint64) (*BackendNode, error)
FindIdleNode(ctx context.Context) (*BackendNode, error)
FindLeastLoadedNode(ctx context.Context) (*BackendNode, error)
@@ -25,6 +28,9 @@ type ModelRouter interface {
Get(ctx context.Context, nodeID string) (*BackendNode, error)
GetModelScheduling(ctx context.Context, modelName string) (*ModelSchedulingConfig, error)
FindNodesBySelector(ctx context.Context, selector map[string]string) ([]BackendNode, error)
FindNodesWithFreeSlot(ctx context.Context, modelName string, candidateNodeIDs []string) ([]BackendNode, error)
ReserveVRAM(ctx context.Context, nodeID string, bytes uint64) error
ReleaseVRAM(ctx context.Context, nodeID string, bytes uint64) error
FindNodeWithVRAMFromSet(ctx context.Context, minBytes uint64, nodeIDs []string) (*BackendNode, error)
FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
@@ -40,13 +46,14 @@ type NodeHealthStore interface {
MarkHealthy(ctx context.Context, nodeID string) error
Heartbeat(ctx context.Context, nodeID string, update *HeartbeatUpdate) error
FindStaleNodes(ctx context.Context, threshold time.Duration) ([]BackendNode, error)
RemoveNodeModel(ctx context.Context, nodeID, modelName string) error
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
}
// ModelLocator is used by RemoteUnloaderAdapter for model discovery.
type ModelLocator interface {
FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error)
RemoveNodeModel(ctx context.Context, nodeID, modelName string) error
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
RemoveAllNodeModelReplicas(ctx context.Context, nodeID, modelName string) error
}
// ModelLookup is used by DistributedModelStore for model existence queries.
@@ -58,8 +65,8 @@ type ModelLookup interface {
// InFlightTracker is used by InFlightTrackingClient for request counting.
type InFlightTracker interface {
IncrementInFlight(ctx context.Context, nodeID, modelName string) error
DecrementInFlight(ctx context.Context, nodeID, modelName string) error
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
}
// NodeManager is used by HTTP endpoints for node registration and lifecycle.
@@ -76,7 +83,8 @@ type NodeManager interface {
Heartbeat(ctx context.Context, nodeID string, update *HeartbeatUpdate) error
GetNodeModels(ctx context.Context, nodeID string) ([]NodeModel, error)
UpdateAuthRefs(ctx context.Context, nodeID, authUserID, apiKeyID string) error
RemoveNodeModel(ctx context.Context, nodeID, modelName string) error
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
RemoveAllNodeModelReplicas(ctx context.Context, nodeID, modelName string) error
}
// BackendClientFactory creates gRPC backend clients.

View File

@@ -328,7 +328,10 @@ func (d *DistributedBackendManager) InstallBackend(ctx context.Context, op *gall
backendName := op.GalleryElementName
result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendInstall, backendName, galleriesJSON, func(node BackendNode) error {
reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON), op.ExternalURI, op.ExternalName, op.ExternalAlias)
// Admin-driven backend install: not tied to a specific replica slot.
// Pass replica 0 — the worker's processKey is "backend#0" when no
// modelID is supplied, matching pre-PR4 behavior.
reply, err := d.adapter.InstallBackend(node.ID, backendName, "", string(galleriesJSON), op.ExternalURI, op.ExternalName, op.ExternalAlias, 0)
if err != nil {
return err
}
@@ -349,7 +352,7 @@ func (d *DistributedBackendManager) UpgradeBackend(ctx context.Context, name str
galleriesJSON, _ := json.Marshal(d.backendGalleries)
result, err := d.enqueueAndDrainBackendOp(ctx, OpBackendUpgrade, name, galleriesJSON, func(node BackendNode) error {
reply, err := d.adapter.InstallBackend(node.ID, name, "", string(galleriesJSON), "", "", "")
reply, err := d.adapter.InstallBackend(node.ID, name, "", string(galleriesJSON), "", "", "", 0)
if err != nil {
return err
}

View File

@@ -33,29 +33,38 @@ func (f *fakeModelRouterForSmartRouter) FindAndLockNodeWithModel(_ context.Conte
return f.node, f.nodeModel, f.findErr
}
func (f *fakeModelRouterForSmartRouter) DecrementInFlight(_ context.Context, nodeID, modelName string) error {
func (f *fakeModelRouterForSmartRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.decrementCalled[nodeID+":"+modelName]++
return nil
}
func (f *fakeModelRouterForSmartRouter) IncrementInFlight(_ context.Context, _, _ string) error {
func (f *fakeModelRouterForSmartRouter) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
return nil
}
func (f *fakeModelRouterForSmartRouter) RemoveNodeModel(_ context.Context, _, _ string) error {
func (f *fakeModelRouterForSmartRouter) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
return nil
}
func (f *fakeModelRouterForSmartRouter) TouchNodeModel(_ context.Context, _, _ string) {}
func (f *fakeModelRouterForSmartRouter) SetNodeModel(_ context.Context, _, _, _, _ string, _ int) error {
func (f *fakeModelRouterForSmartRouter) RemoveAllNodeModelReplicas(_ context.Context, _, _ string) error {
return nil
}
func (f *fakeModelRouterForSmartRouter) SetNodeModelLoadInfo(_ context.Context, _, _, _ string, _ []byte) error {
func (f *fakeModelRouterForSmartRouter) TouchNodeModel(_ context.Context, _, _ string, _ int) {}
func (f *fakeModelRouterForSmartRouter) SetNodeModel(_ context.Context, _, _ string, _ int, _, _ string, _ int) error {
return nil
}
func (f *fakeModelRouterForSmartRouter) SetNodeModelLoadInfo(_ context.Context, _, _ string, _ int, _ string, _ []byte) error {
return nil
}
func (f *fakeModelRouterForSmartRouter) GetModelLoadInfo(_ context.Context, _ string) (string, []byte, error) {
return "", nil, fmt.Errorf("not found")
}
func (f *fakeModelRouterForSmartRouter) NextFreeReplicaIndex(_ context.Context, _, _ string, _ int) (int, error) {
return 0, nil
}
func (f *fakeModelRouterForSmartRouter) CountReplicasOnNode(_ context.Context, _, _ string) (int, error) {
return 0, nil
}
func (f *fakeModelRouterForSmartRouter) FindNodeWithVRAM(_ context.Context, _ uint64) (*BackendNode, error) {
return nil, nil
}
@@ -85,6 +94,15 @@ func (f *fakeModelRouterForSmartRouter) GetModelScheduling(_ context.Context, _
func (f *fakeModelRouterForSmartRouter) FindNodesBySelector(_ context.Context, _ map[string]string) ([]BackendNode, error) {
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) FindNodesWithFreeSlot(_ context.Context, _ string, _ []string) ([]BackendNode, error) {
return nil, nil
}
func (f *fakeModelRouterForSmartRouter) ReserveVRAM(_ context.Context, _ string, _ uint64) error {
return nil
}
func (f *fakeModelRouterForSmartRouter) ReleaseVRAM(_ context.Context, _ string, _ uint64) error {
return nil
}
func (f *fakeModelRouterForSmartRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
return nil, nil
}

View File

@@ -188,7 +188,9 @@ func (rc *ReplicaReconciler) drainPendingBackendOps(ctx context.Context) {
case OpBackendDelete:
_, applyErr = rc.adapter.DeleteBackend(op.NodeID, op.Backend)
case OpBackendInstall, OpBackendUpgrade:
reply, err := rc.adapter.InstallBackend(op.NodeID, op.Backend, "", string(op.Galleries), "", "", "")
// Pending-op drain for admin install/upgrade — not a per-replica
// load. Replica 0 is the conventional admin slot.
reply, err := rc.adapter.InstallBackend(op.NodeID, op.Backend, "", string(op.Galleries), "", "", "", 0)
if err != nil {
applyErr = err
} else if !reply.Success {
@@ -276,12 +278,12 @@ func (rc *ReplicaReconciler) probeLoadedModels(ctx context.Context) {
Where("id = ?", m.ID).Update("updated_at", time.Now()).Error
continue
}
if err := rc.registry.RemoveNodeModel(ctx, m.NodeID, m.ModelName); err != nil {
xlog.Warn("Reconciler: failed to remove unreachable model", "node", m.NodeID, "model", m.ModelName, "error", err)
if err := rc.registry.RemoveNodeModel(ctx, m.NodeID, m.ModelName, m.ReplicaIndex); err != nil {
xlog.Warn("Reconciler: failed to remove unreachable model", "node", m.NodeID, "model", m.ModelName, "replica", m.ReplicaIndex, "error", err)
continue
}
xlog.Warn("Reconciler: model unreachable, removed from registry",
"node", m.NodeID, "model", m.ModelName, "address", m.Address)
"node", m.NodeID, "model", m.ModelName, "replica", m.ReplicaIndex, "address", m.Address)
}
}
@@ -300,25 +302,112 @@ func (rc *ReplicaReconciler) reconcile(ctx context.Context) {
}
}
// unsatisfiableTickThreshold is how many consecutive ticks of "capacity == 0
// && need > 0" must elapse before the reconciler stops trying. Three ticks at
// the default 30s interval gives ~90s of grace before logging a warning and
// entering cooldown — enough to ride out a transient race between Register
// and the next tick, but short enough that a misconfig (MinReplicas above
// cluster capacity) doesn't churn the worker forever like it did pre-PR4.
const unsatisfiableTickThreshold = 3
// unsatisfiableCooldown is the duration the reconciler waits before retrying
// after the threshold trips. ClearAllUnsatisfiable on cluster events shortens
// this in practice — the cooldown is the worst-case, not the steady-state.
const unsatisfiableCooldown = 5 * time.Minute
// candidateNodeIDsForSelector resolves the model's NodeSelector to a slice
// of node IDs, or returns nil if no selector is configured (meaning "any
// healthy node" — registry helpers interpret nil as no candidate filter).
// Returns ok=false if a non-empty selector matched zero nodes, in which case
// the caller should skip — there's nothing to schedule on.
func (rc *ReplicaReconciler) candidateNodeIDsForSelector(ctx context.Context, cfg ModelSchedulingConfig) (ids []string, ok bool) {
if cfg.NodeSelector == "" {
return nil, true
}
sel := parseSelector(cfg.NodeSelector)
if len(sel) == 0 {
return nil, true
}
nodes, err := rc.registry.FindNodesBySelector(ctx, sel)
if err != nil || len(nodes) == 0 {
return nil, false
}
ids = make([]string, len(nodes))
for i, n := range nodes {
ids[i] = n.ID
}
return ids, true
}
func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedulingConfig) {
// Cooldown gate: if we previously decided this config is unsatisfiable,
// don't even bother checking until the cooldown expires. ClearAllUnsatisfiable
// (fired by node lifecycle events) bypasses this by zeroing the column.
if cfg.UnsatisfiableUntil != nil && cfg.UnsatisfiableUntil.After(time.Now()) {
return
}
current, err := rc.registry.CountLoadedReplicas(ctx, cfg.ModelName)
if err != nil {
xlog.Warn("Reconciler: failed to count replicas", "model", cfg.ModelName, "error", err)
return
}
// 1. Ensure minimum replicas
// 1. Ensure minimum replicas, but only up to what the cluster can host.
// Without this cap, a MinReplicas above cluster capacity would loop
// forever (the original flap: every tick "scaling up", but the registry
// never grows because all nodes are full).
if cfg.MinReplicas > 0 && int(current) < cfg.MinReplicas {
candidateNodeIDs, selectorMatched := rc.candidateNodeIDsForSelector(ctx, cfg)
if !selectorMatched {
xlog.Warn("Reconciler: no nodes match selector", "model", cfg.ModelName, "selector", cfg.NodeSelector)
rc.markCapacityProblem(ctx, cfg.ModelName, "no nodes match selector")
return
}
capacity, capErr := rc.registry.ClusterCapacityForModel(ctx, cfg.ModelName, candidateNodeIDs)
if capErr != nil {
xlog.Warn("Reconciler: failed to compute cluster capacity", "model", cfg.ModelName, "error", capErr)
return
}
needed := cfg.MinReplicas - int(current)
if capacity == 0 {
// No capacity right now. Bump hysteresis; trip cooldown if it
// crosses the threshold. ClearAllUnsatisfiable resets this on
// any plausible capacity-changing event.
rc.markCapacityProblem(ctx, cfg.ModelName, "cluster capacity exhausted")
return
}
// Cap to actual capacity so we don't try harder than possible.
if needed > capacity {
xlog.Info("Reconciler: capping scale-up at cluster capacity", "model", cfg.ModelName,
"need", needed, "capacity", capacity)
needed = capacity
}
xlog.Info("Reconciler: scaling up to meet minimum", "model", cfg.ModelName,
"current", current, "min", cfg.MinReplicas, "adding", needed)
rc.scaleUp(ctx, cfg, needed)
// Successful (or partial) scale-up clears the hysteresis so a future
// dip starts fresh.
_ = rc.registry.ClearUnsatisfiable(ctx, cfg.ModelName)
return
}
// 2. Auto-scale up if all replicas are busy
if current > 0 && (cfg.MaxReplicas == 0 || int(current) < cfg.MaxReplicas) {
if rc.allReplicasBusy(ctx, cfg.ModelName) {
candidateNodeIDs, selectorMatched := rc.candidateNodeIDsForSelector(ctx, cfg)
if !selectorMatched {
return
}
capacity, capErr := rc.registry.ClusterCapacityForModel(ctx, cfg.ModelName, candidateNodeIDs)
if capErr != nil || capacity == 0 {
// All busy AND no slot available — burst load above capacity.
// Don't enter cooldown for this case (it's transient demand,
// not a misconfig); the next tick will retry naturally.
return
}
xlog.Info("Reconciler: all replicas busy, scaling up", "model", cfg.ModelName,
"current", current)
rc.scaleUp(ctx, cfg, 1)
@@ -335,29 +424,42 @@ func (rc *ReplicaReconciler) reconcileModel(ctx context.Context, cfg ModelSchedu
}
}
// scaleUp schedules additional replicas of the model.
// markCapacityProblem advances the hysteresis counter and sets the cooldown
// timestamp once it crosses the threshold. Centralized so the two scale-up
// paths (MinReplicas and busy-burst) report capacity exhaustion the same way.
func (rc *ReplicaReconciler) markCapacityProblem(ctx context.Context, modelName, reason string) {
ticks, err := rc.registry.BumpUnsatisfiableTicks(ctx, modelName)
if err != nil {
xlog.Warn("Reconciler: failed to bump unsatisfiable counter", "model", modelName, "error", err)
return
}
if ticks >= unsatisfiableTickThreshold {
until := time.Now().Add(unsatisfiableCooldown)
if err := rc.registry.MarkUnsatisfiable(ctx, modelName, until); err != nil {
xlog.Warn("Reconciler: failed to mark unsatisfiable", "model", modelName, "error", err)
return
}
xlog.Warn("Reconciler: scheduling unsatisfiable, entering cooldown",
"model", modelName, "reason", reason,
"cooldown", unsatisfiableCooldown, "retry_after", until.Format(time.RFC3339))
}
}
// scaleUp schedules additional replicas of the model. Callers in
// reconcileModel are expected to have already capped `count` against
// ClusterCapacityForModel so this function never tries to overshoot.
func (rc *ReplicaReconciler) scaleUp(ctx context.Context, cfg ModelSchedulingConfig, count int) {
if rc.scheduler == nil {
xlog.Warn("Reconciler: no scheduler available, cannot scale up")
return
}
// Determine candidate nodes from selector
var candidateNodeIDs []string
if cfg.NodeSelector != "" {
selector := parseSelector(cfg.NodeSelector)
if len(selector) > 0 {
candidates, err := rc.registry.FindNodesBySelector(ctx, selector)
if err != nil || len(candidates) == 0 {
xlog.Warn("Reconciler: no nodes match selector", "model", cfg.ModelName,
"selector", cfg.NodeSelector)
return
}
candidateNodeIDs = make([]string, len(candidates))
for i, n := range candidates {
candidateNodeIDs[i] = n.ID
}
}
// Resolve selector → candidate node IDs (nil when no selector → "any
// healthy node"). The selector mismatch case is handled upstream in
// reconcileModel, but defensively short-circuit here too.
candidateNodeIDs, ok := rc.candidateNodeIDsForSelector(ctx, cfg)
if !ok {
return
}
for i := 0; i < count; i++ {
@@ -377,13 +479,17 @@ func (rc *ReplicaReconciler) scaleDownIdle(ctx context.Context, cfg ModelSchedul
return
}
// Find idle replicas that have been unused for longer than scaleDownDelay
// Find idle replicas that have been unused for longer than scaleDownDelay.
// Order by replica_index DESC first, then last_used ASC: trim the
// highest-indexed replicas first so subsequent scale-ups can reuse the
// low indexes via NextFreeReplicaIndex, keeping slot allocation compact
// and matching the worker supervisor's port-recycling behavior.
cutoff := time.Now().Add(-rc.scaleDownDelay)
var idleModels []NodeModel
rc.registry.db.WithContext(ctx).
Where("model_name = ? AND state = ? AND in_flight = 0 AND last_used < ?",
cfg.ModelName, "loaded", cutoff).
Order("last_used ASC").
Order("replica_index DESC, last_used ASC").
Find(&idleModels)
toRemove := current - floor
@@ -392,16 +498,17 @@ func (rc *ReplicaReconciler) scaleDownIdle(ctx context.Context, cfg ModelSchedul
if removed >= toRemove {
break
}
// Remove from registry
if err := rc.registry.RemoveNodeModel(ctx, nm.NodeID, nm.ModelName); err != nil {
xlog.Warn("Reconciler: failed to remove model record", "error", err)
// Remove this specific replica row from registry (sibling replicas of
// the same model on the same node, if any, are unaffected).
if err := rc.registry.RemoveNodeModel(ctx, nm.NodeID, nm.ModelName, nm.ReplicaIndex); err != nil {
xlog.Warn("Reconciler: failed to remove model record", "node", nm.NodeID, "model", nm.ModelName, "replica", nm.ReplicaIndex, "error", err)
continue
}
// Unload from worker
if err := rc.unloader.UnloadModelOnNode(nm.NodeID, nm.ModelName); err != nil {
xlog.Warn("Reconciler: unload failed (model already removed from registry)", "error", err)
}
xlog.Info("Reconciler: scaled down idle replica", "model", cfg.ModelName, "node", nm.NodeID)
xlog.Info("Reconciler: scaled down idle replica", "model", cfg.ModelName, "node", nm.NodeID, "replica", nm.ReplicaIndex)
removed++
}
}

View File

@@ -48,12 +48,18 @@ var _ = Describe("ReplicaReconciler", func() {
Expect(err).ToNot(HaveOccurred())
})
// Helper to register a healthy node.
// Helper to register a healthy node with enough replica capacity for
// most tests. Pre-PR4 the reconciler ignored capacity, so existing
// fixtures didn't bother setting MaxReplicasPerModel — bumping the
// default here keeps the test intent ("scale up enough") working under
// the new capacity-aware logic. Tests that specifically exercise the
// circuit breaker should register nodes with a tighter cap.
registerNode := func(name, address string) *BackendNode {
node := &BackendNode{
Name: name,
NodeType: NodeTypeBackend,
Address: address,
Name: name,
NodeType: NodeTypeBackend,
Address: address,
MaxReplicasPerModel: 4,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
return node
@@ -99,12 +105,12 @@ var _ = Describe("ReplicaReconciler", func() {
setSchedulingConfig("model-b", 1, 4, "")
// Load 2 replicas, both busy (in_flight > 0)
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", "loaded", "addr1", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-b")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-b", 0, "loaded", "addr1", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-b", 0)).To(Succeed())
node2 := registerNode("node-busy-2", "10.0.0.3:50051")
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-b", "loaded", "addr2", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-b")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-b", 0, "loaded", "addr2", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-b", 0)).To(Succeed())
scheduler := &fakeScheduler{
scheduleNode: node,
@@ -128,12 +134,12 @@ var _ = Describe("ReplicaReconciler", func() {
setSchedulingConfig("model-c", 1, 2, "")
// Load 2 replicas (at max), both busy
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-c", "loaded", "addr1", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-c")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-c", 0, "loaded", "addr1", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "model-c", 0)).To(Succeed())
node2 := registerNode("node-max-2", "10.0.0.5:50051")
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-c", "loaded", "addr2", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-c")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-c", 0, "loaded", "addr2", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "model-c", 0)).To(Succeed())
scheduler := &fakeScheduler{
scheduleNode: node,
@@ -160,7 +166,7 @@ var _ = Describe("ReplicaReconciler", func() {
// Load 3 replicas, all idle with last_used in the past
pastTime := time.Now().Add(-10 * time.Minute)
for _, n := range []*BackendNode{node1, node2, node3} {
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-d", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-d", 0, "loaded", "", 0)).To(Succeed())
// Set last_used to past time to trigger scale-down
db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-d").
Update("last_used", pastTime)
@@ -190,7 +196,7 @@ var _ = Describe("ReplicaReconciler", func() {
// Load exactly 2 replicas (at min), both idle with past last_used
pastTime := time.Now().Add(-10 * time.Minute)
for _, n := range []*BackendNode{node1, node2} {
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-e", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-e", 0, "loaded", "", 0)).To(Succeed())
db.Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", n.ID, "model-e").
Update("last_used", pastTime)
}
@@ -238,6 +244,185 @@ var _ = Describe("ReplicaReconciler", func() {
Expect(scheduler.scheduleCalls[0].candidateIDs).ToNot(ContainElement(node2.ID))
})
})
Describe("Capacity gating + circuit breaker (PR4)", func() {
// Helper: register a node with an explicit per-model replica cap.
// Tests in this Describe block want to exercise both "fits" and
// "doesn't fit" capacity scenarios precisely.
registerCappedNode := func(name, address string, cap int) *BackendNode {
node := &BackendNode{
Name: name,
NodeType: NodeTypeBackend,
Address: address,
MaxReplicasPerModel: cap,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
return node
}
It("caps scale-up at cluster capacity instead of looping forever", func() {
// 1 node × 1 slot = capacity 1, but MinReplicas=2.
// Pre-PR4 this looped: every 30s "scaling up to meet minimum"
// because the registry never grew to 2. Post-PR4 the reconciler
// does the math up front and only schedules 1 (the achievable
// target), then flags unsatisfiable on the next ticks.
node := registerCappedNode("cap-1-slot", "10.0.0.40:50051", 1)
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "tight-model",
MinReplicas: 2,
})).To(Succeed())
scheduler := &fakeScheduler{scheduleNode: node}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Scheduler: scheduler,
DB: db,
})
reconciler.reconcile(context.Background())
Expect(scheduler.scheduleCalls).To(HaveLen(1),
"only 1 schedule call: capacity is 1, not the requested 2 — must not loop")
})
It("flags unsatisfiable after threshold consecutive ticks at capacity 0", func() {
// 1 node × 1 slot, already loaded. Capacity=0, but MinReplicas=2.
// Each tick increments UnsatisfiableTicks; once we cross the
// threshold the cooldown timestamp is set and further ticks
// short-circuit (the scheduler is no longer called).
node := registerCappedNode("cb-node", "10.0.0.41:50051", 1)
Expect(registry.SetNodeModel(context.Background(), node.ID, "cb-model", 0, "loaded", "addr1", 0)).To(Succeed())
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "cb-model",
MinReplicas: 2,
})).To(Succeed())
scheduler := &fakeScheduler{scheduleNode: node}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Scheduler: scheduler,
DB: db,
})
// Drive enough ticks to cross the threshold, plus a couple more
// to confirm the cooldown holds.
for i := 0; i < unsatisfiableTickThreshold+2; i++ {
reconciler.reconcile(context.Background())
}
cfg, err := registry.GetModelScheduling(context.Background(), "cb-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.UnsatisfiableUntil).ToNot(BeNil(),
"reconciler must flag the config after threshold ticks of capacity exhaustion")
Expect(cfg.UnsatisfiableUntil.After(time.Now())).To(BeTrue(),
"cooldown must point to the future")
// Capacity 0 + cooldown active means the scheduler shouldn't have
// been invoked at all — capacity was 0 from the first tick.
Expect(scheduler.scheduleCalls).To(BeEmpty(),
"capacity was always 0 — no schedule attempts should have been made")
})
It("clears unsatisfiable on a successful scale-up", func() {
// Pre-flag the config (simulate a prior unsatisfiable run), then
// register enough capacity and tick — the reconciler must clear
// the flag and proceed.
node := registerCappedNode("clear-node", "10.0.0.42:50051", 4)
until := time.Now().Add(-1 * time.Second) // already-expired cooldown
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "clear-model",
MinReplicas: 1,
UnsatisfiableTicks: 5,
UnsatisfiableUntil: &until,
})).To(Succeed())
scheduler := &fakeScheduler{scheduleNode: node}
reconciler := NewReplicaReconciler(ReplicaReconcilerOptions{
Registry: registry,
Scheduler: scheduler,
DB: db,
})
reconciler.reconcile(context.Background())
Expect(scheduler.scheduleCalls).To(HaveLen(1),
"expired cooldown should not block scheduling")
cfg, err := registry.GetModelScheduling(context.Background(), "clear-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.UnsatisfiableUntil).To(BeNil(), "successful scale-up must clear the cooldown")
Expect(cfg.UnsatisfiableTicks).To(Equal(0), "successful scale-up must reset the counter")
})
It("recovers when a new node joins (ClearAllUnsatisfiable on Register)", func() {
// One full node, then config flagged unsatisfiable. Adding a
// second node simulates the user's recovery question: capacity
// returns, cooldown clears, the next tick schedules.
node1 := registerCappedNode("rec-node-1", "10.0.0.43:50051", 1)
Expect(registry.SetNodeModel(context.Background(), node1.ID, "rec-model", 0, "loaded", "addr1", 0)).To(Succeed())
until := time.Now().Add(unsatisfiableCooldown)
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "rec-model",
MinReplicas: 2,
UnsatisfiableTicks: unsatisfiableTickThreshold,
UnsatisfiableUntil: &until,
})).To(Succeed())
// New node registers — this is the recovery event.
registerCappedNode("rec-node-2", "10.0.0.44:50051", 1)
cfg, err := registry.GetModelScheduling(context.Background(), "rec-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.UnsatisfiableUntil).To(BeNil(),
"Register must clear unsatisfiable flags so the reconciler retries")
Expect(cfg.UnsatisfiableTicks).To(Equal(0))
})
It("recovers when node labels change (ClearAllUnsatisfiable on label ops)", func() {
node := registerCappedNode("lbl-node", "10.0.0.45:50051", 1)
until := time.Now().Add(unsatisfiableCooldown)
Expect(registry.SetModelScheduling(context.Background(), &ModelSchedulingConfig{
ModelName: "lbl-model",
MinReplicas: 2,
UnsatisfiableTicks: unsatisfiableTickThreshold,
UnsatisfiableUntil: &until,
})).To(Succeed())
// Adding a label could change which models the node matches via
// a NodeSelector, so capacity for some config may have just
// changed. ClearAllUnsatisfiable lets the reconciler re-check.
Expect(registry.SetNodeLabel(context.Background(), node.ID, "tier", "fast")).To(Succeed())
cfg, err := registry.GetModelScheduling(context.Background(), "lbl-model")
Expect(err).ToNot(HaveOccurred())
Expect(cfg.UnsatisfiableUntil).To(BeNil())
})
})
Describe("ClusterCapacityForModel", func() {
It("sums (max_replicas_per_model - replicas[node, model]) over candidates", func() {
// Three nodes with caps 4, 2, 1. Loaded counts: 1, 0, 1 → free
// slots: 3, 2, 0 → total capacity 5.
a := &BackendNode{Name: "cap-a", NodeType: NodeTypeBackend, Address: "10.0.0.50:50051", MaxReplicasPerModel: 4}
b := &BackendNode{Name: "cap-b", NodeType: NodeTypeBackend, Address: "10.0.0.51:50051", MaxReplicasPerModel: 2}
c := &BackendNode{Name: "cap-c", NodeType: NodeTypeBackend, Address: "10.0.0.52:50051", MaxReplicasPerModel: 1}
Expect(registry.Register(context.Background(), a, true)).To(Succeed())
Expect(registry.Register(context.Background(), b, true)).To(Succeed())
Expect(registry.Register(context.Background(), c, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), a.ID, "cap-model", 0, "loaded", "x", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), c.ID, "cap-model", 0, "loaded", "y", 0)).To(Succeed())
cap, err := registry.ClusterCapacityForModel(context.Background(), "cap-model", nil)
Expect(err).ToNot(HaveOccurred())
Expect(cap).To(Equal(5))
// Restricting to {b, c}: b free=2, c free=0 → capacity 2.
cap, err = registry.ClusterCapacityForModel(context.Background(), "cap-model", []string{b.ID, c.ID})
Expect(err).ToNot(HaveOccurred())
Expect(cap).To(Equal(2))
})
})
})
// fakeProber lets tests control whether a model's gRPC address "responds".

View File

@@ -26,14 +26,30 @@ type BackendNode struct {
TokenHash string `gorm:"size:64" json:"-"` // SHA-256 of registration token
TotalVRAM uint64 `gorm:"column:total_vram" json:"total_vram"` // Total GPU VRAM in bytes
AvailableVRAM uint64 `gorm:"column:available_vram" json:"available_vram"` // Available GPU VRAM in bytes
TotalRAM uint64 `gorm:"column:total_ram" json:"total_ram"` // Total system RAM in bytes (fallback when no GPU)
AvailableRAM uint64 `gorm:"column:available_ram" json:"available_ram"` // Available system RAM in bytes
GPUVendor string `gorm:"column:gpu_vendor;size:32" json:"gpu_vendor"` // nvidia, amd, intel, vulkan, unknown
APIKeyID string `gorm:"size:36" json:"-"` // auto-provisioned API key ID (for cleanup)
AuthUserID string `gorm:"size:36" json:"-"` // auto-provisioned user ID (for cleanup)
LastHeartbeat time.Time `gorm:"column:last_heartbeat" json:"last_heartbeat"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// ReservedVRAM is a soft, in-tick reservation deducted by the scheduler when
// it picks this node to load a model. Workers reset it back to 0 on each
// heartbeat (the worker is the source of truth for actual free VRAM); the
// reservation is only here to keep two scheduling decisions within the
// same heartbeat window from over-committing the same node.
ReservedVRAM uint64 `gorm:"column:reserved_vram;default:0" json:"reserved_vram"`
TotalRAM uint64 `gorm:"column:total_ram" json:"total_ram"` // Total system RAM in bytes (fallback when no GPU)
AvailableRAM uint64 `gorm:"column:available_ram" json:"available_ram"` // Available system RAM in bytes
GPUVendor string `gorm:"column:gpu_vendor;size:32" json:"gpu_vendor"` // nvidia, amd, intel, vulkan, unknown
// MaxReplicasPerModel caps how many replicas of any one model can run on
// this node concurrently. Default 1 preserves the historical "one
// (node, model)" assumption; set higher (via worker --max-replicas-per-model)
// to allow stacking replicas on a fat node.
MaxReplicasPerModel int `gorm:"column:max_replicas_per_model;default:1" json:"max_replicas_per_model"`
// MaxReplicasPerModelManuallySet flags the value above as a UI-set
// admin override. When true, the worker's CLI value is ignored on
// re-registration so the override survives worker restarts. Cleared
// by an explicit "reset to worker default" action.
MaxReplicasPerModelManuallySet bool `gorm:"column:max_replicas_per_model_manually_set;default:false" json:"max_replicas_per_model_manually_set"`
APIKeyID string `gorm:"size:36" json:"-"` // auto-provisioned API key ID (for cleanup)
AuthUserID string `gorm:"size:36" json:"-"` // auto-provisioned user ID (for cleanup)
LastHeartbeat time.Time `gorm:"column:last_heartbeat" json:"last_heartbeat"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
const (
@@ -47,23 +63,31 @@ const (
StatusUnhealthy = "unhealthy"
// Column names (must match gorm:"column:" tags on BackendNode)
ColAvailableVRAM = "available_vram"
ColTotalVRAM = "total_vram"
ColAvailableRAM = "available_ram"
ColGPUVendor = "gpu_vendor"
ColLastHeartbeat = "last_heartbeat"
ColAvailableVRAM = "available_vram"
ColTotalVRAM = "total_vram"
ColReservedVRAM = "reserved_vram"
ColAvailableRAM = "available_ram"
ColGPUVendor = "gpu_vendor"
ColLastHeartbeat = "last_heartbeat"
ColMaxReplicasPerModel = "max_replicas_per_model"
)
// NodeModel tracks which models are loaded on which nodes.
//
// Multiple replicas of the same model on the same node are allowed; each
// replica has its own ReplicaIndex (0..MaxReplicasPerModel-1), its own
// gRPC Address (each replica is a separate worker process on its own port),
// and its own InFlight counter.
type NodeModel struct {
ID string `gorm:"primaryKey;size:36" json:"id"`
NodeID string `gorm:"index;size:36" json:"node_id"`
ModelName string `gorm:"index;size:255" json:"model_name"`
Address string `gorm:"size:255" json:"address"` // gRPC address for this model's backend process
State string `gorm:"size:32;default:idle" json:"state"` // loading, loaded, unloading, idle
InFlight int `json:"in_flight"` // number of active requests
LastUsed time.Time `json:"last_used"`
LoadingBy string `gorm:"size:36" json:"loading_by,omitempty"` // frontend ID that triggered loading
ID string `gorm:"primaryKey;size:36" json:"id"`
NodeID string `gorm:"index;size:36" json:"node_id"`
ModelName string `gorm:"index;size:255" json:"model_name"`
ReplicaIndex int `gorm:"column:replica_index;default:0;index" json:"replica_index"`
Address string `gorm:"size:255" json:"address"` // gRPC address for this replica's backend process
State string `gorm:"size:32;default:idle" json:"state"` // loading, loaded, unloading, idle
InFlight int `json:"in_flight"` // number of active requests on this replica
LastUsed time.Time `json:"last_used"`
LoadingBy string `gorm:"size:36" json:"loading_by,omitempty"` // frontend ID that triggered loading
BackendType string `gorm:"size:128" json:"backend_type,omitempty"` // e.g. "llama-cpp"; used by reconciler to replicate loads
ModelOptsBlob []byte `gorm:"type:bytea" json:"-"` // serialized pb.ModelOptions for replica scale-ups
CreatedAt time.Time `json:"created_at"`
@@ -87,13 +111,23 @@ type NodeLabel struct {
//
// Auto-scaling is enabled when MinReplicas > 0 or MaxReplicas > 0.
type ModelSchedulingConfig struct {
ID string `gorm:"primaryKey;size:36" json:"id"`
ModelName string `gorm:"uniqueIndex;size:255" json:"model_name"`
NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...}
MinReplicas int `gorm:"default:0" json:"min_replicas"`
MaxReplicas int `gorm:"default:0" json:"max_replicas"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID string `gorm:"primaryKey;size:36" json:"id"`
ModelName string `gorm:"uniqueIndex;size:255" json:"model_name"`
NodeSelector string `gorm:"type:text" json:"node_selector,omitempty"` // JSON {"key":"value",...}
MinReplicas int `gorm:"default:0" json:"min_replicas"`
MaxReplicas int `gorm:"default:0" json:"max_replicas"`
// UnsatisfiableUntil is set by the reconciler when no candidate node has
// free capacity for this model; while in the future, the reconciler skips
// scale-up attempts for this model. Cleared on cluster events that could
// change capacity (new node registers, node approved, labels change,
// max-replicas-per-model changes) or when the cooldown expires.
UnsatisfiableUntil *time.Time `gorm:"column:unsatisfiable_until" json:"unsatisfiable_until,omitempty"`
// UnsatisfiableTicks is hysteresis: incremented each tick capacity==0,
// promoted to UnsatisfiableUntil once it crosses a small threshold to
// avoid one-tick flaps. Reset on any successful scale-up.
UnsatisfiableTicks int `gorm:"column:unsatisfiable_ticks;default:0" json:"unsatisfiable_ticks"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// NodeWithExtras extends BackendNode with computed fields for list views.
@@ -196,7 +230,19 @@ func (r *NodeRegistry) Register(ctx context.Context, node *BackendNode, autoAppr
// Node was never approved — keep pending
node.Status = StatusPending
}
if err := r.db.WithContext(ctx).Model(&existing).Updates(node).Error; err != nil {
// Preserve admin overrides from re-registration. Without this,
// every worker restart silently reverts the UI-set value back to
// the worker's CLI flag (default 1) — a footgun for operators who
// configure capacity from the UI without touching the worker flag.
updateDB := r.db.WithContext(ctx).Model(&existing)
if existing.MaxReplicasPerModelManuallySet {
updateDB = updateDB.Omit("max_replicas_per_model", "max_replicas_per_model_manually_set")
// Reflect the persisted value back so the caller sees what the
// scheduler will actually use.
node.MaxReplicasPerModel = existing.MaxReplicasPerModel
node.MaxReplicasPerModelManuallySet = true
}
if err := updateDB.Updates(node).Error; err != nil {
return fmt.Errorf("updating node %s: %w", node.Name, err)
}
// Preserve auth references from existing record.
@@ -231,6 +277,13 @@ func (r *NodeRegistry) Register(ctx context.Context, node *BackendNode, autoAppr
}
xlog.Info("Node registered", "name", node.Name, "address", node.Address, "status", node.Status)
// Cluster capacity may have changed: a new healthy node, a returning
// node, or one with different MaxReplicasPerModel. Wake any configs the
// reconciler put in cooldown — the next tick will re-flag if still
// unsatisfiable. Best-effort; logged but non-fatal.
if err := r.ClearAllUnsatisfiable(ctx); err != nil {
xlog.Warn("Failed to clear unsatisfiable scheduling flags on register", "error", err)
}
return nil
}
@@ -253,6 +306,11 @@ func (r *NodeRegistry) ApproveNode(ctx context.Context, nodeID string) error {
if result.RowsAffected == 0 {
return fmt.Errorf("node %s not found or not in pending status", nodeID)
}
// pending → healthy adds cluster capacity; clear any cooldown flags so
// the next reconciler tick can use the new node.
if err := r.ClearAllUnsatisfiable(ctx); err != nil {
xlog.Warn("Failed to clear unsatisfiable scheduling flags on approve", "error", err)
}
return nil
}
@@ -283,8 +341,11 @@ func (r *NodeRegistry) MarkOffline(ctx context.Context, nodeID string) error {
return nil
}
// FindNodeWithVRAM returns healthy nodes with at least minBytes available VRAM,
// ordered idle-first then least-loaded.
// FindNodeWithVRAM returns healthy nodes with at least minBytes effectively-
// available VRAM (available_vram - reserved_vram), ordered idle-first then
// least-loaded. The reserved_vram subtraction is the in-tick soft reservation
// that prevents two scheduling decisions in the same heartbeat window from
// over-committing the same node.
func (r *NodeRegistry) FindNodeWithVRAM(ctx context.Context, minBytes uint64) (*BackendNode, error) {
db := r.db.WithContext(ctx)
@@ -297,19 +358,22 @@ func (r *NodeRegistry) FindNodeWithVRAM(ctx context.Context, minBytes uint64) (*
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
Group("node_id")
// Try idle nodes with enough VRAM first, prefer the one with most free VRAM
// Try idle nodes with enough effectively-free VRAM first, prefer the one
// with most free VRAM (after deducting the in-tick reservation).
var node BackendNode
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?)", StatusHealthy, NodeTypeBackend, minBytes, loadedModels).
Order("available_vram DESC").
err := db.Where("status = ? AND node_type = ? AND (available_vram - reserved_vram) >= ? AND id NOT IN (?)",
StatusHealthy, NodeTypeBackend, minBytes, loadedModels).
Order("(available_vram - reserved_vram) DESC").
First(&node).Error
if err == nil {
return &node, nil
}
// Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
err = db.Where("status = ? AND node_type = ? AND available_vram >= ?", StatusHealthy, NodeTypeBackend, minBytes).
// Fall back to least-loaded nodes with enough effectively-free VRAM
err = db.Where("status = ? AND node_type = ? AND (available_vram - reserved_vram) >= ?",
StatusHealthy, NodeTypeBackend, minBytes).
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
Order("COALESCE(load.total_inflight, 0) ASC, (backend_nodes.available_vram - backend_nodes.reserved_vram) DESC").
First(&node).Error
if err != nil {
return nil, fmt.Errorf("no healthy nodes with %d bytes available VRAM: %w", minBytes, err)
@@ -317,6 +381,59 @@ func (r *NodeRegistry) FindNodeWithVRAM(ctx context.Context, minBytes uint64) (*
return &node, nil
}
// ErrInsufficientVRAM signals that ReserveVRAM could not deduct the requested
// amount because the node's effectively-free VRAM has dropped below it
// (raced with another scheduler tick or with a heartbeat reset).
var ErrInsufficientVRAM = errors.New("insufficient effectively-free VRAM on node")
// ReserveVRAM atomically deducts `bytes` from the node's effectively-free
// VRAM (available_vram - reserved_vram). The UPDATE's WHERE clause does the
// admission check inside the database so two concurrent scheduling ticks
// can't both succeed when only one fits — whichever lands first reserves
// the slot, the other gets ErrInsufficientVRAM and falls through to the
// next candidate node.
//
// `bytes` may be 0 (e.g. when the model size estimator declines), in which
// case ReserveVRAM is a no-op — leaving accounting alone is preferable to
// reserving 0 (which would still bump no rows but is conceptually wrong).
//
// Worker heartbeats reset reserved_vram to 0 because the worker is the
// authoritative source for actual free VRAM. This is what makes the
// "soft" in soft-reservation: it's only honored within one heartbeat
// window; longer-term accounting comes from the worker's own readings.
func (r *NodeRegistry) ReserveVRAM(ctx context.Context, nodeID string, bytes uint64) error {
if bytes == 0 {
return nil
}
res := r.db.WithContext(ctx).Model(&BackendNode{}).
Where("id = ? AND (available_vram - reserved_vram) >= ?", nodeID, bytes).
UpdateColumn(ColReservedVRAM, gorm.Expr("reserved_vram + ?", bytes))
if res.Error != nil {
return fmt.Errorf("reserving %d bytes on node %s: %w", bytes, nodeID, res.Error)
}
if res.RowsAffected == 0 {
return ErrInsufficientVRAM
}
return nil
}
// ReleaseVRAM returns previously-reserved bytes to the pool. Called from the
// scheduler's deferred rollback path when LoadModel fails after a successful
// reservation, so the failed in-flight reservation doesn't linger until the
// next heartbeat.
//
// Guarded by `reserved_vram >= bytes` so a duplicate Release can't underflow
// past zero (the column is uint64 — wrap-around would be catastrophic for
// scheduler decisions).
func (r *NodeRegistry) ReleaseVRAM(ctx context.Context, nodeID string, bytes uint64) error {
if bytes == 0 {
return nil
}
return r.db.WithContext(ctx).Model(&BackendNode{}).
Where("id = ? AND reserved_vram >= ?", nodeID, bytes).
UpdateColumn(ColReservedVRAM, gorm.Expr("reserved_vram - ?", bytes)).Error
}
// Deregister removes a backend node, its model associations, and any auto-provisioned auth credentials.
func (r *NodeRegistry) Deregister(ctx context.Context, nodeID string) error {
db := r.db.WithContext(ctx)
@@ -365,6 +482,10 @@ func (r *NodeRegistry) Heartbeat(ctx context.Context, nodeID string, update *Hea
if update != nil {
if update.AvailableVRAM != nil {
updates[ColAvailableVRAM] = *update.AvailableVRAM
// The worker is the source of truth for actual free VRAM.
// Whenever it sends us a fresh reading, the in-tick soft
// reservation is no longer needed — clear it. (See ReserveVRAM.)
updates[ColReservedVRAM] = uint64(0)
}
if update.TotalVRAM != nil {
updates[ColTotalVRAM] = *update.TotalVRAM
@@ -455,16 +576,18 @@ func (r *NodeRegistry) FindStaleNodes(ctx context.Context, threshold time.Durati
// --- NodeModel operations ---
// SetNodeModel records that a model is loaded on a node.
func (r *NodeRegistry) SetNodeModel(ctx context.Context, nodeID, modelName, state, address string, initialInFlight int) error {
// SetNodeModel records that a replica of a model is loaded on a node.
// replicaIndex identifies which slot on the node this replica occupies
// (0..MaxReplicasPerModel-1). Pass 0 for single-replica scheduling.
func (r *NodeRegistry) SetNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int, state, address string, initialInFlight int) error {
now := time.Now()
// Use Attrs for creation-only fields (ID) and Assign for update-only fields.
// Attrs is applied only when creating a new record. Assign is applied on
// both create and update. This prevents overwriting the primary key on
// subsequent calls for the same node+model.
// subsequent calls for the same (node, model, replica_index).
var nm NodeModel
result := r.db.WithContext(ctx).Where("node_id = ? AND model_name = ?", nodeID, modelName).
Attrs(NodeModel{ID: uuid.New().String(), NodeID: nodeID, ModelName: modelName}).
result := r.db.WithContext(ctx).Where("node_id = ? AND model_name = ? AND replica_index = ?", nodeID, modelName, replicaIndex).
Attrs(NodeModel{ID: uuid.New().String(), NodeID: nodeID, ModelName: modelName, ReplicaIndex: replicaIndex}).
Assign(map[string]any{"address": address, "state": state, "last_used": now, "in_flight": initialInFlight}).
FirstOrCreate(&nm)
return result.Error
@@ -473,9 +596,9 @@ func (r *NodeRegistry) SetNodeModel(ctx context.Context, nodeID, modelName, stat
// SetNodeModelLoadInfo stores the backend type and serialized model options on
// an existing NodeModel record. This metadata is used by the reconciler to
// replicate model loads during scale-up.
func (r *NodeRegistry) SetNodeModelLoadInfo(ctx context.Context, nodeID, modelName, backendType string, optsBlob []byte) error {
func (r *NodeRegistry) SetNodeModelLoadInfo(ctx context.Context, nodeID, modelName string, replicaIndex int, backendType string, optsBlob []byte) error {
return r.db.WithContext(ctx).Model(&NodeModel{}).
Where("node_id = ? AND model_name = ?", nodeID, modelName).
Where("node_id = ? AND model_name = ? AND replica_index = ?", nodeID, modelName, replicaIndex).
Updates(map[string]any{"backend_type": backendType, "model_opts_blob": optsBlob}).Error
}
@@ -493,8 +616,21 @@ func (r *NodeRegistry) GetModelLoadInfo(ctx context.Context, modelName string) (
return nm.BackendType, nm.ModelOptsBlob, nil
}
// RemoveNodeModel removes a model association from a node.
func (r *NodeRegistry) RemoveNodeModel(ctx context.Context, nodeID, modelName string) error {
// RemoveNodeModel removes a single replica of a model from a node.
// replicaIndex must match the row to delete; passing 0 for single-replica
// scheduling preserves historical behavior. Removing siblings requires
// separate calls per index — there is no "remove all replicas" shortcut here
// to keep the contract explicit (probeLoadedModels and scaleDownIdle iterate
// per-row and must not orphan healthy siblings).
func (r *NodeRegistry) RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error {
return r.db.WithContext(ctx).Where("node_id = ? AND model_name = ? AND replica_index = ?", nodeID, modelName, replicaIndex).
Delete(&NodeModel{}).Error
}
// RemoveAllNodeModelReplicas removes every replica of modelName on nodeID.
// Used by callers (e.g. node deregistration, full backend stop) that genuinely
// want to clear all replicas, not just one.
func (r *NodeRegistry) RemoveAllNodeModelReplicas(ctx context.Context, nodeID, modelName string) error {
return r.db.WithContext(ctx).Where("node_id = ? AND model_name = ?", nodeID, modelName).
Delete(&NodeModel{}).Error
}
@@ -552,22 +688,72 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
return &node, &nm, nil
}
// TouchNodeModel updates the last_used timestamp for LRU tracking.
func (r *NodeRegistry) TouchNodeModel(ctx context.Context, nodeID, modelName string) {
r.db.WithContext(ctx).Model(&NodeModel{}).Where("node_id = ? AND model_name = ?", nodeID, modelName).
// TouchNodeModel updates the last_used timestamp for LRU tracking on a single
// replica row.
func (r *NodeRegistry) TouchNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) {
r.db.WithContext(ctx).Model(&NodeModel{}).
Where("node_id = ? AND model_name = ? AND replica_index = ?", nodeID, modelName, replicaIndex).
Update("last_used", time.Now())
}
// GetNodeModel returns the NodeModel record for a specific node+model combination.
func (r *NodeRegistry) GetNodeModel(ctx context.Context, nodeID, modelName string) (*NodeModel, error) {
// GetNodeModel returns the NodeModel record for a specific (node, model, replica_index) combination.
func (r *NodeRegistry) GetNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) (*NodeModel, error) {
var nm NodeModel
err := r.db.WithContext(ctx).Where("node_id = ? AND model_name = ?", nodeID, modelName).First(&nm).Error
err := r.db.WithContext(ctx).
Where("node_id = ? AND model_name = ? AND replica_index = ?", nodeID, modelName, replicaIndex).
First(&nm).Error
if err != nil {
return nil, err
}
return &nm, nil
}
// CountReplicasOnNode returns how many replicas of modelName are currently
// recorded for nodeID (across all states). Used by NextFreeReplicaIndex and
// by capacity checks.
func (r *NodeRegistry) CountReplicasOnNode(ctx context.Context, nodeID, modelName string) (int, error) {
var count int64
if err := r.db.WithContext(ctx).Model(&NodeModel{}).
Where("node_id = ? AND model_name = ?", nodeID, modelName).
Count(&count).Error; err != nil {
return 0, err
}
return int(count), nil
}
// ErrNoFreeSlot is returned by NextFreeReplicaIndex when the node already has
// MaxReplicasPerModel replicas of this model and cannot host another.
var ErrNoFreeSlot = errors.New("no free replica slot on node")
// NextFreeReplicaIndex returns the lowest replica_index in [0, maxSlots) that
// is not currently occupied by a row for (nodeID, modelName). Returns
// ErrNoFreeSlot if every index is taken.
//
// Allocating the lowest free index (rather than always appending) keeps slot
// numbers compact across scale-down/scale-up cycles, which matches the worker
// supervisor's port-recycling behavior in core/cli/worker.go (freePorts).
func (r *NodeRegistry) NextFreeReplicaIndex(ctx context.Context, nodeID, modelName string, maxSlots int) (int, error) {
if maxSlots <= 0 {
return 0, ErrNoFreeSlot
}
var taken []int
if err := r.db.WithContext(ctx).Model(&NodeModel{}).
Where("node_id = ? AND model_name = ?", nodeID, modelName).
Pluck("replica_index", &taken).Error; err != nil {
return 0, err
}
occupied := make(map[int]struct{}, len(taken))
for _, idx := range taken {
occupied[idx] = struct{}{}
}
for idx := 0; idx < maxSlots; idx++ {
if _, ok := occupied[idx]; !ok {
return idx, nil
}
}
return 0, ErrNoFreeSlot
}
// FindLeastLoadedNode returns the healthy node with the fewest in-flight requests.
func (r *NodeRegistry) FindLeastLoadedNode(ctx context.Context) (*BackendNode, error) {
db := r.db.WithContext(ctx)
@@ -607,10 +793,10 @@ func (r *NodeRegistry) FindIdleNode(ctx context.Context) (*BackendNode, error) {
return &node, nil
}
// IncrementInFlight atomically increments the in-flight counter for a model on a node.
func (r *NodeRegistry) IncrementInFlight(ctx context.Context, nodeID, modelName string) error {
// IncrementInFlight atomically increments the in-flight counter on a single replica row.
func (r *NodeRegistry) IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error {
result := r.db.WithContext(ctx).Model(&NodeModel{}).
Where("node_id = ? AND model_name = ?", nodeID, modelName).
Where("node_id = ? AND model_name = ? AND replica_index = ?", nodeID, modelName, replicaIndex).
Updates(map[string]any{
"in_flight": gorm.Expr("in_flight + 1"),
"last_used": time.Now(),
@@ -619,21 +805,22 @@ func (r *NodeRegistry) IncrementInFlight(ctx context.Context, nodeID, modelName
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("node model %s/%s not found", nodeID, modelName)
return fmt.Errorf("node model %s/%s replica %d not found", nodeID, modelName, replicaIndex)
}
return nil
}
// DecrementInFlight atomically decrements the in-flight counter for a model on a node.
func (r *NodeRegistry) DecrementInFlight(ctx context.Context, nodeID, modelName string) error {
// DecrementInFlight atomically decrements the in-flight counter on a single replica row.
// Guarded by `in_flight > 0` so that double-decrements don't go negative.
func (r *NodeRegistry) DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error {
result := r.db.WithContext(ctx).Model(&NodeModel{}).
Where("node_id = ? AND model_name = ? AND in_flight > 0", nodeID, modelName).
Where("node_id = ? AND model_name = ? AND replica_index = ? AND in_flight > 0", nodeID, modelName, replicaIndex).
UpdateColumn("in_flight", gorm.Expr("in_flight - 1"))
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
xlog.Warn("DecrementInFlight: no matching row or already zero", "node", nodeID, "model", modelName)
xlog.Warn("DecrementInFlight: no matching row or already zero", "node", nodeID, "model", modelName, "replica", replicaIndex)
}
return nil
}
@@ -700,6 +887,10 @@ func (r *NodeRegistry) FindGlobalLRUModelWithZeroInFlight(ctx context.Context) (
// --- NodeLabel operations ---
// SetNodeLabel upserts a single label on a node.
//
// A label change can change which models match a NodeSelector, so any
// scheduling cooldown flag is cleared as a side effect — the next reconciler
// tick will re-flag if the new label set still doesn't satisfy capacity.
func (r *NodeRegistry) SetNodeLabel(ctx context.Context, nodeID, key, value string) error {
label := NodeLabel{
ID: uuid.New().String(),
@@ -707,17 +898,23 @@ func (r *NodeRegistry) SetNodeLabel(ctx context.Context, nodeID, key, value stri
Key: key,
Value: value,
}
return r.db.WithContext(ctx).
if err := r.db.WithContext(ctx).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "node_id"}, {Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value"}),
}).
Create(&label).Error
Create(&label).Error; err != nil {
return err
}
if err := r.ClearAllUnsatisfiable(ctx); err != nil {
xlog.Warn("Failed to clear unsatisfiable scheduling flags on SetNodeLabel", "error", err)
}
return nil
}
// SetNodeLabels replaces all labels for a node with the given map.
func (r *NodeRegistry) SetNodeLabels(ctx context.Context, nodeID string, labels map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("node_id = ?", nodeID).Delete(&NodeLabel{}).Error; err != nil {
return err
}
@@ -728,12 +925,24 @@ func (r *NodeRegistry) SetNodeLabels(ctx context.Context, nodeID string, labels
}
}
return nil
})
}); err != nil {
return err
}
if err := r.ClearAllUnsatisfiable(ctx); err != nil {
xlog.Warn("Failed to clear unsatisfiable scheduling flags on SetNodeLabels", "error", err)
}
return nil
}
// RemoveNodeLabel removes a single label from a node.
func (r *NodeRegistry) RemoveNodeLabel(ctx context.Context, nodeID, key string) error {
return r.db.WithContext(ctx).Where("node_id = ? AND key = ?", nodeID, key).Delete(&NodeLabel{}).Error
if err := r.db.WithContext(ctx).Where("node_id = ? AND key = ?", nodeID, key).Delete(&NodeLabel{}).Error; err != nil {
return err
}
if err := r.ClearAllUnsatisfiable(ctx); err != nil {
xlog.Warn("Failed to clear unsatisfiable scheduling flags on RemoveNodeLabel", "error", err)
}
return nil
}
// GetNodeLabels returns all labels for a node.
@@ -793,19 +1002,21 @@ func (r *NodeRegistry) FindNodeWithVRAMFromSet(ctx context.Context, minBytes uin
Select("node_id, COALESCE(SUM(in_flight), 0) as total_inflight").
Group("node_id")
// Try idle nodes with enough VRAM first, prefer the one with most free VRAM
// Try idle nodes with enough effectively-free VRAM first.
var node BackendNode
err := db.Where("status = ? AND node_type = ? AND available_vram >= ? AND id NOT IN (?) AND id IN ?", StatusHealthy, NodeTypeBackend, minBytes, loadedModels, nodeIDs).
Order("available_vram DESC").
err := db.Where("status = ? AND node_type = ? AND (available_vram - reserved_vram) >= ? AND id NOT IN (?) AND id IN ?",
StatusHealthy, NodeTypeBackend, minBytes, loadedModels, nodeIDs).
Order("(available_vram - reserved_vram) DESC").
First(&node).Error
if err == nil {
return &node, nil
}
// Fall back to least-loaded nodes with enough VRAM, prefer most free VRAM as tiebreaker
err = db.Where("status = ? AND node_type = ? AND available_vram >= ? AND backend_nodes.id IN ?", StatusHealthy, NodeTypeBackend, minBytes, nodeIDs).
// Fall back to least-loaded nodes with enough effectively-free VRAM
err = db.Where("status = ? AND node_type = ? AND (available_vram - reserved_vram) >= ? AND backend_nodes.id IN ?",
StatusHealthy, NodeTypeBackend, minBytes, nodeIDs).
Joins("LEFT JOIN (?) AS load ON load.node_id = backend_nodes.id", subquery).
Order("COALESCE(load.total_inflight, 0) ASC, backend_nodes.available_vram DESC").
Order("COALESCE(load.total_inflight, 0) ASC, (backend_nodes.available_vram - backend_nodes.reserved_vram) DESC").
First(&node).Error
if err != nil {
return nil, fmt.Errorf("no healthy nodes in set with %d bytes available VRAM: %w", minBytes, err)
@@ -905,6 +1116,184 @@ func (r *NodeRegistry) CountLoadedReplicas(ctx context.Context, modelName string
return count, err
}
// FindNodesWithFreeSlot returns healthy backend nodes that have at least one
// free replica slot for modelName (i.e. count(node_models.*) for this model
// is strictly less than the node's MaxReplicasPerModel cap). When
// candidateNodeIDs is non-empty, only those nodes are considered.
//
// This is the candidate-pool used by SmartRouter.scheduleNewModel — without
// it, the scheduler would happily pick the same node for replica #2 even
// when that node already hosts replica #1, re-creating the original flap.
func (r *NodeRegistry) FindNodesWithFreeSlot(ctx context.Context, modelName string, candidateNodeIDs []string) ([]BackendNode, error) {
q := r.db.WithContext(ctx).Model(&BackendNode{}).
Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend)
if len(candidateNodeIDs) > 0 {
q = q.Where("id IN ?", candidateNodeIDs)
}
// Subquery: per-node count of loaded+loading replicas of this model.
// We count any non-removed row (state != deleted) so a load in progress
// counts against the cap and a second concurrent scale-up can't overshoot.
subq := r.db.Model(&NodeModel{}).
Select("node_id, COUNT(*) as cnt").
Where("model_name = ?", modelName).
Group("node_id")
var out []BackendNode
err := q.Joins("LEFT JOIN (?) AS rc ON rc.node_id = backend_nodes.id", subq).
Where("COALESCE(rc.cnt, 0) < backend_nodes.max_replicas_per_model").
Find(&out).Error
if err != nil {
return nil, fmt.Errorf("finding nodes with free slot for %s: %w", modelName, err)
}
return out, nil
}
// ClusterCapacityForModel returns the total free replica capacity for
// modelName across the candidate node set: Σ (max_replicas_per_model
// current_replicas[n,m]). When candidateNodeIDs is empty all healthy backend
// nodes are considered.
//
// The reconciler uses this to bound MinReplicas at what the cluster can
// actually host, preventing the "scale-up forever" loop from #9XXX where a
// MinReplicas=2 with one worker × one slot churned the model every 30s.
func (r *NodeRegistry) ClusterCapacityForModel(ctx context.Context, modelName string, candidateNodeIDs []string) (int, error) {
q := r.db.WithContext(ctx).Model(&BackendNode{}).
Where("status = ? AND node_type = ?", StatusHealthy, NodeTypeBackend)
if len(candidateNodeIDs) > 0 {
q = q.Where("id IN ?", candidateNodeIDs)
}
subq := r.db.Model(&NodeModel{}).
Select("node_id, COUNT(*) as cnt").
Where("model_name = ?", modelName).
Group("node_id")
var nodes []struct {
MaxReplicasPerModel int
Loaded int
}
err := q.Select("backend_nodes.max_replicas_per_model AS max_replicas_per_model, COALESCE(rc.cnt, 0) AS loaded").
Joins("LEFT JOIN (?) AS rc ON rc.node_id = backend_nodes.id", subq).
Scan(&nodes).Error
if err != nil {
return 0, fmt.Errorf("computing cluster capacity for %s: %w", modelName, err)
}
total := 0
for _, n := range nodes {
free := n.MaxReplicasPerModel - n.Loaded
if free > 0 {
total += free
}
}
return total, nil
}
// BumpUnsatisfiableTicks increments the per-config hysteresis counter when
// the reconciler tries to scale up but cluster capacity is exhausted.
// Returns the new value.
func (r *NodeRegistry) BumpUnsatisfiableTicks(ctx context.Context, modelName string) (int, error) {
res := r.db.WithContext(ctx).Model(&ModelSchedulingConfig{}).
Where("model_name = ?", modelName).
UpdateColumn("unsatisfiable_ticks", gorm.Expr("unsatisfiable_ticks + 1"))
if res.Error != nil {
return 0, res.Error
}
var cfg ModelSchedulingConfig
if err := r.db.WithContext(ctx).Where("model_name = ?", modelName).First(&cfg).Error; err != nil {
return 0, err
}
return cfg.UnsatisfiableTicks, nil
}
// MarkUnsatisfiable sets UnsatisfiableUntil to a future time, so the
// reconciler skips scale-up attempts for this model until the cooldown
// expires (or a cluster event clears the flag — see ClearAllUnsatisfiable).
func (r *NodeRegistry) MarkUnsatisfiable(ctx context.Context, modelName string, until time.Time) error {
return r.db.WithContext(ctx).Model(&ModelSchedulingConfig{}).
Where("model_name = ?", modelName).
Update("unsatisfiable_until", until).Error
}
// ClearUnsatisfiable resets both the cooldown timestamp and the hysteresis
// counter for a single model. Called on a successful scale-up so the next
// transient capacity dip starts the hysteresis from zero.
func (r *NodeRegistry) ClearUnsatisfiable(ctx context.Context, modelName string) error {
return r.db.WithContext(ctx).Model(&ModelSchedulingConfig{}).
Where("model_name = ?", modelName).
Updates(map[string]any{
"unsatisfiable_until": gorm.Expr("NULL"),
"unsatisfiable_ticks": 0,
}).Error
}
// UpdateMaxReplicasPerModel sets a node's per-model replica cap as an admin
// override (sticky across worker restarts) and refreshes the mirrored
// `node.replica-slots` auto-label so selectors reflect the new value.
// Capacity may have just changed, so cooldown flags are cleared too — the
// next reconciler tick will re-flag if still unsatisfiable.
//
// The override is preserved on worker re-registration (see Register). To
// hand control back to the worker flag, call ResetMaxReplicasPerModel.
func (r *NodeRegistry) UpdateMaxReplicasPerModel(ctx context.Context, nodeID string, n int) error {
if n < 1 {
return fmt.Errorf("max_replicas_per_model must be >= 1, got %d", n)
}
res := r.db.WithContext(ctx).Model(&BackendNode{}).
Where("id = ?", nodeID).
Updates(map[string]any{
ColMaxReplicasPerModel: n,
"max_replicas_per_model_manually_set": true,
})
if res.Error != nil {
return fmt.Errorf("updating max_replicas_per_model on %s: %w", nodeID, res.Error)
}
if res.RowsAffected == 0 {
return fmt.Errorf("node %s not found", nodeID)
}
// Keep the auto-label in sync so existing AND-selectors keep matching.
if err := r.SetNodeLabel(ctx, nodeID, "node.replica-slots", fmt.Sprintf("%d", n)); err != nil {
xlog.Warn("Failed to refresh node.replica-slots label", "node", nodeID, "error", err)
}
if err := r.ClearAllUnsatisfiable(ctx); err != nil {
xlog.Warn("Failed to clear unsatisfiable scheduling flags after capacity update", "error", err)
}
return nil
}
// ResetMaxReplicasPerModel clears the admin override flag so the next worker
// re-registration is allowed to update the value again. The current value is
// left in place — the worker will overwrite it on its next register call.
//
// This is the "Reset to worker default" affordance in the UI: it doesn't
// require knowing what the worker flag is set to (the worker tells us on
// re-register), it just hands ownership back.
func (r *NodeRegistry) ResetMaxReplicasPerModel(ctx context.Context, nodeID string) error {
res := r.db.WithContext(ctx).Model(&BackendNode{}).
Where("id = ?", nodeID).
Update("max_replicas_per_model_manually_set", false)
if res.Error != nil {
return fmt.Errorf("clearing max_replicas_per_model override on %s: %w", nodeID, res.Error)
}
if res.RowsAffected == 0 {
return fmt.Errorf("node %s not found", nodeID)
}
return nil
}
// ClearAllUnsatisfiable clears the cooldown flag on every scheduling config.
// Called from cluster-events that could plausibly increase capacity (new
// node registers, node approves pending→healthy, node labels change,
// MaxReplicasPerModel changes). The reconciler's own loop will re-flag any
// config whose target is still unsatisfiable, so over-clearing is cheap and
// correct.
func (r *NodeRegistry) ClearAllUnsatisfiable(ctx context.Context) error {
return r.db.WithContext(ctx).Model(&ModelSchedulingConfig{}).
Where("unsatisfiable_until IS NOT NULL OR unsatisfiable_ticks > 0").
Updates(map[string]any{
"unsatisfiable_until": gorm.Expr("NULL"),
"unsatisfiable_ticks": 0,
}).Error
}
// --- Composite queries ---
// ListWithExtras returns all nodes with model counts and labels.
@@ -999,6 +1388,15 @@ func (r *NodeRegistry) ApplyAutoLabels(ctx context.Context, nodeID string, node
if node.Name != "" {
_ = r.SetNodeLabel(ctx, nodeID, "node.name", node.Name)
}
// Mirror the typed MaxReplicasPerModel field as a label so the existing
// AND-selector machinery in ModelSchedulingConfig can target high-capacity
// nodes (e.g. {"node.replica-slots": "4"}). Always set it (default 1) so
// selectors don't have to special-case missing labels.
slots := node.MaxReplicasPerModel
if slots < 1 {
slots = 1
}
_ = r.SetNodeLabel(ctx, nodeID, "node.replica-slots", fmt.Sprintf("%d", slots))
}
// UpsertPendingBackendOp records or refreshes a pending backend operation for

View File

@@ -125,7 +125,7 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Load a model on the node
Expect(registry.SetNodeModel(context.Background(), node.ID, "llama-7b", "loaded", "10.0.0.7:50052", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "llama-7b", 0, "loaded", "10.0.0.7:50052", 0)).To(Succeed())
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(1))
@@ -155,13 +155,13 @@ var _ = Describe("NodeRegistry", func() {
node := makeNode("stable-id-node", "10.0.0.99:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", "loaded", "10.0.0.99:50052", 0)).To(Succeed())
nm1, err := registry.GetNodeModel(context.Background(), node.ID, "my-model")
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.99:50052", 0)).To(Succeed())
nm1, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0)
Expect(err).ToNot(HaveOccurred())
// Call again with different state/address
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", "loaded", "10.0.0.99:50053", 0)).To(Succeed())
nm2, err := registry.GetNodeModel(context.Background(), node.ID, "my-model")
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.99:50053", 0)).To(Succeed())
nm2, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm2.ID).To(Equal(nm1.ID), "ID should remain stable across SetNodeModel calls")
@@ -199,7 +199,7 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.Register(context.Background(), idle, true)).To(Succeed())
// Load a model on the busy node
Expect(registry.SetNodeModel(context.Background(), busy.ID, "model-a", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), busy.ID, "model-a", 0, "loaded", "", 0)).To(Succeed())
found, err := registry.FindIdleNode(context.Background())
Expect(err).ToNot(HaveOccurred())
@@ -209,7 +209,7 @@ var _ = Describe("NodeRegistry", func() {
It("returns error when all nodes have models loaded", func() {
n := makeNode("all-busy", "10.0.0.22:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), n, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-x", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n.ID, "model-x", 0, "loaded", "", 0)).To(Succeed())
_, err := registry.FindIdleNode(context.Background())
Expect(err).To(HaveOccurred())
@@ -224,13 +224,13 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.Register(context.Background(), light, true)).To(Succeed())
// Set up models with different in-flight counts
Expect(registry.SetNodeModel(context.Background(), heavy.ID, "model-a", "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), heavy.ID, "model-a", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), heavy.ID, "model-a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), light.ID, "model-b", "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), light.ID, "model-b")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), light.ID, "model-b", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), light.ID, "model-b", 0)).To(Succeed())
found, err := registry.FindLeastLoadedNode(context.Background())
Expect(err).ToNot(HaveOccurred())
@@ -242,7 +242,7 @@ var _ = Describe("NodeRegistry", func() {
It("returns the correct node and increments in-flight", func() {
node := makeNode("lock-node", "10.0.0.40:50051", 8_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", "loaded", "10.0.0.40:50052", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "my-model", 0, "loaded", "10.0.0.40:50052", 0)).To(Succeed())
foundNode, foundNM, err := registry.FindAndLockNodeWithModel(context.Background(), "my-model")
Expect(err).ToNot(HaveOccurred())
@@ -250,7 +250,7 @@ var _ = Describe("NodeRegistry", func() {
Expect(foundNM.ModelName).To(Equal("my-model"))
// Verify in-flight was incremented
nm, err := registry.GetNodeModel(context.Background(), node.ID, "my-model")
nm, err := registry.GetNodeModel(context.Background(), node.ID, "my-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm.InFlight).To(Equal(1))
})
@@ -266,12 +266,12 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "shared-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "shared-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "shared-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "shared-model", 0, "loaded", "", 0)).To(Succeed())
// Add in-flight to n1
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), n1.ID, "shared-model", 0)).To(Succeed())
foundNode, _, err := registry.FindAndLockNodeWithModel(context.Background(), "shared-model")
Expect(err).ToNot(HaveOccurred())
@@ -528,8 +528,8 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "counted-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "counted-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "counted-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "counted-model", 0, "loaded", "", 0)).To(Succeed())
count, err := registry.CountLoadedReplicas(context.Background(), "counted-model")
Expect(err).ToNot(HaveOccurred())
@@ -542,8 +542,8 @@ var _ = Describe("NodeRegistry", func() {
Expect(registry.Register(context.Background(), n1, true)).To(Succeed())
Expect(registry.Register(context.Background(), n2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "state-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "state-model", "loading", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n1.ID, "state-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), n2.ID, "state-model", 0, "loading", "", 0)).To(Succeed())
count, err := registry.CountLoadedReplicas(context.Background(), "state-model")
Expect(err).ToNot(HaveOccurred())
@@ -555,12 +555,12 @@ var _ = Describe("NodeRegistry", func() {
It("does not go below zero", func() {
node := makeNode("dec-node", "10.0.0.50:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "dec-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "dec-model", 0, "loaded", "", 0)).To(Succeed())
// in_flight starts at 0 — decrement should be a no-op
Expect(registry.DecrementInFlight(context.Background(), node.ID, "dec-model")).To(Succeed())
Expect(registry.DecrementInFlight(context.Background(), node.ID, "dec-model", 0)).To(Succeed())
nm, err := registry.GetNodeModel(context.Background(), node.ID, "dec-model")
nm, err := registry.GetNodeModel(context.Background(), node.ID, "dec-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm.InFlight).To(Equal(0))
})
@@ -568,20 +568,382 @@ var _ = Describe("NodeRegistry", func() {
It("decrements correctly from a positive value", func() {
node := makeNode("dec-node-2", "10.0.0.51:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "dec-model-2", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "dec-model-2", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "dec-model-2")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "dec-model-2")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed())
nm, err := registry.GetNodeModel(context.Background(), node.ID, "dec-model-2")
nm, err := registry.GetNodeModel(context.Background(), node.ID, "dec-model-2", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm.InFlight).To(Equal(2))
Expect(registry.DecrementInFlight(context.Background(), node.ID, "dec-model-2")).To(Succeed())
Expect(registry.DecrementInFlight(context.Background(), node.ID, "dec-model-2", 0)).To(Succeed())
nm, err = registry.GetNodeModel(context.Background(), node.ID, "dec-model-2")
nm, err = registry.GetNodeModel(context.Background(), node.ID, "dec-model-2", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm.InFlight).To(Equal(1))
})
})
Describe("Schema defaults", func() {
// These tests pin the GORM defaults that the multi-replica refactor
// relies on. If a future migration changes a default, the
// reconciler/router will silently misbehave (e.g. capacity 0 instead
// of 1) — these assertions catch that at the migration boundary.
It("BackendNode.MaxReplicasPerModel defaults to 1", func() {
node := makeNode("schema-default-mrpm", "10.0.0.200:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
fetched, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(fetched.MaxReplicasPerModel).To(Equal(1),
"old workers don't send the field; default must preserve single-replica behavior")
})
It("BackendNode.ReservedVRAM defaults to 0", func() {
node := makeNode("schema-default-reserved", "10.0.0.201:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
fetched, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(fetched.ReservedVRAM).To(Equal(uint64(0)))
})
It("NodeModel.ReplicaIndex defaults to 0", func() {
node := makeNode("schema-default-replica", "10.0.0.202:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "default-replica-model", 0, "loaded", "127.0.0.1:50100", 0)).To(Succeed())
nm, err := registry.GetNodeModel(context.Background(), node.ID, "default-replica-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(nm).ToNot(BeNil())
Expect(nm.ReplicaIndex).To(Equal(0))
})
It("ModelSchedulingConfig.UnsatisfiableUntil is nullable and defaults to nil", func() {
cfg := &ModelSchedulingConfig{
ModelName: "schema-default-unsat",
MinReplicas: 1,
}
Expect(registry.SetModelScheduling(context.Background(), cfg)).To(Succeed())
fetched, err := registry.GetModelScheduling(context.Background(), "schema-default-unsat")
Expect(err).ToNot(HaveOccurred())
Expect(fetched).ToNot(BeNil())
Expect(fetched.UnsatisfiableUntil).To(BeNil())
Expect(fetched.UnsatisfiableTicks).To(Equal(0))
})
})
Describe("Multi-replica registry", func() {
// PR2 tests: SetNodeModel with distinct replica indexes creates distinct
// rows; per-row mutations (Remove, Increment, Decrement, Touch) target
// only their indexed row so siblings are not orphaned.
It("SetNodeModel(replicaIndex=0) then SetNodeModel(replicaIndex=1) creates two distinct rows", func() {
node := makeNode("multi-1", "10.0.0.210:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 0, "loaded", "127.0.0.1:50100", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "multi-model", 1, "loaded", "127.0.0.1:50101", 0)).To(Succeed())
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(2))
byIdx := map[int]NodeModel{}
for _, m := range models {
byIdx[m.ReplicaIndex] = m
}
Expect(byIdx[0].Address).To(Equal("127.0.0.1:50100"))
Expect(byIdx[1].Address).To(Equal("127.0.0.1:50101"))
Expect(byIdx[0].ID).ToNot(Equal(byIdx[1].ID))
})
It("RemoveNodeModel(replicaIndex=0) leaves replica 1 intact", func() {
node := makeNode("multi-2", "10.0.0.211:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "kept-model", 0, "loaded", "127.0.0.1:50110", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "kept-model", 1, "loaded", "127.0.0.1:50111", 0)).To(Succeed())
Expect(registry.RemoveNodeModel(context.Background(), node.ID, "kept-model", 0)).To(Succeed())
// Sibling replica must still exist — this was the latent bug pre-PR2:
// the WHERE clause matched both rows and orphaned the healthy sibling.
survivor, err := registry.GetNodeModel(context.Background(), node.ID, "kept-model", 1)
Expect(err).ToNot(HaveOccurred())
Expect(survivor).ToNot(BeNil())
Expect(survivor.Address).To(Equal("127.0.0.1:50111"))
// Replica 0 is gone
_, err = registry.GetNodeModel(context.Background(), node.ID, "kept-model", 0)
Expect(err).To(HaveOccurred())
})
It("RemoveAllNodeModelReplicas deletes every replica of the model on the node", func() {
node := makeNode("multi-3", "10.0.0.212:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "purge-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "purge-model", 1, "loaded", "b", 0)).To(Succeed())
Expect(registry.RemoveAllNodeModelReplicas(context.Background(), node.ID, "purge-model")).To(Succeed())
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(BeEmpty())
})
It("IncrementInFlight only updates the targeted replica row", func() {
node := makeNode("multi-4", "10.0.0.213:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "infl-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "infl-model", 1, "loaded", "b", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "infl-model", 1)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "infl-model", 1)).To(Succeed())
r0, err := registry.GetNodeModel(context.Background(), node.ID, "infl-model", 0)
Expect(err).ToNot(HaveOccurred())
Expect(r0.InFlight).To(Equal(0), "replica 0 must not have been incremented")
r1, err := registry.GetNodeModel(context.Background(), node.ID, "infl-model", 1)
Expect(err).ToNot(HaveOccurred())
Expect(r1.InFlight).To(Equal(2))
})
It("CountReplicasOnNode returns the per-(node, model) row count", func() {
node := makeNode("multi-5", "10.0.0.214:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "count-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "count-model", 1, "loaded", "b", 0)).To(Succeed())
n, err := registry.CountReplicasOnNode(context.Background(), node.ID, "count-model")
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(2))
})
It("NextFreeReplicaIndex returns the lowest unused index < maxSlots", func() {
node := makeNode("multi-6", "10.0.0.215:50051", 16_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Slot 0 free initially
idx, err := registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4)
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(Equal(0))
// Occupy 0 and 2 — next free is 1 (lowest gap)
Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 0, "loaded", "a", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 2, "loaded", "c", 0)).To(Succeed())
idx, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4)
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(Equal(1), "must allocate the lowest free index for compactness")
// Fill all 4 — must return ErrNoFreeSlot
Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 1, "loaded", "b", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "slot-model", 3, "loaded", "d", 0)).To(Succeed())
_, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "slot-model", 4)
Expect(err).To(MatchError(ErrNoFreeSlot))
// maxSlots=0 always returns ErrNoFreeSlot
_, err = registry.NextFreeReplicaIndex(context.Background(), node.ID, "no-slots-model", 0)
Expect(err).To(MatchError(ErrNoFreeSlot))
})
})
Describe("ApplyAutoLabels", func() {
It("mirrors MaxReplicasPerModel as the node.replica-slots label", func() {
node := makeNode("auto-label-replicas", "10.0.0.220:50051", 16_000_000_000)
node.MaxReplicasPerModel = 4
node.GPUVendor = "nvidia"
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
registry.ApplyAutoLabels(context.Background(), node.ID, node)
labels, err := registry.GetNodeLabels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
byKey := map[string]string{}
for _, l := range labels {
byKey[l.Key] = l.Value
}
Expect(byKey).To(HaveKeyWithValue("node.replica-slots", "4"),
"selectors targeting fat nodes need this auto-label")
Expect(byKey).To(HaveKeyWithValue("gpu.vendor", "nvidia"))
})
It("defaults node.replica-slots to 1 when MaxReplicasPerModel is unset", func() {
node := makeNode("auto-label-default", "10.0.0.221:50051", 4_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Fetch back; default should be 1 (PR1 schema test)
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModel).To(Equal(1))
registry.ApplyAutoLabels(context.Background(), node.ID, fetched)
labels, _ := registry.GetNodeLabels(context.Background(), node.ID)
byKey := map[string]string{}
for _, l := range labels {
byKey[l.Key] = l.Value
}
Expect(byKey).To(HaveKeyWithValue("node.replica-slots", "1"))
})
})
Describe("VRAM soft-reservation (PR5)", func() {
// These tests pin the soft-reservation contract: ReserveVRAM is the
// admission gate that prevents two concurrent scheduling decisions
// from over-committing the same node within one heartbeat window.
It("ReserveVRAM atomically deducts from effectively-free VRAM", func() {
node := makeNode("reserve-1", "10.0.0.230:50051", 10_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.ReserveVRAM(context.Background(), node.ID, 3_000_000_000)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(3_000_000_000)))
})
It("ReserveVRAM rejects when effectively-free VRAM is insufficient", func() {
node := makeNode("reserve-2", "10.0.0.231:50051", 5_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// First reservation fits.
Expect(registry.ReserveVRAM(context.Background(), node.ID, 4_000_000_000)).To(Succeed())
// Second is too big — only 1 GB effectively free.
err := registry.ReserveVRAM(context.Background(), node.ID, 2_000_000_000)
Expect(err).To(MatchError(ErrInsufficientVRAM))
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(4_000_000_000)),
"failed reservation must not bump the column")
})
It("ReserveVRAM with bytes=0 is a no-op", func() {
node := makeNode("reserve-3", "10.0.0.232:50051", 1_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.ReserveVRAM(context.Background(), node.ID, 0)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(0)))
})
It("ReleaseVRAM returns reserved bytes to the pool", func() {
node := makeNode("release-1", "10.0.0.233:50051", 10_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.ReserveVRAM(context.Background(), node.ID, 4_000_000_000)).To(Succeed())
Expect(registry.ReleaseVRAM(context.Background(), node.ID, 1_000_000_000)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(3_000_000_000)))
})
It("ReleaseVRAM cannot underflow past zero", func() {
node := makeNode("release-underflow", "10.0.0.234:50051", 1_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// No reservation; release is a guarded no-op rather than wrapping
// uint64 to a huge positive number.
Expect(registry.ReleaseVRAM(context.Background(), node.ID, 5_000_000_000)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.ReservedVRAM).To(Equal(uint64(0)))
})
It("Heartbeat with available_vram resets reserved_vram to 0", func() {
node := makeNode("heartbeat-reset", "10.0.0.235:50051", 10_000_000_000)
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.ReserveVRAM(context.Background(), node.ID, 5_000_000_000)).To(Succeed())
fresh := uint64(8_000_000_000)
Expect(registry.Heartbeat(context.Background(), node.ID, &HeartbeatUpdate{AvailableVRAM: &fresh})).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.AvailableVRAM).To(Equal(fresh),
"heartbeat must overwrite available_vram with the worker's reading")
Expect(fetched.ReservedVRAM).To(Equal(uint64(0)),
"heartbeat must clear the soft reservation — worker is the source of truth")
})
It("UpdateMaxReplicasPerModel marks the value as a sticky override", func() {
// The original UX bug: workers default the flag to 1, so every
// re-registration silently reverted the admin's UI value. This
// test pins the fix.
node := &BackendNode{
Name: "override-survives",
NodeType: NodeTypeBackend,
Address: "10.0.0.240:50051",
MaxReplicasPerModel: 1,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Admin sets capacity to 4 via the UI.
Expect(registry.UpdateMaxReplicasPerModel(context.Background(), node.ID, 4)).To(Succeed())
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModel).To(Equal(4))
Expect(fetched.MaxReplicasPerModelManuallySet).To(BeTrue())
// Worker re-registers with its default of 1 (operator never set the flag).
restart := &BackendNode{
Name: "override-survives",
NodeType: NodeTypeBackend,
Address: "10.0.0.240:50051",
MaxReplicasPerModel: 1,
}
Expect(registry.Register(context.Background(), restart, true)).To(Succeed())
// Override must have survived.
fetched, _ = registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModel).To(Equal(4),
"admin override must not be overwritten by worker re-registration")
Expect(fetched.MaxReplicasPerModelManuallySet).To(BeTrue())
})
It("ResetMaxReplicasPerModel hands control back to the worker", func() {
node := &BackendNode{
Name: "override-reset",
NodeType: NodeTypeBackend,
Address: "10.0.0.241:50051",
MaxReplicasPerModel: 1,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.UpdateMaxReplicasPerModel(context.Background(), node.ID, 4)).To(Succeed())
Expect(registry.ResetMaxReplicasPerModel(context.Background(), node.ID)).To(Succeed())
// Reset only flips the flag; the value stays until the worker
// re-registers (we don't presume to know what the worker wants).
fetched, _ := registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModelManuallySet).To(BeFalse())
// Now worker re-registers with 8.
restart := &BackendNode{
Name: "override-reset",
NodeType: NodeTypeBackend,
Address: "10.0.0.241:50051",
MaxReplicasPerModel: 8,
}
Expect(registry.Register(context.Background(), restart, true)).To(Succeed())
fetched, _ = registry.Get(context.Background(), node.ID)
Expect(fetched.MaxReplicasPerModel).To(Equal(8),
"after reset, the worker's value should apply")
})
It("FindNodeWithVRAM honors the reservation", func() {
small := makeNode("find-vram-small", "10.0.0.236:50051", 5_000_000_000)
big := makeNode("find-vram-big", "10.0.0.237:50051", 20_000_000_000)
Expect(registry.Register(context.Background(), small, true)).To(Succeed())
Expect(registry.Register(context.Background(), big, true)).To(Succeed())
// Reserve almost all of the big node so its effective free
// drops below the request — small isn't big enough either —
// the call must return an error.
Expect(registry.ReserveVRAM(context.Background(), big.ID, 18_000_000_000)).To(Succeed())
_, err := registry.FindNodeWithVRAM(context.Background(), 8_000_000_000)
Expect(err).To(HaveOccurred(),
"reserved capacity must remove a node from VRAM-aware candidates")
})
})
})

View File

@@ -77,19 +77,23 @@ func (r *SmartRouter) StagingTracker() *StagingTracker { return r.stagingTracker
// scheduleLoadResult holds the result of scheduling and loading a model on a node.
type scheduleLoadResult struct {
Node *BackendNode
Client grpc.Backend
BackendAddr string
Node *BackendNode
Client grpc.Backend
BackendAddr string
ReplicaIndex int
}
// scheduleAndLoad is the shared core for loading a model on a new node.
// Used by both Route() (for first-time loads) and ScheduleAndLoadModel() (for reconciler scale-ups).
//
// Steps: pick node → install backend → stage files → LoadModel → SetNodeModel.
// Steps: pick node + replica slot → install backend → stage files → LoadModel → SetNodeModel.
//
// scheduleNewModel allocates the replica index internally so the worker's
// processKey, port, and the registry row all agree.
func (r *SmartRouter) scheduleAndLoad(ctx context.Context, backendType, trackingKey, modelName string,
modelOpts *pb.ModelOptions, parallel bool, initialInFlight int) (*scheduleLoadResult, error) {
node, backendAddr, err := r.scheduleNewModel(ctx, backendType, trackingKey, modelOpts)
node, backendAddr, replicaIndex, err := r.scheduleNewModel(ctx, backendType, trackingKey, modelOpts)
if err != nil {
return nil, fmt.Errorf("no available nodes: %w", err)
}
@@ -122,21 +126,21 @@ func (r *SmartRouter) scheduleAndLoad(ctx context.Context, backendType, tracking
}
}
// Record the model as loaded on this node
if err := r.registry.SetNodeModel(ctx, node.ID, trackingKey, "loaded", backendAddr, initialInFlight); err != nil {
xlog.Warn("Failed to record model on node", "node", node.Name, "model", trackingKey, "error", err)
// Record the model as loaded on this node (specific replica slot).
if err := r.registry.SetNodeModel(ctx, node.ID, trackingKey, replicaIndex, "loaded", backendAddr, initialInFlight); err != nil {
xlog.Warn("Failed to record model on node", "node", node.Name, "model", trackingKey, "replica", replicaIndex, "error", err)
}
// Store load metadata for future replica scale-ups by the reconciler
if modelOpts != nil {
if optsBlob, marshalErr := proto.Marshal(modelOpts); marshalErr == nil {
if storeErr := r.registry.SetNodeModelLoadInfo(ctx, node.ID, trackingKey, backendType, optsBlob); storeErr != nil {
xlog.Warn("Failed to store model load info", "node", node.Name, "model", trackingKey, "error", storeErr)
if storeErr := r.registry.SetNodeModelLoadInfo(ctx, node.ID, trackingKey, replicaIndex, backendType, optsBlob); storeErr != nil {
xlog.Warn("Failed to store model load info", "node", node.Name, "model", trackingKey, "replica", replicaIndex, "error", storeErr)
}
}
}
return &scheduleLoadResult{Node: node, Client: client, BackendAddr: backendAddr}, nil
return &scheduleLoadResult{Node: node, Client: client, BackendAddr: backendAddr, ReplicaIndex: replicaIndex}, nil
}
// ScheduleAndLoadModel implements ModelScheduler for the reconciler.
@@ -150,7 +154,7 @@ func (r *SmartRouter) ScheduleAndLoadModel(ctx context.Context, modelName string
// This happens on the very first load (before Route() has stored opts).
xlog.Warn("No stored model load info for reconciler scale-up, falling back to backend install only",
"model", modelName, "error", err)
node, _, schedErr := r.scheduleNewModel(ctx, "", modelName, nil)
node, _, _, schedErr := r.scheduleNewModel(ctx, "", modelName, nil)
return node, schedErr
}
@@ -160,7 +164,8 @@ func (r *SmartRouter) ScheduleAndLoadModel(ctx context.Context, modelName string
return nil, fmt.Errorf("unmarshalling stored model options for %s: %w", modelName, err)
}
// initialInFlight=0: reconciler is pre-loading, not serving a request
// initialInFlight=0: reconciler is pre-loading, not serving a request.
// scheduleAndLoad picks both the node and the replica slot internally.
result, err := r.scheduleAndLoad(ctx, backendType, modelName, modelName, &modelOpts, false, 0)
if err != nil {
return nil, err
@@ -200,31 +205,32 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
if nm.Address != "" {
modelAddr = nm.Address
}
replicaIdx := nm.ReplicaIndex
// Verify the backend process is still alive via gRPC health check
if !r.probeHealth(ctx, node, modelAddr) {
// Stale — roll back the increment, remove the model record, fall through
r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
r.registry.RemoveNodeModel(ctx, node.ID, trackingKey)
// Stale — roll back the increment, remove the specific replica row, fall through
r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx)
r.registry.RemoveNodeModel(ctx, node.ID, trackingKey, replicaIdx)
xlog.Warn("Backend not reachable for cached model, falling through to reload",
"node", node.Name, "model", modelName)
"node", node.Name, "model", modelName, "replica", replicaIdx)
} else {
// Verify node still matches scheduling constraints
if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx)
xlog.Info("Cached model on node that no longer matches selector, falling through",
"node", node.Name, "model", trackingKey)
"node", node.Name, "model", trackingKey, "replica", replicaIdx)
// Fall through to step 2 (scheduleNewModel)
} else {
// Node is alive — FindAndLockNodeWithModel already incremented in-flight as a
// reservation. InFlightTrackingClient handles per-inference tracking, and its
// onFirstComplete callback releases the reservation after the first inference
// call finishes, so in-flight returns to 0 when idle.
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
r.registry.TouchNodeModel(ctx, node.ID, trackingKey, replicaIdx)
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey)
tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey, replicaIdx)
tracked.OnFirstComplete(func() {
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey, replicaIdx)
})
return &RouteResult{
Node: node,
@@ -246,29 +252,30 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
if nm.Address != "" {
modelAddr = nm.Address
}
replicaIdx := nm.ReplicaIndex
// Verify the backend process is still alive via gRPC health check
if !r.probeHealth(ctx, node, modelAddr) {
// Stale — roll back the increment, remove the model record, continue loading
r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
r.registry.RemoveNodeModel(ctx, node.ID, trackingKey)
// Stale — roll back the increment, remove the specific replica row, continue loading
r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx)
r.registry.RemoveNodeModel(ctx, node.ID, trackingKey, replicaIdx)
xlog.Warn("Backend not reachable for cached model inside lock, proceeding to load",
"node", node.Name, "model", modelName)
"node", node.Name, "model", modelName, "replica", replicaIdx)
} else {
// Verify node still matches scheduling constraints
if !r.nodeMatchesScheduling(ctx, node, trackingKey) {
r.registry.DecrementInFlight(ctx, node.ID, trackingKey)
r.registry.DecrementInFlight(ctx, node.ID, trackingKey, replicaIdx)
xlog.Info("Cached model on node that no longer matches selector, falling through",
"node", node.Name, "model", trackingKey)
"node", node.Name, "model", trackingKey, "replica", replicaIdx)
// Fall through to scheduling below
} else {
// Model loaded while we waited — FindAndLockNodeWithModel already incremented
// in-flight as a reservation. Release it after the first inference completes.
r.registry.TouchNodeModel(ctx, node.ID, trackingKey)
r.registry.TouchNodeModel(ctx, node.ID, trackingKey, replicaIdx)
grpcClient := r.buildClientForAddr(node, modelAddr, parallel)
tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey)
tracked := NewInFlightTrackingClient(grpcClient, r.registry, node.ID, trackingKey, replicaIdx)
tracked.OnFirstComplete(func() {
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey)
r.registry.DecrementInFlight(context.Background(), node.ID, trackingKey, replicaIdx)
})
return &RouteResult{
Node: node,
@@ -281,15 +288,17 @@ func (r *SmartRouter) Route(ctx context.Context, modelID, modelName, backendType
}
}
// Still not loaded — use shared schedule-and-load logic
// Still not loaded — use shared schedule-and-load logic, which picks
// both the node and the replica slot.
result, err := r.scheduleAndLoad(ctx, backendType, trackingKey, modelName, modelOpts, parallel, 1)
if err != nil {
return nil, err
}
tracked := NewInFlightTrackingClient(result.Client, r.registry, result.Node.ID, trackingKey)
replicaIdx := result.ReplicaIndex
tracked := NewInFlightTrackingClient(result.Client, r.registry, result.Node.ID, trackingKey, replicaIdx)
tracked.OnFirstComplete(func() {
r.registry.DecrementInFlight(context.Background(), result.Node.ID, trackingKey)
r.registry.DecrementInFlight(context.Background(), result.Node.ID, trackingKey, replicaIdx)
})
return &RouteResult{
Node: result.Node,
@@ -370,33 +379,55 @@ func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNo
return true
}
// scheduleNewModel picks the best node for loading a new model.
// Strategy: VRAM-aware → idle-first → least-loaded.
// scheduleNewModel picks the best node for loading a new model and allocates
// the replica slot.
// Strategy: filter to nodes with a free slot for this model → VRAM-aware →
// idle-first → least-loaded → eviction.
// Sends backend.install via NATS so the chosen node has the right backend running.
func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID string, modelOpts *pb.ModelOptions) (*BackendNode, string, error) {
//
// Returns (node, gRPC address, replicaIndex, err). replicaIndex is the slot
// the worker has been told to use; the caller must pass the same index into
// SetNodeModel so the registry row matches the live process.
func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID string, modelOpts *pb.ModelOptions) (*BackendNode, string, int, error) {
// Estimate VRAM required for the model
var estimatedVRAM uint64
if modelOpts != nil {
estimatedVRAM = r.estimateModelVRAM(ctx, modelOpts)
}
// Check for scheduling constraints (node selector)
// Check for scheduling constraints (node selector). If a selector is set,
// we restrict the candidate pool to matching nodes; otherwise nil means
// "any healthy node".
sched, _ := r.registry.GetModelScheduling(ctx, modelID)
var candidateNodeIDs []string // nil = all nodes eligible
var candidateNodeIDs []string
if sched != nil && sched.NodeSelector != "" {
selector := parseSelectorJSON(sched.NodeSelector)
if len(selector) > 0 {
candidates, err := r.registry.FindNodesBySelector(ctx, selector)
if err != nil || len(candidates) == 0 {
return nil, "", fmt.Errorf("no healthy nodes match selector for model %s: %v", modelID, sched.NodeSelector)
return nil, "", 0, fmt.Errorf("no healthy nodes match selector for model %s: %v", modelID, sched.NodeSelector)
}
candidateNodeIDs = extractNodeIDs(candidates)
}
}
// Narrow candidates to nodes that still have a free replica slot for this
// model. Without this filter, the scheduler would happily pick a node
// already at capacity for this model (e.g. when MinReplicas > free
// cluster capacity), which is what caused the original 30s flap loop.
freeSlotNodes, err := r.registry.FindNodesWithFreeSlot(ctx, modelID, candidateNodeIDs)
if err != nil {
xlog.Warn("Failed to query nodes with free slot; falling back to selector-only filtering",
"model", modelID, "error", err)
} else if len(freeSlotNodes) > 0 {
// Replace the candidate set with only those that have capacity.
candidateNodeIDs = extractNodeIDs(freeSlotNodes)
}
// If freeSlotNodes is empty (everyone full), candidateNodeIDs is whatever
// it was — we'll fall through to eviction below.
var node *BackendNode
var err error
if estimatedVRAM > 0 {
if candidateNodeIDs != nil {
@@ -429,20 +460,71 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID
evictedNode, evictErr := r.evictLRUAndFreeNode(ctx)
if evictErr != nil {
if errors.Is(evictErr, ErrEvictionBusy) {
return nil, "", fmt.Errorf("no healthy nodes available: %w", evictErr)
return nil, "", 0, fmt.Errorf("no healthy nodes available: %w", evictErr)
}
return nil, "", fmt.Errorf("no healthy nodes available and eviction failed: %w", evictErr)
return nil, "", 0, fmt.Errorf("no healthy nodes available and eviction failed: %w", evictErr)
}
node = evictedNode
}
// Send backend.install — the worker installs the backend if needed and starts the gRPC process
addr, err := r.installBackendOnNode(ctx, node, backendType, modelID)
if err != nil {
return nil, "", fmt.Errorf("installing backend on node %s: %w", node.Name, err)
// Allocate the replica slot before sending backend.install so the worker
// uses the same slot for its processKey + port. Default to 0 when the
// node's MaxReplicasPerModel is 1 (preserves single-replica behavior).
maxSlots := node.MaxReplicasPerModel
if maxSlots < 1 {
maxSlots = 1
}
replicaIdx, slotErr := r.registry.NextFreeReplicaIndex(ctx, node.ID, modelID, maxSlots)
if slotErr != nil {
// All slots on this node are taken — fall back to eviction. This is
// rare in practice because FindNodesWithFreeSlot already filtered;
// it can race with another concurrent scheduler.
xlog.Warn("Chosen node has no free replica slot, evicting LRU",
"node", node.Name, "model", modelID, "max_slots", maxSlots)
evictedNode, evictErr := r.evictLRUAndFreeNode(ctx)
if evictErr != nil {
return nil, "", 0, fmt.Errorf("no replica slot on %s and eviction failed: %w", node.Name, evictErr)
}
node = evictedNode
replicaIdx, slotErr = r.registry.NextFreeReplicaIndex(ctx, node.ID, modelID, node.MaxReplicasPerModel)
if slotErr != nil {
return nil, "", 0, fmt.Errorf("no replica slot on %s after eviction: %w", node.Name, slotErr)
}
}
return node, addr, nil
// Soft-reserve VRAM up front so a second scheduling tick within the same
// heartbeat window can't pick this node based on stale free-VRAM
// numbers. The worker's next heartbeat resets reserved_vram to the
// authoritative reading; explicit rollback below covers the failure
// window between reservation and a successful install.
reserved := false
if estimatedVRAM > 0 {
reserveErr := r.registry.ReserveVRAM(ctx, node.ID, estimatedVRAM)
if reserveErr != nil {
// ErrInsufficientVRAM races with another scheduler — log and
// proceed without a reservation rather than failing the load.
// FindNodeWithVRAM already accounted for reserved_vram, so this
// is a tight race window; the worker will reconcile via heartbeat.
xlog.Warn("Failed to reserve VRAM, proceeding without reservation",
"node", node.Name, "bytes", estimatedVRAM, "error", reserveErr)
} else {
reserved = true
}
}
// Send backend.install — the worker installs the backend if needed and
// starts the gRPC process bound to a port for this (model, replica) slot.
addr, installErr := r.installBackendOnNode(ctx, node, backendType, modelID, replicaIdx)
if installErr != nil {
// Roll back the reservation explicitly so the column is accurate
// before the next heartbeat. Best-effort.
if reserved {
_ = r.registry.ReleaseVRAM(ctx, node.ID, estimatedVRAM)
}
return nil, "", 0, fmt.Errorf("installing backend on node %s: %w", node.Name, installErr)
}
return node, addr, replicaIdx, nil
}
// estimateModelVRAM estimates the VRAM required for a model using the unified estimator.
@@ -499,19 +581,19 @@ func (r *SmartRouter) estimateModelVRAM(ctx context.Context, opts *pb.ModelOptio
// The worker installs the backend from gallery (if not already installed),
// starts the gRPC process, and replies when ready.
// installBackendOnNode installs a backend on a node and returns the gRPC address.
func (r *SmartRouter) installBackendOnNode(ctx context.Context, node *BackendNode, backendType, modelID string) (string, error) {
func (r *SmartRouter) installBackendOnNode(ctx context.Context, node *BackendNode, backendType, modelID string, replicaIndex int) (string, error) {
if r.unloader == nil {
return "", fmt.Errorf("no NATS connection for backend installation")
}
reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "")
reply, err := r.unloader.InstallBackend(node.ID, backendType, modelID, r.galleriesJSON, "", "", "", replicaIndex)
if err != nil {
return "", err
}
if !reply.Success {
return "", fmt.Errorf("worker replied with error: %s", reply.Error)
}
// Return the backend's gRPC address (new: per-process port from worker)
// Return the backend's gRPC address (per-replica port from worker)
addr := reply.Address
if addr == "" {
addr = node.Address // fallback to node base address
@@ -789,7 +871,8 @@ func closeClient(client grpc.Backend) {
}
}
// UnloadModel sends a NATS unload event to a specific node for the given model.
// UnloadModel sends a NATS unload event to a specific node for the given model
// and removes every replica row for (nodeID, modelName).
// The worker process handles Free() + kill + deregister.
func (r *SmartRouter) UnloadModel(ctx context.Context, nodeID, modelName string) error {
if r.unloader == nil {
@@ -799,7 +882,7 @@ func (r *SmartRouter) UnloadModel(ctx context.Context, nodeID, modelName string)
if err := r.unloader.StopBackend(nodeID, modelName); err != nil {
return fmt.Errorf("failed to stop backend on node %s: %w", nodeID, err)
}
r.registry.RemoveNodeModel(ctx, nodeID, modelName)
r.registry.RemoveAllNodeModelReplicas(ctx, nodeID, modelName)
return nil
}
@@ -851,9 +934,10 @@ func (r *SmartRouter) evictLRUAndFreeNode(ctx context.Context) (*BackendNode, er
First(&lru).Error; err != nil {
return err
}
// Remove inside the same transaction
return tx.Where("node_id = ? AND model_name = ?", lru.NodeID, lru.ModelName).
Delete(&NodeModel{}).Error
// Remove inside the same transaction. Target the specific replica row
// by ID so we don't accidentally delete sibling replicas of the same
// model on the same node.
return tx.Where("id = ?", lru.ID).Delete(&NodeModel{}).Error
})
if err == nil {

View File

@@ -118,31 +118,39 @@ func (f *fakeModelRouter) FindAndLockNodeWithModel(_ context.Context, modelName
return f.findAndLockNode, f.findAndLockNM, f.findAndLockErr
}
func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string) error {
func (f *fakeModelRouter) DecrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
f.decrementCalls = append(f.decrementCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) IncrementInFlight(_ context.Context, nodeID, modelName string) error {
func (f *fakeModelRouter) IncrementInFlight(_ context.Context, nodeID, modelName string, _ int) error {
f.incrementCalls = append(f.incrementCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) RemoveNodeModel(_ context.Context, nodeID, modelName string) error {
func (f *fakeModelRouter) RemoveNodeModel(_ context.Context, nodeID, modelName string, _ int) error {
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) TouchNodeModel(_ context.Context, nodeID, modelName string) {
func (f *fakeModelRouter) RemoveAllNodeModelReplicas(_ context.Context, nodeID, modelName string) error {
// Same recorded key as RemoveNodeModel so existing tests that assert "the
// model was removed" don't need to know whether the production code used
// the per-replica or all-replicas variant.
f.removeCalls = append(f.removeCalls, nodeID+":"+modelName)
return nil
}
func (f *fakeModelRouter) TouchNodeModel(_ context.Context, nodeID, modelName string, _ int) {
f.touchCalls = append(f.touchCalls, nodeID+":"+modelName)
}
func (f *fakeModelRouter) SetNodeModel(_ context.Context, nodeID, modelName, state, address string, _ int) error {
func (f *fakeModelRouter) SetNodeModel(_ context.Context, nodeID, modelName string, _ int, state, address string, _ int) error {
f.setCalls = append(f.setCalls, fmt.Sprintf("%s:%s:%s:%s", nodeID, modelName, state, address))
return nil
}
func (f *fakeModelRouter) SetNodeModelLoadInfo(_ context.Context, _, _, _ string, _ []byte) error {
func (f *fakeModelRouter) SetNodeModelLoadInfo(_ context.Context, _, _ string, _ int, _ string, _ []byte) error {
return nil
}
@@ -150,6 +158,14 @@ func (f *fakeModelRouter) GetModelLoadInfo(_ context.Context, _ string) (string,
return "", nil, fmt.Errorf("not found")
}
func (f *fakeModelRouter) NextFreeReplicaIndex(_ context.Context, _, _ string, _ int) (int, error) {
return 0, nil
}
func (f *fakeModelRouter) CountReplicasOnNode(_ context.Context, _, _ string) (int, error) {
return 0, nil
}
func (f *fakeModelRouter) FindNodeWithVRAM(_ context.Context, _ uint64) (*BackendNode, error) {
return f.findVRAMNode, f.findVRAMErr
}
@@ -182,6 +198,20 @@ func (f *fakeModelRouter) FindNodesBySelector(_ context.Context, _ map[string]st
return f.findBySelectorNodes, f.findBySelectorErr
}
func (f *fakeModelRouter) FindNodesWithFreeSlot(_ context.Context, _ string, _ []string) ([]BackendNode, error) {
// Default: same answer as FindNodesBySelector. Tests that need a
// specific filter can override by reusing findBySelectorNodes.
return f.findBySelectorNodes, f.findBySelectorErr
}
func (f *fakeModelRouter) ReserveVRAM(_ context.Context, _ string, _ uint64) error {
return nil
}
func (f *fakeModelRouter) ReleaseVRAM(_ context.Context, _ string, _ uint64) error {
return nil
}
func (f *fakeModelRouter) FindNodeWithVRAMFromSet(_ context.Context, _ uint64, _ []string) (*BackendNode, error) {
return f.findVRAMFromSetNode, f.findVRAMFromSetErr
}
@@ -244,7 +274,7 @@ type fakeUnloader struct {
unloadErr error
}
func (f *fakeUnloader) InstallBackend(_, _, _, _, _, _, _ string) (*messaging.BackendInstallReply, error) {
func (f *fakeUnloader) InstallBackend(_, _, _, _, _, _, _ string, _ int) (*messaging.BackendInstallReply, error) {
return f.installReply, f.installErr
}
@@ -690,8 +720,8 @@ var _ = Describe("SmartRouter", func() {
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Load a model and give it in-flight requests so it cannot be evicted
Expect(registry.SetNodeModel(context.Background(), node.ID, "busy-model", "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "busy-model")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "busy-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "busy-model", 0)).To(Succeed())
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})
@@ -711,8 +741,8 @@ var _ = Describe("SmartRouter", func() {
Address: "10.0.0.101:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "cancel-model", "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "cancel-model")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "cancel-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "cancel-model", 0)).To(Succeed())
router := NewSmartRouter(registry, SmartRouterOptions{DB: db})

View File

@@ -17,7 +17,7 @@ type backendStopRequest struct {
// NodeCommandSender abstracts NATS-based commands to worker nodes.
// Used by HTTP endpoint handlers to avoid coupling to the concrete RemoteUnloaderAdapter.
type NodeCommandSender interface {
InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string) (*messaging.BackendInstallReply, error)
InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error)
DeleteBackend(nodeID, backendName string) (*messaging.BackendDeleteReply, error)
ListBackends(nodeID string) (*messaging.BackendListReply, error)
StopBackend(nodeID, backend string) error
@@ -61,8 +61,9 @@ func (a *RemoteUnloaderAdapter) UnloadRemoteModel(modelName string) error {
xlog.Warn("Failed to send backend.stop", "node", node.Name, "error", err)
continue
}
// Remove model from registry — the node will handle the actual cleanup
a.registry.RemoveNodeModel(ctx, node.ID, modelName)
// Remove every replica of this model on the node — the worker will
// handle the actual process cleanup.
a.registry.RemoveAllNodeModelReplicas(ctx, node.ID, modelName)
}
return nil
@@ -71,10 +72,15 @@ func (a *RemoteUnloaderAdapter) UnloadRemoteModel(modelName string) error {
// InstallBackend sends a backend.install request-reply to a worker node.
// The worker installs the backend from gallery (if not already installed),
// starts the gRPC process, and replies when ready.
//
// replicaIndex selects which replica slot the worker should use as its
// process key — distinct slots run on distinct ports so multiple replicas of
// the same model can coexist on a fat node. Pass 0 for single-replica.
//
// Timeout: 5 minutes (gallery install can take a while).
func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string) (*messaging.BackendInstallReply, error) {
func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, galleriesJSON, uri, name, alias string, replicaIndex int) (*messaging.BackendInstallReply, error) {
subject := messaging.SubjectNodeBackendInstall(nodeID)
xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID)
xlog.Info("Sending NATS backend.install", "nodeID", nodeID, "backend", backendType, "modelID", modelID, "replica", replicaIndex)
return messaging.RequestJSON[messaging.BackendInstallRequest, messaging.BackendInstallReply](a.nats, subject, messaging.BackendInstallRequest{
Backend: backendType,
@@ -83,6 +89,7 @@ func (a *RemoteUnloaderAdapter) InstallBackend(nodeID, backendType, modelID, gal
URI: uri,
Name: name,
Alias: alias,
ReplicaIndex: int32(replicaIndex),
}, 5*time.Minute)
}

View File

@@ -31,7 +31,12 @@ func (f *fakeModelLocator) FindNodesWithModel(_ context.Context, _ string) ([]Ba
return f.nodes, f.findErr
}
func (f *fakeModelLocator) RemoveNodeModel(_ context.Context, nodeID, modelName string) error {
func (f *fakeModelLocator) RemoveNodeModel(_ context.Context, nodeID, modelName string, _ int) error {
f.removedPairs = append(f.removedPairs, modelNodePair{nodeID, modelName})
return nil
}
func (f *fakeModelLocator) RemoveAllNodeModelReplicas(_ context.Context, nodeID, modelName string) error {
f.removedPairs = append(f.removedPairs, modelNodePair{nodeID, modelName})
return nil
}

View File

@@ -329,14 +329,14 @@ var _ = Describe("Full Distributed Inference Flow", Label("Distributed"), func()
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
// Set both as having the model loaded
Expect(registry.SetNodeModel(context.Background(), node1.ID, "test-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "test-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node1.ID, "test-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "test-model", 0, "loaded", "", 0)).To(Succeed())
// Set node-1 with high in-flight (5), node-2 with low in-flight (1)
for range 5 {
Expect(registry.IncrementInFlight(context.Background(), node1.ID, "test-model")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node1.ID, "test-model", 0)).To(Succeed())
}
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "test-model")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "test-model", 0)).To(Succeed())
// Route should pick node-2 (least loaded) thanks to ORDER BY in_flight ASC
router := newTestSmartRouter(registry)
@@ -380,7 +380,7 @@ var _ = Describe("Full Distributed Inference Flow", Label("Distributed"), func()
// Register a node with a loaded model
node := &nodes.BackendNode{Name: "gpu-unload", Address: "127.0.0.1:50099"}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "old-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "old-model", 0, "loaded", "", 0)).To(Succeed())
// Subscribe to NATS backend.stop for this node
stopSubject := messaging.SubjectNodeBackendStop(node.ID)

View File

@@ -97,7 +97,7 @@ var _ = Describe("DistributedModelStore", Label("Distributed"), func() {
Name: "range-node", Address: "range:9000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "db-only-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "db-only-model", 0, "loaded", "", 0)).To(Succeed())
visited := map[string]bool{}
dStore.Range(func(id string, m *model.Model) bool {
@@ -112,7 +112,7 @@ var _ = Describe("DistributedModelStore", Label("Distributed"), func() {
Name: "dup-node", Address: "dup:9000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "shared-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "shared-model", 0, "loaded", "", 0)).To(Succeed())
// Also in local store
localStore.Set("shared-model", model.NewModel("shared-model", "dup:9000", nil))

View File

@@ -132,8 +132,8 @@ var _ = Describe("Model and Backend Managers", Label("Distributed"), func() {
node2 := &nodes.BackendNode{Name: "dm-n2", Address: "h2:50051"}
Expect(registry.Register(context.Background(), node1, true)).To(Succeed())
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node1.ID, "big-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "big-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node1.ID, "big-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "big-model", 0, "loaded", "", 0)).To(Succeed())
// Subscribe to model.delete on both node subjects, track receipt
var deleteCount atomic.Int32

View File

@@ -53,9 +53,9 @@ var _ = Describe("Model Routing", Label("Distributed"), func() {
Name: "gpu-1", Address: "h1:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "llama3", "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "llama3", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "llama3", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "llama3", 0)).To(Succeed())
// Verify in-flight count
models, err := registry.GetNodeModels(context.Background(), node.ID)
@@ -75,7 +75,7 @@ var _ = Describe("Model Routing", Label("Distributed"), func() {
Expect(models[0].InFlight).To(Equal(3))
// Simulate decrement (what Release does)
Expect(registry.DecrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
Expect(registry.DecrementInFlight(context.Background(), node.ID, "llama3", 0)).To(Succeed())
models, _ = registry.GetNodeModels(context.Background(), node.ID)
Expect(models[0].InFlight).To(Equal(2))
@@ -99,7 +99,7 @@ var _ = Describe("Model Routing", Label("Distributed"), func() {
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
// Load model on node1
Expect(registry.SetNodeModel(context.Background(), node1.ID, "llama3", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node1.ID, "llama3", 0, "loaded", "", 0)).To(Succeed())
// Verify routing can find the model
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "llama3")

View File

@@ -57,7 +57,7 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
FlushNATS(infra.NC)
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
installReply, err := adapter.InstallBackend(node.ID, "llama-cpp", "", "", "", "", "")
installReply, err := adapter.InstallBackend(node.ID, "llama-cpp", "", "", "", "", "", 0)
Expect(err).ToNot(HaveOccurred())
Expect(installReply.Success).To(BeTrue())
})
@@ -78,7 +78,7 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
FlushNATS(infra.NC)
adapter := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
installReply, err := adapter.InstallBackend(node.ID, "nonexistent", "", "", "", "", "")
installReply, err := adapter.InstallBackend(node.ID, "nonexistent", "", "", "", "", "", 0)
Expect(err).ToNot(HaveOccurred())
Expect(installReply.Success).To(BeFalse())
Expect(installReply.Error).To(ContainSubstring("backend not found"))
@@ -91,7 +91,7 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
Name: "gpu-node-2", Address: "h2:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "whisper-large", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "whisper-large", 0, "loaded", "", 0)).To(Succeed())
var stopReceived atomic.Int32
sub, err := infra.NC.Subscribe(messaging.SubjectNodeBackendStop(node.ID), func(data []byte) {
@@ -118,8 +118,8 @@ var _ = Describe("Node Backend Lifecycle (NATS-driven)", Label("Distributed"), f
node2 := &nodes.BackendNode{Name: "n2", Address: "h2:50051"}
registry.Register(context.Background(), node1, true)
registry.Register(context.Background(), node2, true)
registry.SetNodeModel(context.Background(), node1.ID, "shared-model", "loaded", "", 0)
registry.SetNodeModel(context.Background(), node2.ID, "shared-model", "loaded", "", 0)
registry.SetNodeModel(context.Background(), node1.ID, "shared-model", 0, "loaded", "", 0)
registry.SetNodeModel(context.Background(), node2.ID, "shared-model", 0, "loaded", "", 0)
var count atomic.Int32
sub1, _ := infra.NC.Subscribe(messaging.SubjectNodeBackendStop(node1.ID), func(data []byte) {

View File

@@ -162,7 +162,7 @@ var _ = Describe("Phase 1: Node Registration", Label("Distributed"), func() {
})
It("should track models loaded on a node", func() {
err := registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
err := registry.SetNodeModel(context.Background(), nodeID, "llama3", 0, "loaded", "", 0)
Expect(err).ToNot(HaveOccurred())
models, err := registry.GetNodeModels(context.Background(), nodeID)
@@ -173,7 +173,7 @@ var _ = Describe("Phase 1: Node Registration", Label("Distributed"), func() {
})
It("should find nodes with a specific model", func() {
registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
registry.SetNodeModel(context.Background(), nodeID, "llama3", 0, "loaded", "", 0)
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "llama3")
Expect(err).ToNot(HaveOccurred())
@@ -182,24 +182,24 @@ var _ = Describe("Phase 1: Node Registration", Label("Distributed"), func() {
})
It("should increment and decrement in-flight counters", func() {
registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
registry.SetNodeModel(context.Background(), nodeID, "llama3", 0, "loaded", "", 0)
err := registry.IncrementInFlight(context.Background(), nodeID, "llama3")
err := registry.IncrementInFlight(context.Background(), nodeID, "llama3", 0)
Expect(err).ToNot(HaveOccurred())
err = registry.IncrementInFlight(context.Background(), nodeID, "llama3")
err = registry.IncrementInFlight(context.Background(), nodeID, "llama3", 0)
Expect(err).ToNot(HaveOccurred())
models, _ := registry.GetNodeModels(context.Background(), nodeID)
Expect(models[0].InFlight).To(Equal(2))
registry.DecrementInFlight(context.Background(), nodeID, "llama3")
registry.DecrementInFlight(context.Background(), nodeID, "llama3", 0)
models, _ = registry.GetNodeModels(context.Background(), nodeID)
Expect(models[0].InFlight).To(Equal(1))
})
It("should remove model association from node", func() {
registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
err := registry.RemoveNodeModel(context.Background(), nodeID, "llama3")
registry.SetNodeModel(context.Background(), nodeID, "llama3", 0, "loaded", "", 0)
err := registry.RemoveNodeModel(context.Background(), nodeID, "llama3", 0)
Expect(err).ToNot(HaveOccurred())
models, _ := registry.GetNodeModels(context.Background(), nodeID)
@@ -208,9 +208,9 @@ var _ = Describe("Phase 1: Node Registration", Label("Distributed"), func() {
It("should find LRU model on a node", func() {
// Load two models
registry.SetNodeModel(context.Background(), nodeID, "old-model", "loaded", "", 0)
registry.SetNodeModel(context.Background(), nodeID, "old-model", 0, "loaded", "", 0)
time.Sleep(10 * time.Millisecond)
registry.SetNodeModel(context.Background(), nodeID, "new-model", "loaded", "", 0)
registry.SetNodeModel(context.Background(), nodeID, "new-model", 0, "loaded", "", 0)
// Update last_used to make old-model older
db.Model(&nodes.NodeModel{}).Where("node_id = ? AND model_name = ?", nodeID, "old-model").
@@ -222,8 +222,8 @@ var _ = Describe("Phase 1: Node Registration", Label("Distributed"), func() {
})
It("should clean up models when deregistering node", func() {
registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
registry.SetNodeModel(context.Background(), nodeID, "whisper", "loaded", "", 0)
registry.SetNodeModel(context.Background(), nodeID, "llama3", 0, "loaded", "", 0)
registry.SetNodeModel(context.Background(), nodeID, "whisper", 0, "loaded", "", 0)
err := registry.Deregister(context.Background(), nodeID)
Expect(err).ToNot(HaveOccurred())

View File

@@ -45,7 +45,7 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
Name: "healthy-node", Address: "h:5000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-a", 0, "loaded", "", 0)).To(Succeed())
models, err := registry.ListAllLoadedModels(context.Background())
Expect(err).ToNot(HaveOccurred())
@@ -58,7 +58,7 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
Name: "sick-node", Address: "s:5000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-on-sick", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-on-sick", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.MarkUnhealthy(context.Background(), node.ID)).To(Succeed())
models, err := registry.ListAllLoadedModels(context.Background())
@@ -71,8 +71,8 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
Name: "state-node", Address: "st:5000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "loading-model", "loading", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "idle-model", "idle", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "loading-model", 0, "loading", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "idle-model", 0, "idle", "", 0)).To(Succeed())
models, err := registry.ListAllLoadedModels(context.Background())
Expect(err).ToNot(HaveOccurred())
@@ -88,8 +88,8 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
}
Expect(registry.Register(context.Background(), node1, true)).To(Succeed())
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node1.ID, "model-x", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-y", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node1.ID, "model-x", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "model-y", 0, "loaded", "", 0)).To(Succeed())
models, err := registry.ListAllLoadedModels(context.Background())
Expect(err).ToNot(HaveOccurred())
@@ -110,7 +110,7 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
Name: "find-node", Address: "f:5000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "findable-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "findable-model", 0, "loaded", "", 0)).To(Succeed())
found, ok := registry.FindNodeForModel(context.Background(), "findable-model")
Expect(ok).To(BeTrue())
@@ -129,7 +129,7 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
Name: "unhealthy-find", Address: "uf:5000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "unhealthy-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "unhealthy-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.MarkUnhealthy(context.Background(), node.ID)).To(Succeed())
found, ok := registry.FindNodeForModel(context.Background(), "unhealthy-model")
@@ -144,8 +144,8 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
Name: "stale-clear-node", Address: "sc:5000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "stale-model-1", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "stale-model-2", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "stale-model-1", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "stale-model-2", 0, "loaded", "", 0)).To(Succeed())
// Verify models exist
models, err := registry.GetNodeModels(context.Background(), node.ID)
@@ -187,7 +187,7 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
}
Expect(registry.Register(context.Background(), busy, true)).To(Succeed())
Expect(registry.Register(context.Background(), idle, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), busy.ID, "some-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), busy.ID, "some-model", 0, "loaded", "", 0)).To(Succeed())
found, err := registry.FindIdleNode(context.Background())
Expect(err).ToNot(HaveOccurred())
@@ -200,7 +200,7 @@ var _ = Describe("NodeRegistry extra methods", Label("Distributed"), func() {
Name: "loaded-node", Address: "loaded:5000",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-x", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "model-x", 0, "loaded", "", 0)).To(Succeed())
_, err := registry.FindIdleNode(context.Background())
Expect(err).To(HaveOccurred())

View File

@@ -150,7 +150,7 @@ var _ = Describe("SmartRouter trackingKey", Label("Distributed"), func() {
}
// Manually increment in-flight (simulates what InFlightTrackingClient.track() does during inference)
Expect(registry.IncrementInFlight(context.Background(), nodeID, "release-model")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), nodeID, "release-model", 0)).To(Succeed())
// Check in-flight increased
models, err = registry.GetNodeModels(context.Background(), nodeID)
@@ -164,7 +164,7 @@ var _ = Describe("SmartRouter trackingKey", Label("Distributed"), func() {
Expect(inflight).To(Equal(baseline + 1))
// Decrement and check in-flight goes back to baseline
Expect(registry.DecrementInFlight(context.Background(), nodeID, "release-model")).To(Succeed())
Expect(registry.DecrementInFlight(context.Background(), nodeID, "release-model", 0)).To(Succeed())
models, err = registry.GetNodeModels(context.Background(), nodeID)
Expect(err).ToNot(HaveOccurred())