Files
LocalAI/tests/e2e/distributed/distributed_full_flow_test.go
Ettore Di Giacinto 6b63b47f61 feat(distributed): support multiple replicas of one model on the same node (#9583)
* feat(distributed): support multiple replicas of one model on the same node

The distributed scheduler implicitly assumed `(node_id, model_name)` was
unique, but the schema didn't enforce it and the worker keyed all gRPC
processes by model name alone. With `MinReplicas=2` against a single
worker, the reconciler "scaled up" every 30s but the registry never
advanced past 1 row — the worker re-loaded the model in-place every tick
until VRAM fragmented and the gRPC process died.

This change introduces multi-replica-per-node as a first-class concept,
with capacity-aware scheduling, a circuit breaker, and VRAM
soft-reservation. Operators can declare per-node capacity via the worker
flag `--max-replicas-per-model` (mirrored as auto-label
`node.replica-slots=N`) or override per-node from the UI.

* Schema: BackendNode gains MaxReplicasPerModel (default 1) and
  ReservedVRAM. NodeModel gains ReplicaIndex (composite with node_id +
  model_name). ModelSchedulingConfig gains UnsatisfiableUntil/Ticks for
  the reconciler circuit breaker.

* Registry: replica_index threaded through SetNodeModel, RemoveNodeModel,
  IncrementInFlight, DecrementInFlight, TouchNodeModel, GetNodeModel,
  SetNodeModelLoadInfo and the InFlightTrackingClient. New helpers:
  CountReplicasOnNode, NextFreeReplicaIndex (with ErrNoFreeSlot),
  RemoveAllNodeModelReplicas, FindNodesWithFreeSlot,
  ClusterCapacityForModel, ReserveVRAM/ReleaseVRAM (atomic UPDATE with
  ErrInsufficientVRAM), and the unsatisfiable-flag CRUD.

* Worker: processKey now `<modelID>#<replicaIndex>` so concurrent loads
  of the same model land on distinct ports. Adds CLI flag
  --max-replicas-per-model (env LOCALAI_MAX_REPLICAS_PER_MODEL, default 1)
  and emits the auto-label.

* Router: scheduleNewModel filters candidates by free slot, allocates the
  replica index, and soft-reserves VRAM before installing the backend.
  evictLRUAndFreeNode now deletes the targeted row by ID instead of all
  replicas of the model on the node — fixes a latent bug where evicting
  one replica orphaned its siblings.

* Reconciler: caps scale-up at ClusterCapacityForModel so a misconfig
  (MinReplicas > capacity) doesn't loop forever. After 3 consecutive
  ticks of capacity==0 it sets UnsatisfiableUntil for a 5m cooldown and
  emits a warning. ClearAllUnsatisfiable fires from Register,
  ApproveNode, SetNodeLabel(s), RemoveNodeLabel and
  UpdateMaxReplicasPerModel so a new node joining or label changes wake
  the reconciler immediately. scaleDownIdle removes highest-replica-index
  first to keep slots compact.

* Heartbeat resets reserved_vram to 0 — worker is the source of truth
  for actual free VRAM; the reservation is only for the in-tick race
  window between two scheduling decisions.

* Probe path (reconciler.probeLoadedModels and health.doCheckAll) now
  pass the row's replica_index to RemoveNodeModel so an unreachable
  replica doesn't orphan healthy siblings.

* Admin override: PUT /api/nodes/:id/max-replicas-per-model sets a
  sticky override (preserved across worker re-registration). DELETE
  clears the override so the worker's flag applies again on next
  register. Required because Kong defaults the worker flag to 1, so
  every worker restart would have silently reverted the UI value.

* React UI: always-visible slot badge on the node row (muted at default
  1, accented when >1); inline editor in the expanded drawer with
  pencil-to-edit, Save/Cancel, Esc/Enter, "(override)" indicator when
  the value is admin-set, and a "Reset" button to hand control back to
  the worker. Soft confirm when shrinking the cap below the count of
  loaded replicas. Scheduling rules table gets an "Unsatisfiable until
  HH:MM" status badge surfacing the cooldown.

* node.replica-slots filtered out of the labels strip on the row to
  avoid duplicating the slot badge.

23 new Ginkgo specs (registry, reconciler, inflight, health) cover:
multi-replica row independence, RemoveNodeModel of one replica
preserving siblings, NextFreeReplicaIndex slot allocation including
ErrNoFreeSlot, capacity-gated scale-up with circuit breaker tripping
and recovery on Register, scheduleDownIdle ordering, ClusterCapacity
math, ReserveVRAM admission gating, Heartbeat reset, override survival
across worker re-registration, and ResetMaxReplicasPerModel handing
control back. Plus 8 stdlib tests for the worker processKey / CLI /
auto-label.

Closes the flap reproduced on Qwen3.6-35B against the nvidia-thor
worker (single 128 GiB node, MinReplicas=2): the reconciler now caps
the scale-up at the cluster's actual capacity instead of looping.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Read] [Edit] [Bash] [Skill:critique] [Skill:audit] [Skill:polish] [Skill:golang-testing]

* refactor(react-ui/nodes): tighten capacity editor copy + adopt ActionMenu for row actions

* Capacity editor hint trimmed from operator-doc-style ("Sourced from
  the worker's `--max-replicas-per-model` flag. Changing it here makes it
  a sticky admin override that survives worker restarts." → "Saved
  values stick across worker restarts.") and the override-state copy
  similarly compressed. The full mechanic is no longer needed in the UI
  — the override pill carries the meaning and the docs cover the rest.

* Node row actions migrated from an inline cluster of icon buttons
  (Drain / Resume / Trash) to the kebab ActionMenu used by /manage for
  per-row model actions, so dense Nodes tables stay clean. Approve
  stays as a prominent primary button — it's a stateful admission gate,
  not a routine action, and elevating it matches how /manage surfaces
  install-time decisions outside the menu.

* The expanded drawer's Labels section now filters node.replica-slots
  out of the editable label list. The label is owned by the Capacity
  editor above; surfacing it again as an editable label invited
  confusion (the Capacity save would clobber any direct edit).

Both backend and agent workers benefit — they share the row rendering
path, so the action menu and label filter apply to both.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp] [Skill:critique] [Skill:audit] [Skill:polish]

* fix(react-ui/nodes): suppress slot badge on agent workers

Agent workers don't load models, so the per-node replica capacity is
inapplicable to them. Showing "1× slots" on agent rows was a tiny
inconsistency from the unified rendering path — gate the badge on
node_type !== 'agent' so it only appears on backend workers.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp]

* refactor(react-ui/nodes): distill expanded drawer + restyle scheduling form

The expanded node drawer used to stack five panels — slot badge,
filled capacity box, Loaded Models h4+empty-state, Installed Backends
h4+empty-state, Labels h4+chips+form — making routine inspections feel
like a control panel. The scheduling rule form wrapped its mode toggle
as two 50%-width filled buttons that competed visually with the actual
primary action.

* Drawer: collapse three rarely-touched config zones (Capacity,
  Backends, Labels) into one `<details>` "Manage" disclosure (closed by
  default) with small uppercase eyebrow labels for each zone instead of
  parallel h4 sub-headings. Loaded Models stays as the at-a-glance
  headline with a single-line empty hint instead of a boxed empty state.
  CapacityEditor renders flat (no filled background) — the Manage
  disclosure provides framing.

* Scheduling form: replace the chunky 50%-width button-tabs with the
  project's existing `.segmented` control (icon + label, sized to
  content). Mode hint becomes a single tied line below. Fields stack
  vertically with helper text under inputs and a hairline divider above
  the right-aligned Save / Cancel.

The empty drawer collapses from ~5 stacked sections (~280px tall) to
two lines (~80px). The scheduling form now reads as a designed dialog
instead of raw building blocks. Both surfaces now match the typographic
density and weight of the rest of the admin pages.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [chrome-devtools-mcp] [Skill:distill] [Skill:audit] [Skill:polish]

* feat(react-ui/nodes): replace scheduling form's model picker with searchable combobox

The native <select> made operators scroll through every gallery entry to
find a model name. The project already has SearchableModelSelect (used
in Studio/Talk/etc.) which combines free-text search with the gallery
list and accepts typed model names that aren't installed yet — useful
for pre-staging a scheduling rule before the node it'll run on has
finished bootstrapping.

Also drops the now-unused useModels import (the combobox manages the
gallery hook internally).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit]

* refactor(react-ui/nodes): consolidate key/value chip editor + add replica preset chips

The Nodes page was rendering the same key=value chip pattern in two
places with subtly different markup: the Labels editor in the expanded
drawer and (post-distill) the Node Selector input in the scheduling
form. The form's input was also a comma-separated string that operators
were getting wrong.

* Extract <KeyValueChips> as a fully controlled chip-builder. Parent
  owns the map and decides what onAdd/onRemove does — form state for the
  scheduling form, API calls for the live drawer Labels editor. Same
  visuals everywhere; one component to change when polish needs apply.

* Replace the comma-separated Node Selector text input with KeyValueChips.
  Operators were copying syntax from docs and missing commas; the chip
  vocabulary makes the key=value structure self-documenting.

* Add <ReplicaInput>: numeric input + quick-pick preset chips for Min/Max
  replicas. Picked over a slider because replica counts are exact specs
  derived from VRAM math (operator decision, not a fuzzy estimate). The
  chips give one-click access to common values (1/2/3/4 for Min,
  0=no-limit/2/4/8 for Max) without the slider's special-value problem
  (MaxReplicas=0 is categorical, not a position on a continuum).

* Drop the now-unused labelInputs state in the Nodes page (the inline
  label editor's per-node draft state lived there and is now owned by
  KeyValueChips).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [Skill:distill]

* test: fix CI fallout from multi-replica refactor (e2e/distributed + playwright)

Two breakages caught by CI that didn't surface in the local run:

* tests/e2e/distributed/*.go — multiple files used the pre-PR2 registry
  signatures for SetNodeModel / IncrementInFlight / DecrementInFlight /
  RemoveNodeModel / TouchNodeModel / GetNodeModel / SetNodeModelLoadInfo
  and one stale adapter.InstallBackend call in node_lifecycle_test.go.
  All updated to pass replicaIndex=0 — these tests don't exercise
  multi-replica behavior, they just need to compile against the new
  signatures. The chip-builder tests in core/services/nodes/ already
  cover the multi-replica logic.

* core/http/react-ui/e2e/nodes-per-node-backend-actions.spec.js — the
  drawer's distill refactor moved Backends inside a "Manage" <details>
  disclosure that's collapsed by default. The test helper expanded the
  node row but never opened Manage, so the per-node backend table was
  never in the DOM. Helper now clicks `.node-manage > summary` after
  expanding the row.

All 100 playwright tests pass locally; tests/e2e/distributed compiles
clean.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Assisted-by: claude-code:opus-4-7 [Edit] [Bash]

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-04-27 21:20:05 +02:00

985 lines
34 KiB
Go

package distributed_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"time"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/nats-io/nats.go"
"google.golang.org/grpc"
pgdriver "gorm.io/driver/postgres"
gormDB "gorm.io/gorm"
"gorm.io/gorm/logger"
)
// testLLM is a minimal AIModel implementation for testing.
// Override methods to write output to Dst so we can test the full
// FileStagingClient round-trip (upload inputs + download outputs).
type testLLM struct {
base.Base
loaded bool
lastModel string
// dstOutput is the content written to any Dst path by output-producing methods.
dstOutput []byte
// lastSrc records the last Src/input path seen (for verifying staging rewrote it).
lastSrc string
// lastAudioDst records the Dst field from AudioTranscription (it's an input, not output).
lastAudioDst string
// lastTTSModel records the Model field from TTS requests (for verifying path rewriting).
lastTTSModel string
}
func (t *testLLM) Load(opts *pb.ModelOptions) error {
t.loaded = true
t.lastModel = opts.ModelFile
return nil
}
func (t *testLLM) Predict(opts *pb.PredictOptions) (string, error) {
if !t.loaded {
return "", fmt.Errorf("model not loaded")
}
return "test response from remote node", nil
}
func (t *testLLM) GenerateImage(req *pb.GenerateImageRequest) error {
t.lastSrc = req.Src
if req.Dst != "" && len(t.dstOutput) > 0 {
return os.WriteFile(req.Dst, t.dstOutput, 0644)
}
return nil
}
func (t *testLLM) GenerateVideo(req *pb.GenerateVideoRequest) error {
t.lastSrc = req.StartImage
if req.Dst != "" && len(t.dstOutput) > 0 {
return os.WriteFile(req.Dst, t.dstOutput, 0644)
}
return nil
}
func (t *testLLM) TTS(req *pb.TTSRequest) error {
t.lastTTSModel = req.Model
if req.Dst != "" && len(t.dstOutput) > 0 {
return os.WriteFile(req.Dst, t.dstOutput, 0644)
}
return nil
}
func (t *testLLM) SoundGeneration(req *pb.SoundGenerationRequest) error {
if req.Src != nil {
t.lastSrc = *req.Src
}
if req.Dst != "" && len(t.dstOutput) > 0 {
return os.WriteFile(req.Dst, t.dstOutput, 0644)
}
return nil
}
func (t *testLLM) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
t.lastAudioDst = req.Dst
return pb.TranscriptResult{Text: "transcribed text"}, nil
}
// startTestGRPCServer starts a real gRPC backend server on a free port
// and returns the address and cleanup function.
func startTestGRPCServer(llm grpcPkg.AIModel) (string, func(), error) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", nil, err
}
addr := lis.Addr().String()
s := grpc.NewServer(
grpc.MaxRecvMsgSize(50*1024*1024),
grpc.MaxSendMsgSize(50*1024*1024),
)
pb.RegisterBackendServer(s, grpcPkg.NewBackendServer(llm))
go func() {
defer GinkgoRecover()
_ = s.Serve(lis)
}()
cleanup := func() {
s.GracefulStop()
}
return addr, cleanup, nil
}
// startTestHTTPFileServer starts a test HTTP file transfer server (mirroring serve_backend_http.go)
// on a free port and returns the address and cleanup function.
func startTestHTTPFileServer(stagingDir string) (string, func(), error) {
if err := os.MkdirAll(stagingDir, 0750); err != nil {
return "", nil, err
}
mux := http.NewServeMux()
mux.HandleFunc("/v1/files/", func(w http.ResponseWriter, r *http.Request) {
key := r.URL.Path[len("/v1/files/"):]
switch r.Method {
case http.MethodPut:
safeName := filepath.Base(key)
dstPath := filepath.Join(stagingDir, safeName)
f, err := os.Create(dstPath)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
io.Copy(f, r.Body)
f.Close()
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"local_path":%q}`, dstPath)
case http.MethodGet:
safeName := filepath.Base(key)
srcPath := filepath.Join(stagingDir, safeName)
if _, statErr := os.Stat(srcPath); os.IsNotExist(statErr) {
// AllocRemoteTemp creates files under stagingDir/tmp/
srcPath = filepath.Join(stagingDir, "tmp", safeName)
}
f, err := os.Open(srcPath)
if err != nil {
http.Error(w, "not found", http.StatusNotFound)
return
}
defer f.Close()
w.Header().Set("Content-Type", "application/octet-stream")
io.Copy(w, f)
case http.MethodPost:
if key == "temp" {
tmpDir := filepath.Join(stagingDir, "tmp")
os.MkdirAll(tmpDir, 0750)
f, err := os.CreateTemp(tmpDir, "output-*")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
localPath := f.Name()
f.Close()
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"local_path":%q}`, localPath)
} else {
http.Error(w, "not found", http.StatusNotFound)
}
}
})
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", nil, err
}
httpAddr := lis.Addr().String()
srv := &http.Server{Handler: mux}
go srv.Serve(lis)
cleanup := func() {
srv.Close()
}
return httpAddr, cleanup, nil
}
var _ = Describe("Full Distributed Inference Flow", Label("Distributed"), func() {
var (
infra *TestInfra
cancel context.CancelFunc
ctx context.Context
db *gormDB.DB
registry *nodes.NodeRegistry
)
BeforeEach(func() {
infra = SetupInfra("localai_fullflow_test")
ctx, cancel = context.WithTimeout(infra.Ctx, 2*time.Minute)
var err error
db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
Expect(err).ToNot(HaveOccurred())
registry, err = nodes.NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
cancel()
})
// newTestSmartRouter creates a SmartRouter with NATS wired up and a mock
// backend.install handler that always replies success for all registered nodes.
newTestSmartRouter := func(reg *nodes.NodeRegistry, extraOpts ...nodes.SmartRouterOptions) *nodes.SmartRouter {
unloader := nodes.NewRemoteUnloaderAdapter(reg, infra.NC)
opts := nodes.SmartRouterOptions{
Unloader: unloader,
}
if len(extraOpts) > 0 {
o := extraOpts[0]
if o.FileStager != nil {
opts.FileStager = o.FileStager
}
if o.GalleriesJSON != "" {
opts.GalleriesJSON = o.GalleriesJSON
}
if o.AuthToken != "" {
opts.AuthToken = o.AuthToken
}
if o.DB != nil {
opts.DB = o.DB
}
}
router := nodes.NewSmartRouter(reg, opts)
// Subscribe a mock backend.install handler that replies success for any node.
// We use a wildcard-style approach: subscribe to all nodes' install subjects
// by registering after each node. In practice, we rely on the test registering
// nodes before calling Route, so we subscribe to a catch-all pattern.
infra.NC.Conn().Subscribe("nodes.*.backend.install", func(msg *nats.Msg) {
reply := messaging.BackendInstallReply{Success: true}
data, _ := json.Marshal(reply)
msg.Respond(data)
})
return router
}
// suppress unused warning in case some tests don't call it
_ = newTestSmartRouter
It("should route inference to a registered node with a real gRPC backend", func() {
// 1. Start a mock gRPC backend
llm := &testLLM{}
addr, cleanup, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanup()
// 2. Register it as a node
node := &nodes.BackendNode{
Name: "test-gpu-1",
Address: addr,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// 3. Create SmartRouter and route a request
router := newTestSmartRouter(registry)
// The model is not loaded yet, so Route will pick the node and call LoadModel
result, err := router.Route(ctx, "", "test-model", "llama-cpp", &pb.ModelOptions{
Model: "test-model",
}, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.Name).To(Equal("test-gpu-1"))
// 4. Verify the model was loaded on the backend
Expect(llm.loaded).To(BeTrue())
// 5. Use the client to call Predict
reply, err := result.Client.Predict(ctx, &pb.PredictOptions{
Prompt: "Hello world",
})
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.Message)).To(Equal("test response from remote node"))
// 6. Release and verify in-flight decremented
result.Release()
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(1))
Expect(models[0].InFlight).To(Equal(0))
// 7. Verify model recorded as "loaded" in registry
Expect(models[0].State).To(Equal("loaded"))
Expect(models[0].ModelName).To(Equal("test-model"))
})
It("should load-balance across multiple nodes with same model", func() {
// Start two mock gRPC backends
llm1 := &testLLM{}
addr1, cleanup1, err := startTestGRPCServer(llm1)
Expect(err).ToNot(HaveOccurred())
defer cleanup1()
llm2 := &testLLM{}
addr2, cleanup2, err := startTestGRPCServer(llm2)
Expect(err).ToNot(HaveOccurred())
defer cleanup2()
// Register both nodes
node1 := &nodes.BackendNode{Name: "node-heavy", Address: addr1}
node2 := &nodes.BackendNode{Name: "node-light", Address: addr2}
Expect(registry.Register(context.Background(), node1, true)).To(Succeed())
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
// Set both as having the model loaded
Expect(registry.SetNodeModel(context.Background(), node1.ID, "test-model", 0, "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "test-model", 0, "loaded", "", 0)).To(Succeed())
// Set node-1 with high in-flight (5), node-2 with low in-flight (1)
for range 5 {
Expect(registry.IncrementInFlight(context.Background(), node1.ID, "test-model", 0)).To(Succeed())
}
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "test-model", 0)).To(Succeed())
// Route should pick node-2 (least loaded) thanks to ORDER BY in_flight ASC
router := newTestSmartRouter(registry)
result, err := router.Route(ctx, "", "test-model", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.Name).To(Equal("node-light"))
result.Release()
})
It("should load model on empty node when no node has it", func() {
// Start a mock gRPC backend
llm := &testLLM{}
addr, cleanup, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanup()
// Register a node with NO models loaded
node := &nodes.BackendNode{Name: "empty-node", Address: addr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Route should pick this node and call LoadModel on it
router := newTestSmartRouter(registry)
result, err := router.Route(ctx, "", "new-model", "llama-cpp", &pb.ModelOptions{
Model: "new-model",
}, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.Name).To(Equal("empty-node"))
Expect(llm.loaded).To(BeTrue())
// Verify model is now recorded in registry
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(1))
Expect(models[0].ModelName).To(Equal("new-model"))
Expect(models[0].State).To(Equal("loaded"))
result.Release()
})
It("should unload remote model via NATS", func() {
// Register a node with a loaded model
node := &nodes.BackendNode{Name: "gpu-unload", Address: "127.0.0.1:50099"}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "old-model", 0, "loaded", "", 0)).To(Succeed())
// Subscribe to NATS backend.stop for this node
stopSubject := messaging.SubjectNodeBackendStop(node.ID)
received := make(chan struct{}, 1)
rawConn, err := nats.Connect(infra.NatsURL)
Expect(err).ToNot(HaveOccurred())
defer rawConn.Close()
_, err = rawConn.Subscribe(stopSubject, func(msg *nats.Msg) {
received <- struct{}{}
})
Expect(err).ToNot(HaveOccurred())
// Create RemoteUnloaderAdapter and unload model
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
err = unloader.UnloadRemoteModel("old-model")
Expect(err).ToNot(HaveOccurred())
// Verify NATS event received
Eventually(received, 5*time.Second).Should(Receive())
// Verify model removed from registry
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(BeEmpty())
})
It("should integrate ModelRouterAdapter with SmartRouter end-to-end", func() {
// Start a mock gRPC backend
llm := &testLLM{}
addr, cleanup, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanup()
// Register node
node := &nodes.BackendNode{Name: "adapter-node", Address: addr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Create SmartRouter + ModelRouterAdapter
router := newTestSmartRouter(registry)
adapter := nodes.NewModelRouterAdapter(router)
// Call adapter.Route() (same signature ModelLoader uses)
m, err := adapter.Route(ctx, "llama-cpp", "test-model-id", "test-model", "",
&pb.ModelOptions{Model: "test-model"}, false)
Expect(err).ToNot(HaveOccurred())
Expect(m).ToNot(BeNil())
// Verify returned Model has correct ID and nil process (remote)
Expect(m.ID).To(Equal("test-model-id"))
Expect(m.Process()).To(BeNil())
// Verify the model was loaded on the backend
Expect(llm.loaded).To(BeTrue())
// Use the Model's GRPC() method to get a client and verify inference works
client := m.GRPC(false, nil)
Expect(client).ToNot(BeNil())
reply, err := client.Predict(ctx, &pb.PredictOptions{Prompt: "test"})
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.Message)).To(Equal("test response from remote node"))
// Release the model via adapter
adapter.ReleaseModel("test-model-id")
})
It("should stage model files via HTTP when routing to a new node", func() {
// Create a real model file on disk
modelDir := GinkgoT().TempDir()
modelContent := []byte("fake GGUF model data — this is test content for file transfer verification")
modelPath := filepath.Join(modelDir, "model.gguf")
Expect(os.WriteFile(modelPath, modelContent, 0644)).To(Succeed())
mmprojContent := []byte("fake mmproj data for multimodal projection")
mmprojPath := filepath.Join(modelDir, "mmproj.bin")
Expect(os.WriteFile(mmprojPath, mmprojContent, 0644)).To(Succeed())
// Start a real gRPC backend server (for AI RPCs) and HTTP server (for file transfer)
llm := &testLLM{}
stagingDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
Expect(err).ToNot(HaveOccurred())
defer cleanupHTTP()
// Register the node in PostgreSQL
node := &nodes.BackendNode{Name: "staging-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Create HTTPFileStager that resolves node IDs to HTTP addresses
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Create SmartRouter with the HTTPFileStager
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
// Route with ModelOptions that have file paths — SmartRouter should stage them
result, err := router.Route(ctx, "", "staged-model", "llama-cpp", &pb.ModelOptions{
Model: "staged-model",
ModelFile: modelPath,
MMProj: mmprojPath,
}, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.Name).To(Equal("staging-node"))
// Verify the model file bytes were transferred to the backend's staging dir
stagedModelPath := filepath.Join(stagingDir, "model.gguf")
stagedModelData, err := os.ReadFile(stagedModelPath)
Expect(err).ToNot(HaveOccurred())
Expect(stagedModelData).To(Equal(modelContent))
stagedMMProjPath := filepath.Join(stagingDir, "mmproj.bin")
stagedMMProjData, err := os.ReadFile(stagedMMProjPath)
Expect(err).ToNot(HaveOccurred())
Expect(stagedMMProjData).To(Equal(mmprojContent))
// Verify LoadModel was called with the rewritten (remote) paths
Expect(llm.loaded).To(BeTrue())
Expect(llm.lastModel).To(Equal(stagedModelPath))
// Verify Predict still works through the FileStagingClient wrapper
reply, err := result.Client.Predict(ctx, &pb.PredictOptions{
Prompt: "test via staging client",
})
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.Message)).To(Equal("test response from remote node"))
result.Release()
})
It("should stage multimodal input files via HTTP through FileStagingClient", func() {
// Create a real image file on disk
imageDir := GinkgoT().TempDir()
imageContent := []byte("fake JPEG image data for multimodal testing")
imagePath := filepath.Join(imageDir, "photo.jpg")
Expect(os.WriteFile(imagePath, imageContent, 0644)).To(Succeed())
// Start gRPC server (AI RPCs) and HTTP server (file transfer)
llm := &testLLM{}
stagingDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
Expect(err).ToNot(HaveOccurred())
defer cleanupHTTP()
// Register node
node := &nodes.BackendNode{Name: "mm-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Create HTTPFileStager
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Create SmartRouter with FileStager
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
// Route with ModelOptions — triggers LoadModel on the node
modelDir := GinkgoT().TempDir()
modelPath := filepath.Join(modelDir, "vision.gguf")
Expect(os.WriteFile(modelPath, []byte("vision model data"), 0644)).To(Succeed())
result, err := router.Route(ctx, "", "vision-model", "llama-cpp", &pb.ModelOptions{
Model: "vision-model",
ModelFile: modelPath,
}, false)
Expect(err).ToNot(HaveOccurred())
// Verify LoadModel was called (model file was staged)
Expect(llm.loaded).To(BeTrue())
// Now call Predict with image file paths — FileStagingClient should stage them
_, err = result.Client.Predict(ctx, &pb.PredictOptions{
Prompt: "describe this image",
Images: []string{imagePath},
})
Expect(err).ToNot(HaveOccurred())
// Verify the image file was actually transferred to the backend staging dir
stagedImagePath := filepath.Join(stagingDir, "photo.jpg")
stagedImageData, err := os.ReadFile(stagedImagePath)
Expect(err).ToNot(HaveOccurred())
Expect(stagedImageData).To(Equal(imageContent))
result.Release()
})
It("should transfer output files back via HTTP", func() {
// Start gRPC server (AI RPCs) and HTTP server (file transfer)
llm := &testLLM{}
stagingDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
Expect(err).ToNot(HaveOccurred())
defer cleanupHTTP()
// Register node
node := &nodes.BackendNode{Name: "output-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Create HTTPFileStager
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Test AllocRemoteTemp + FetchRemote directly (the output retrieval path)
remoteTmpPath, err := stager.AllocRemoteTemp(ctx, node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(remoteTmpPath).ToNot(BeEmpty())
// Simulate backend writing output to the temp path
outputContent := []byte("generated image output data from the backend")
Expect(os.WriteFile(remoteTmpPath, outputContent, 0644)).To(Succeed())
// FetchRemote pulls the file from the backend to a local path
localOutputDir := GinkgoT().TempDir()
localOutputPath := filepath.Join(localOutputDir, "output.png")
err = stager.FetchRemote(ctx, node.ID, remoteTmpPath, localOutputPath)
Expect(err).ToNot(HaveOccurred())
// Verify the output file was retrieved with correct content
retrievedData, err := os.ReadFile(localOutputPath)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputContent))
})
// --- Full round-trip tests for every FileStagingClient src/dst path ---
// Helper: creates an HTTPFileStager + SmartRouter, registers a node,
// and routes to it. Returns the RouteResult (with FileStagingClient) and cleanup.
setupStagedRoute := func(llm *testLLM, backendType, modelName string) (
*nodes.RouteResult, string, func(),
) {
stagingDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
Expect(err).ToNot(HaveOccurred())
node := &nodes.BackendNode{Name: modelName + "-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
result, err := router.Route(ctx, "", modelName, backendType, &pb.ModelOptions{
Model: modelName,
}, false)
Expect(err).ToNot(HaveOccurred())
cleanup := func() {
cleanupGRPC()
cleanupHTTP()
}
return result, stagingDir, cleanup
}
It("should round-trip output via FileStagingClient.GenerateImage (Src + Dst)", func() {
outputData := []byte("PNG image generated by the backend - 1024x1024 pixels")
llm := &testLLM{dstOutput: outputData}
result, stagingDir, cleanup := setupStagedRoute(llm, "diffusers", "sd-model")
defer cleanup()
defer result.Release()
// Create a source image to test input staging (img2img)
srcDir := GinkgoT().TempDir()
srcContent := []byte("source image for img2img")
srcPath := filepath.Join(srcDir, "src.png")
Expect(os.WriteFile(srcPath, srcContent, 0644)).To(Succeed())
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "generated.png")
genResult, err := result.Client.GenerateImage(ctx, &pb.GenerateImageRequest{
PositivePrompt: "a cat",
Src: srcPath,
Dst: frontendDst,
Height: 1024,
Width: 1024,
})
Expect(err).ToNot(HaveOccurred())
Expect(genResult.Success).To(BeTrue())
// Verify input: Src was staged to backend — testLLM.lastSrc should be a staging dir path
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
// Verify the staged input file has correct content
stagedSrcData, err := os.ReadFile(llm.lastSrc)
Expect(err).ToNot(HaveOccurred())
Expect(stagedSrcData).To(Equal(srcContent))
// Verify output: the generated file was pulled back to the frontend
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should round-trip output via FileStagingClient.GenerateVideo (StartImage + Dst)", func() {
outputData := []byte("MP4 video generated by the backend")
llm := &testLLM{dstOutput: outputData}
result, stagingDir, cleanup := setupStagedRoute(llm, "diffusers", "vid-model")
defer cleanup()
defer result.Release()
// Create a start image to test input staging
imgDir := GinkgoT().TempDir()
startImageContent := []byte("start frame image data")
startImagePath := filepath.Join(imgDir, "start.png")
Expect(os.WriteFile(startImagePath, startImageContent, 0644)).To(Succeed())
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "generated.mp4")
genResult, err := result.Client.GenerateVideo(ctx, &pb.GenerateVideoRequest{
Prompt: "a flying cat",
StartImage: startImagePath,
Dst: frontendDst,
NumFrames: 16,
})
Expect(err).ToNot(HaveOccurred())
Expect(genResult.Success).To(BeTrue())
// Verify input: StartImage was staged
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
stagedStartData, err := os.ReadFile(llm.lastSrc)
Expect(err).ToNot(HaveOccurred())
Expect(stagedStartData).To(Equal(startImageContent))
// Verify output: video was pulled back
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should round-trip output via FileStagingClient.TTS (Dst only)", func() {
outputData := []byte("WAV audio generated by TTS backend")
llm := &testLLM{dstOutput: outputData}
result, _, cleanup := setupStagedRoute(llm, "piper", "tts-model")
defer cleanup()
defer result.Release()
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "speech.wav")
ttsResult, err := result.Client.TTS(ctx, &pb.TTSRequest{
Text: "Hello world",
Model: "tts-model",
Dst: frontendDst,
})
Expect(err).ToNot(HaveOccurred())
Expect(ttsResult.Success).To(BeTrue())
// Verify output: audio was pulled back
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should round-trip via FileStagingClient.SoundGeneration (Src + Dst)", func() {
outputData := []byte("generated sound effect audio data")
llm := &testLLM{dstOutput: outputData}
result, stagingDir, cleanup := setupStagedRoute(llm, "bark", "soundgen-model")
defer cleanup()
defer result.Release()
// Create input audio source
srcDir := GinkgoT().TempDir()
srcContent := []byte("input audio for sound generation")
srcPath := filepath.Join(srcDir, "input.wav")
Expect(os.WriteFile(srcPath, srcContent, 0644)).To(Succeed())
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "output.wav")
sgResult, err := result.Client.SoundGeneration(ctx, &pb.SoundGenerationRequest{
Text: "explosion sound",
Src: &srcPath,
Dst: frontendDst,
})
Expect(err).ToNot(HaveOccurred())
Expect(sgResult.Success).To(BeTrue())
// Verify input: Src was staged
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
stagedSrcData, err := os.ReadFile(llm.lastSrc)
Expect(err).ToNot(HaveOccurred())
Expect(stagedSrcData).To(Equal(srcContent))
// Verify output: audio was pulled back
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should stage input audio via FileStagingClient.AudioTranscription (Dst is input)", func() {
llm := &testLLM{}
result, stagingDir, cleanup := setupStagedRoute(llm, "whisper", "whisper-model")
defer cleanup()
defer result.Release()
// Create input audio file
audioDir := GinkgoT().TempDir()
audioContent := []byte("WAV audio data for transcription")
audioPath := filepath.Join(audioDir, "recording.wav")
Expect(os.WriteFile(audioPath, audioContent, 0644)).To(Succeed())
// AudioTranscription uses Dst as the input audio path (confusing naming)
txResult, err := result.Client.AudioTranscription(ctx, &pb.TranscriptRequest{
Dst: audioPath,
})
Expect(err).ToNot(HaveOccurred())
Expect(txResult.Text).To(Equal("transcribed text"))
// Verify input: audio file was staged to the backend
Expect(llm.lastAudioDst).To(ContainSubstring(stagingDir))
stagedAudioData, err := os.ReadFile(llm.lastAudioDst)
Expect(err).ToNot(HaveOccurred())
Expect(stagedAudioData).To(Equal(audioContent))
})
It("should translate TTS Model path to remote worker path", func() {
outputData := []byte("WAV audio generated by TTS backend")
llm := &testLLM{dstOutput: outputData}
// Set up real file transfer server so model staging preserves directory structure
modelsDir := GinkgoT().TempDir()
stagingDir := GinkgoT().TempDir()
dataDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
httpLis, err := net.Listen("tcp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
httpAddr := httpLis.Addr().String()
httpServer, err := nodes.StartFileTransferServerWithListener(httpLis, stagingDir, modelsDir, dataDir, "", 0)
Expect(err).ToNot(HaveOccurred())
defer nodes.ShutdownFileTransferServer(httpServer)
node := &nodes.BackendNode{Name: "tts-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Create model files on the "frontend"
frontendModelsDir := GinkgoT().TempDir()
modelContent := []byte("fake onnx model data")
configContent := []byte(`{"audio":{"sample_rate":22050}}`)
modelFile := filepath.Join(frontendModelsDir, "it-paola-medium.onnx")
configFile := filepath.Join(frontendModelsDir, "it-paola-medium.onnx.json")
Expect(os.WriteFile(modelFile, modelContent, 0644)).To(Succeed())
Expect(os.WriteFile(configFile, configContent, 0644)).To(Succeed())
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
// Route with ModelFile pointing to the .onnx file (triggers model staging)
result, err := router.Route(ctx, "voice-it-paola-medium", "it-paola-medium.onnx", "piper", &pb.ModelOptions{
Model: "it-paola-medium.onnx",
ModelFile: modelFile,
}, false)
Expect(err).ToNot(HaveOccurred())
defer result.Release()
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "speech.wav")
// Simulate what core/backend/tts.go does: construct Model path using frontend ModelPath
frontendModelPath := filepath.Join(frontendModelsDir, "it-paola-medium.onnx")
ttsResult, err := result.Client.TTS(ctx, &pb.TTSRequest{
Text: "Hello world",
Model: frontendModelPath, // frontend absolute path — should be translated to remote
Dst: frontendDst,
})
Expect(err).ToNot(HaveOccurred())
Expect(ttsResult.Success).To(BeTrue())
// Verify: the backend received the remote worker path, NOT the frontend path
Expect(llm.lastTTSModel).ToNot(Equal(frontendModelPath))
// The remote path should be under the worker's models dir with the tracking key
Expect(llm.lastTTSModel).To(ContainSubstring("voice-it-paola-medium"))
Expect(llm.lastTTSModel).To(HaveSuffix("it-paola-medium.onnx"))
// Verify the model file exists at the translated path (already staged during LoadModel)
stagedModelData, err := os.ReadFile(llm.lastTTSModel)
Expect(err).ToNot(HaveOccurred())
Expect(stagedModelData).To(Equal(modelContent))
// Verify the companion .onnx.json is next to it (staged during LoadModel)
companionPath := llm.lastTTSModel + ".json"
stagedConfigData, err := os.ReadFile(companionPath)
Expect(err).ToNot(HaveOccurred())
Expect(stagedConfigData).To(Equal(configContent))
// Verify output: audio was pulled back
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should stage companion .onnx.json files alongside .onnx model files", func() {
llm := &testLLM{}
modelsDir := GinkgoT().TempDir()
stagingDir := GinkgoT().TempDir()
dataDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
// Use the real file transfer server (preserves directory structure)
httpLis, err := net.Listen("tcp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
httpAddr := httpLis.Addr().String()
httpServer, err := nodes.StartFileTransferServerWithListener(httpLis, stagingDir, modelsDir, dataDir, "", 0)
Expect(err).ToNot(HaveOccurred())
defer nodes.ShutdownFileTransferServer(httpServer)
node := &nodes.BackendNode{Name: "companion-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Create model files: .onnx and .onnx.json in a temp "models" dir
frontendModelsDir := GinkgoT().TempDir()
modelContent := []byte("fake onnx model")
configContent := []byte(`{"audio":{"sample_rate":22050}}`)
modelFile := filepath.Join(frontendModelsDir, "my-model.onnx")
configFile := filepath.Join(frontendModelsDir, "my-model.onnx.json")
Expect(os.WriteFile(modelFile, modelContent, 0644)).To(Succeed())
Expect(os.WriteFile(configFile, configContent, 0644)).To(Succeed())
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
// Route with ModelFile pointing to the .onnx file
result, err := router.Route(ctx, "piper-companion-test", "my-model.onnx", "piper", &pb.ModelOptions{
Model: "my-model.onnx",
ModelFile: modelFile,
}, false)
Expect(err).ToNot(HaveOccurred())
defer result.Release()
// Verify: both .onnx and .onnx.json were staged to the worker's models dir
stagedOnnx := filepath.Join(modelsDir, "piper-companion-test", "my-model.onnx")
stagedConfig := filepath.Join(modelsDir, "piper-companion-test", "my-model.onnx.json")
stagedOnnxData, err := os.ReadFile(stagedOnnx)
Expect(err).ToNot(HaveOccurred(), "companion .onnx model should be staged")
Expect(stagedOnnxData).To(Equal(modelContent))
stagedConfigData, err := os.ReadFile(stagedConfig)
Expect(err).ToNot(HaveOccurred(), "companion .onnx.json config should be staged alongside model")
Expect(stagedConfigData).To(Equal(configContent))
})
})