mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-01 04:28:59 -04:00
* feat(distributed): add configurable NATS backend install/upgrade timeouts Adds BackendInstallTimeout and BackendUpgradeTimeout to DistributedConfig with 15m defaults, following the existing MCPToolTimeout / WorkerWaitTimeout pattern. These will replace the hardcoded literals in RemoteUnloaderAdapter so admin-driven backend installs across the cluster survive long OCI image pulls that previously timed out at 3m. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * style(distributed): gofmt alignment after timeout fields Re-aligns the Validate() negative-duration map and the Default* const block so the new BackendInstall/UpgradeTimeout entries do not leave the surrounding columns mis-padded. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(cli): surface LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT and _UPGRADE_TIMEOUT Parses the two new env vars on the run CLI and threads them through the existing AppOption builder so DistributedConfig picks them up. Invalid duration strings now fail loudly at startup rather than silently falling back to the default. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): inject NATS install/upgrade timeouts into RemoteUnloaderAdapter Removes the hardcoded 3m / 15m literals from RemoteUnloaderAdapter and threads in DistributedConfig.BackendInstallTimeoutOrDefault() and BackendUpgradeTimeoutOrDefault() at construction. Install now defaults to 15m (was 3m); cold OCI image pulls on Jetson Wi-Fi routinely blew past the old ceiling. Scripted messaging client captures the timeout so tests can assert the configured value actually reaches the NATS request. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): introduce galleryop.ErrWorkerStillInstalling sentinel When the NATS request-reply for backend.install (or .upgrade) times out the worker is almost always still pulling the OCI image. Wrap the timeout in a typed sentinel so the manager above can distinguish "worker hung" from "worker still working" and leave the pending_backend_ops row in place for the reconciler to confirm via backend.list. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): treat NATS install timeout as in-progress, not failure When a worker times out replying to backend.install but the install is still running on the worker, enqueueAndDrainBackendOp now reports a running_on_worker status and pushes NextRetryAt out by the install timeout so the reconciler does not immediately re-fire another install while the worker is still pulling the image. The pending_backend_ops row stays in place for the next reconciler pass to confirm via backend.list. InstallBackend wraps the result in galleryop.ErrWorkerStillInstalling so callers can branch (galleryop renders yellow in-progress instead of red error). UpgradeBackend uses the same wrap. Adds RemoteUnloaderAdapter.InstallTimeout() so the manager can push NextRetryAt by the configured timeout without reaching into a private field, and NodeRegistry.RecordPendingBackendOpInFlight as the soft cousin of RecordPendingBackendOpFailure. Also includes incidental gofmt-driven struct-field alignment in registry.go on lines unrelated to the change (touched files are re-formatted to canonical form per project policy). Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(distributed): don't increment Attempts on in-flight install timeout An in-flight timeout (worker still pulling the OCI image) is not a failed attempt, it's a delayed one. Incrementing Attempts let genuinely-progressing slow installs (e.g. 30 GB CUDA images on Wi-Fi) trip the reconciler's maxPendingBackendOpAttempts cap and dead-letter the queue row while the worker was still legitimately working. RecordPendingBackendOpInFlight now only updates LastError and NextRetryAt. Also documents "running_on_worker" in the NodeOpStatus.Status enum comment so Task 6 implementers see the full surface. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(galleryop): surface ErrWorkerStillInstalling as non-error OpStatus When the distributed backend manager returns an error that wraps ErrWorkerStillInstalling, backendHandler now completes the op with a "still installing in background" message rather than marking it as a red failure. Admin UI sees a yellow in-progress state; reconciler confirms completion on its next pass. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * test(distributed): end-to-end install-timeout-then-reconcile Wires Task 1-6 end-to-end so any seam mismatch surfaces in CI rather than during a real cluster install. NATS times out, the queue row stays alive with running_on_worker status, the worker eventually reports the backend installed via backend.list, the manager surfaces it via ListBackends. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * docs(distributed): document LOCALAI_NATS_BACKEND_INSTALL_TIMEOUT / _UPGRADE_TIMEOUT Add the two new operator-tunable env vars to the Frontend Configuration table in the distributed-mode docs. Explains the 15m default, when to raise it (slow links pulling multi-GB OCI images), and the new "still installing in background" admin-UI state when the round-trip times out but the worker is still working. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): clear pending install rows when backend.list confirms DistributedBackendManager.ListBackends now proactively clears pending_backend_ops install rows whose (nodeID, backend) is reported installed by backend.list. Operator UI updates immediately instead of waiting up to installTimeout (default 15m) for the next reconciler tick after NextRetryAt. Only install rows are cleared; upgrade and delete intents are not satisfied by presence in backend.list and continue to drain through their normal reconciler paths. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(messaging): add BackendInstallProgressEvent wire type and subject New NATS subject nodes.<nodeID>.backend.install.<opID>.progress lets the worker publish transient progress events (file, current/total bytes, percentage, phase) while a long-running install pulls its OCI image. BackendInstallRequest gains an optional OpID field so the worker knows which subject to publish on. Transient pub/sub (not JetStream): the install reply remains ground truth for success/failure; dropped progress events are tolerable. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * style(messaging): drop em-dash from BackendInstallProgress test comment Per project convention (no em-dashes anywhere). Comment substance is unchanged. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): worker publishes debounced install progress over NATS When BackendInstallRequest.OpID is set, the worker's backend.install handler wires a debounced publisher (250ms window) into the gallery download callback. Each tick becomes a BackendInstallProgressEvent on nodes.<nodeID>.backend.install.<opID>.progress; the publisher always emits a final event on Flush so the UI sees the terminal percentage. Old masters that do not set OpID continue to run silent installs: no behavior change for them. Lock ordering: the publisher releases its mutex before calling messaging.Publish so a slow network never stalls the install loop. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): RemoteUnloaderAdapter subscribes to install progress InstallBackend gains opID + onProgress parameters. When both are set, the adapter subscribes to nodes.<nodeID>.backend.install.<opID>.progress BEFORE publishing the install request, decodes each message into the caller's onProgress callback in a goroutine (so a slow callback never stalls the NATS reader thread), and unsubscribes after RequestJSON returns. When onProgress is nil OR opID is empty (the reconciler retry path), subscription is skipped entirely - silent installs cost nothing extra. Subscribe failure is logged at Warn and the install proceeds without progress streaming; the NATS round-trip still owns terminal status. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): forward backend install progress into galleryop OpStatus DistributedBackendManager.InstallBackend now passes the gallery op ID and a progress bridge into the adapter call. Each BackendInstallProgressEvent from the worker becomes a galleryop.ProgressCallback tick - which the existing backendHandler already turns into OpStatus.UpdateStatus, so the admin UI/SSE polling sees per-byte progress for distributed installs without any UI-side change. UpgradeBackend is intentionally left silent for now: its wire request (BackendUpgradeRequest) does not carry OpID, and rolling-update fallback is the rarer path. Will be picked up in a follow-up if the worker upgrade path also gets a progress channel. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * test(distributed): InstallBackend tolerates silent (pre-Phase-2) workers A worker on pre-Phase-2 code never publishes progress events. The new master subscribes optimistically; this spec pins that a silent worker still produces a green install with no progressCb ticks. The install reply is the source of truth for terminal state; the progress stream is a best-effort UX enrichment. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * docs(distributed): document install progress streaming Note the new nodes.<nodeID>.backend.install.<opID>.progress subject and the silent-worker compatibility behavior so operators know to expect real-time progress and what happens on a mixed-version cluster. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * docs(distributed): note progress-event ordering trade-off in InstallBackend Document near the goroutine dispatch why ordering at the consumer is best-effort, why it rarely matters in practice (worker debounce >> goroutine jitter), and what a future hardening pass would look like (Seq field + stale-by-seq drop). Stops the next reader from accidentally "fixing" the goroutine pool away. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(galleryop): add NodeProgress + OpStatus.Nodes for per-node breakdown Adds the data model the UI needs to render an expandable per-node breakdown of a fanned-out backend install. NodeProgress carries node identity (ID + name), per-node status (queued / running_on_worker / success / error / downloading), the current file + bytes + percentage from the Phase 2 progress stream, and any per-node error. OpStatus.Nodes is the slice the /api/operations handler will surface in a follow-up. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(galleryop): UpdateNodeProgress merges per-node ticks by NodeID GalleryService.UpdateNodeProgress(opID, nodeID, np) merges a NodeProgress into OpStatus.Nodes (keyed by NodeID, no duplicates) and mirrors the latest tick into the aggregate Progress / FileName / DownloadedFileSize / TotalFileSize fields so the legacy single-bar OperationsBar view keeps working unchanged alongside the new per-node breakdown. Concurrent-safe via the existing g.Mutex. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(distributed): write per-node OpStatus entries during install fan-out DistributedBackendManager now accepts a nodeProgressSink and feeds it two streams: 1. enqueueAndDrainBackendOp emits a per-node terminal entry on each status it appends to BackendOpResult (queued, success, error, running_on_worker). The opID is threaded through the function so the sink gets the right gallery op identity. 2. The install apply closure fans each BackendInstallProgressEvent into the sink as a downloading entry, alongside the legacy progressCb path so the aggregate single-bar view stays correct. Production wiring passes the GalleryService (which implements UpdateNodeProgress via Task 2) as the sink. Single-node tests pass nil. DeleteBackend and UpgradeBackend pass an empty opID so the sink path no-ops for ops that aren't gallery-tracked the same way as Install. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(operations): expose per-node breakdown on /api/operations When an operation's OpStatus has Nodes entries (populated by the Phase 4 progress sink wiring), surface them as a "nodes" array on the /api/operations response, sorted by node_name for stable rendering. Backward compatible: legacy clients ignore the field; ops without any node entries (single-node mode, model installs) omit the array entirely thanks to the empty-slice guard. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(ui): per-node breakdown in OperationsBar When an install op fans out to more than one worker, the operations bar now shows a "N nodes" chevron that expands into a per-node list. Each row carries the node's status (color-coded pill), the current file being downloaded, byte counts, percentage, and a thin per-node progress bar. Yellow "Worker busy" pill marks running_on_worker status with a tooltip explaining the NATS round-trip timed out but the worker is still installing in the background. Backward compatible: ops without a nodes field (legacy or single-node mode) render as before. State for expand/collapse is local to the component, keyed by jobID/id - reload starts collapsed. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * docs(distributed): document per-node breakdown in the operations bar Adds a short subsection covering the expandable "N nodes" chevron in the OperationsBar admin UI, the meaning of each status pill, and how it relates to the /api/operations nodes array. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(galleryop): UpdateStatus preserves Nodes when caller sends none Real-world bug surfaced by the Phase 4 multi-worker smoke test: the nodes[] array in /api/operations flickered between a single node at a time on a 2-worker install. Root cause: the Phase 2 progress bridge also calls the legacy progressCb -> UpdateStatus(&OpStatus{...}) on every tick. UpdateStatus then overwrote the entire status pointer, wiping the Nodes slice that UpdateNodeProgress had just merged in. Fix: in UpdateStatus, if the incoming op has an empty Nodes slice, carry forward the previous status's Nodes before storing. Callers that explicitly populate Nodes still win (their slice replaces the prior one, no merge across the two code paths). Two regression specs added pinning both directions of the contract. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * docs(distributed): strip implementation details from user-facing docs Trim the new install/upgrade timeout rows and the install-progress sections to focus on what the operator sees and tunes. Drops: - the NATS subject names and pub/sub mechanics - "round-trip" / reconciler / backend.list jargon - /api/operations polling cadence - "pre-2026-05-22" version references Reframes the breakdown text around the admin UI (Operations Bar, chevron, status pills, "Worker busy" tooltip). Implementation context lives in the agent notes and code comments. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(config): move DistributedConfig.Validate flag names to constants The negative-duration check map was a wall of literal kebab-case strings that had to stay in sync with the kong-derived CLI flag names manually. Move them to a Flag* const block alongside the existing Default* block so a rename of either the Go field or the CLI naming convention forces a compile error rather than silent drift. Sole consumer today is Validate; the constants are exported so future operator-facing surfaces (e.g. error messages on other validation paths) can reference them by name instead of repeating the literals. Tests pin both the literal values (so a future "let's just rename this" doesn't accidentally regress the CLI flag) and the negative- duration error message for the new BackendInstall / BackendUpgrade fields. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(distributed): extract NodeStatus and Phase enums to constants Sweep for the same literal-string-as-identifier pattern called out on the Validate flag names: the per-node install status enum ("queued" | "downloading" | "running_on_worker" | "success" | "error") appeared as raw literals across managers_distributed.go (10+ sites, including 3 separate `n.Status == "running_on_worker"` checks), operation.go, and the test suite. Same shape for the Phase enum ("resolving" | "downloading" | "extracting" | "starting") in the worker-side progress publisher. Promote both to exported const blocks: - galleryop.NodeStatus{Queued,Downloading,RunningOnWorker,Success,Error} shared between galleryop.NodeProgress.Status (the wire field) and nodes.NodeOpStatus.Status (the in-process per-node summary) - messaging.Phase{Resolving,Downloading,Extracting,Starting} shared between the worker publisher and any future consumer that needs to switch on phase Tests pin both the literal values (so a future "let's just rename" doesn't silently change the JSON wire) and use the constants in setup (so the producer side stays drift-protected). Wire-format assertions on the /api/operations JSON output keep their literals deliberately, so the constant value can never silently diverge from what the UI receives. Out of scope for this PR (separate cleanup): the finetune and quantization job-status enums have the same anti-pattern with 14+ literal sites each, but predate this PR's work. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
985 lines
34 KiB
Go
985 lines
34 KiB
Go
package distributed_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/mudler/LocalAI/core/services/messaging"
|
|
"github.com/mudler/LocalAI/core/services/nodes"
|
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
|
|
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
"google.golang.org/grpc"
|
|
|
|
pgdriver "gorm.io/driver/postgres"
|
|
gormDB "gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// testLLM is a minimal AIModel implementation for testing.
|
|
// Override methods to write output to Dst so we can test the full
|
|
// FileStagingClient round-trip (upload inputs + download outputs).
|
|
type testLLM struct {
|
|
base.Base
|
|
loaded bool
|
|
lastModel string
|
|
// dstOutput is the content written to any Dst path by output-producing methods.
|
|
dstOutput []byte
|
|
// lastSrc records the last Src/input path seen (for verifying staging rewrote it).
|
|
lastSrc string
|
|
// lastAudioDst records the Dst field from AudioTranscription (it's an input, not output).
|
|
lastAudioDst string
|
|
// lastTTSModel records the Model field from TTS requests (for verifying path rewriting).
|
|
lastTTSModel string
|
|
}
|
|
|
|
func (t *testLLM) Load(opts *pb.ModelOptions) error {
|
|
t.loaded = true
|
|
t.lastModel = opts.ModelFile
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) Predict(opts *pb.PredictOptions) (string, error) {
|
|
if !t.loaded {
|
|
return "", fmt.Errorf("model not loaded")
|
|
}
|
|
return "test response from remote node", nil
|
|
}
|
|
|
|
func (t *testLLM) GenerateImage(req *pb.GenerateImageRequest) error {
|
|
t.lastSrc = req.Src
|
|
if req.Dst != "" && len(t.dstOutput) > 0 {
|
|
return os.WriteFile(req.Dst, t.dstOutput, 0644)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) GenerateVideo(req *pb.GenerateVideoRequest) error {
|
|
t.lastSrc = req.StartImage
|
|
if req.Dst != "" && len(t.dstOutput) > 0 {
|
|
return os.WriteFile(req.Dst, t.dstOutput, 0644)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) TTS(req *pb.TTSRequest) error {
|
|
t.lastTTSModel = req.Model
|
|
if req.Dst != "" && len(t.dstOutput) > 0 {
|
|
return os.WriteFile(req.Dst, t.dstOutput, 0644)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) SoundGeneration(req *pb.SoundGenerationRequest) error {
|
|
if req.Src != nil {
|
|
t.lastSrc = *req.Src
|
|
}
|
|
if req.Dst != "" && len(t.dstOutput) > 0 {
|
|
return os.WriteFile(req.Dst, t.dstOutput, 0644)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
|
t.lastAudioDst = req.Dst
|
|
return pb.TranscriptResult{Text: "transcribed text"}, nil
|
|
}
|
|
|
|
// startTestGRPCServer starts a real gRPC backend server on a free port
|
|
// and returns the address and cleanup function.
|
|
func startTestGRPCServer(llm grpcPkg.AIModel) (string, func(), error) {
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
addr := lis.Addr().String()
|
|
|
|
s := grpc.NewServer(
|
|
grpc.MaxRecvMsgSize(50*1024*1024),
|
|
grpc.MaxSendMsgSize(50*1024*1024),
|
|
)
|
|
pb.RegisterBackendServer(s, grpcPkg.NewBackendServer(llm))
|
|
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
_ = s.Serve(lis)
|
|
}()
|
|
|
|
cleanup := func() {
|
|
s.GracefulStop()
|
|
}
|
|
return addr, cleanup, nil
|
|
}
|
|
|
|
// startTestHTTPFileServer starts a test HTTP file transfer server (mirroring serve_backend_http.go)
|
|
// on a free port and returns the address and cleanup function.
|
|
func startTestHTTPFileServer(stagingDir string) (string, func(), error) {
|
|
if err := os.MkdirAll(stagingDir, 0750); err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/v1/files/", func(w http.ResponseWriter, r *http.Request) {
|
|
key := r.URL.Path[len("/v1/files/"):]
|
|
switch r.Method {
|
|
case http.MethodPut:
|
|
safeName := filepath.Base(key)
|
|
dstPath := filepath.Join(stagingDir, safeName)
|
|
f, err := os.Create(dstPath)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
io.Copy(f, r.Body)
|
|
f.Close()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprintf(w, `{"local_path":%q}`, dstPath)
|
|
case http.MethodGet:
|
|
safeName := filepath.Base(key)
|
|
srcPath := filepath.Join(stagingDir, safeName)
|
|
if _, statErr := os.Stat(srcPath); os.IsNotExist(statErr) {
|
|
// AllocRemoteTemp creates files under stagingDir/tmp/
|
|
srcPath = filepath.Join(stagingDir, "tmp", safeName)
|
|
}
|
|
f, err := os.Open(srcPath)
|
|
if err != nil {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
defer f.Close()
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
io.Copy(w, f)
|
|
case http.MethodPost:
|
|
if key == "temp" {
|
|
tmpDir := filepath.Join(stagingDir, "tmp")
|
|
os.MkdirAll(tmpDir, 0750)
|
|
f, err := os.CreateTemp(tmpDir, "output-*")
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
localPath := f.Name()
|
|
f.Close()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprintf(w, `{"local_path":%q}`, localPath)
|
|
} else {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
}
|
|
}
|
|
})
|
|
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
httpAddr := lis.Addr().String()
|
|
srv := &http.Server{Handler: mux}
|
|
go srv.Serve(lis)
|
|
|
|
cleanup := func() {
|
|
srv.Close()
|
|
}
|
|
return httpAddr, cleanup, nil
|
|
}
|
|
|
|
var _ = Describe("Full Distributed Inference Flow", Label("Distributed"), func() {
|
|
var (
|
|
infra *TestInfra
|
|
cancel context.CancelFunc
|
|
ctx context.Context
|
|
db *gormDB.DB
|
|
registry *nodes.NodeRegistry
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
infra = SetupInfra("localai_fullflow_test")
|
|
ctx, cancel = context.WithTimeout(infra.Ctx, 2*time.Minute)
|
|
|
|
var err error
|
|
db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
registry, err = nodes.NewNodeRegistry(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
AfterEach(func() {
|
|
cancel()
|
|
})
|
|
|
|
// newTestSmartRouter creates a SmartRouter with NATS wired up and a mock
|
|
// backend.install handler that always replies success for all registered nodes.
|
|
newTestSmartRouter := func(reg *nodes.NodeRegistry, extraOpts ...nodes.SmartRouterOptions) *nodes.SmartRouter {
|
|
unloader := nodes.NewRemoteUnloaderAdapter(reg, infra.NC, 3*time.Minute, 15*time.Minute)
|
|
|
|
opts := nodes.SmartRouterOptions{
|
|
Unloader: unloader,
|
|
}
|
|
if len(extraOpts) > 0 {
|
|
o := extraOpts[0]
|
|
if o.FileStager != nil {
|
|
opts.FileStager = o.FileStager
|
|
}
|
|
if o.GalleriesJSON != "" {
|
|
opts.GalleriesJSON = o.GalleriesJSON
|
|
}
|
|
if o.AuthToken != "" {
|
|
opts.AuthToken = o.AuthToken
|
|
}
|
|
if o.DB != nil {
|
|
opts.DB = o.DB
|
|
}
|
|
}
|
|
|
|
router := nodes.NewSmartRouter(reg, opts)
|
|
|
|
// Subscribe a mock backend.install handler that replies success for any node.
|
|
// We use a wildcard-style approach: subscribe to all nodes' install subjects
|
|
// by registering after each node. In practice, we rely on the test registering
|
|
// nodes before calling Route, so we subscribe to a catch-all pattern.
|
|
infra.NC.Conn().Subscribe("nodes.*.backend.install", func(msg *nats.Msg) {
|
|
reply := messaging.BackendInstallReply{Success: true}
|
|
data, _ := json.Marshal(reply)
|
|
msg.Respond(data)
|
|
})
|
|
|
|
return router
|
|
}
|
|
// suppress unused warning in case some tests don't call it
|
|
_ = newTestSmartRouter
|
|
|
|
It("should route inference to a registered node with a real gRPC backend", func() {
|
|
// 1. Start a mock gRPC backend
|
|
llm := &testLLM{}
|
|
addr, cleanup, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup()
|
|
|
|
// 2. Register it as a node
|
|
node := &nodes.BackendNode{
|
|
Name: "test-gpu-1",
|
|
Address: addr,
|
|
}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// 3. Create SmartRouter and route a request
|
|
router := newTestSmartRouter(registry)
|
|
|
|
// The model is not loaded yet, so Route will pick the node and call LoadModel
|
|
result, err := router.Route(ctx, "", "test-model", "llama-cpp", &pb.ModelOptions{
|
|
Model: "test-model",
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.Name).To(Equal("test-gpu-1"))
|
|
|
|
// 4. Verify the model was loaded on the backend
|
|
Expect(llm.loaded).To(BeTrue())
|
|
|
|
// 5. Use the client to call Predict
|
|
reply, err := result.Client.Predict(ctx, &pb.PredictOptions{
|
|
Prompt: "Hello world",
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(string(reply.Message)).To(Equal("test response from remote node"))
|
|
|
|
// 6. Release and verify in-flight decremented
|
|
result.Release()
|
|
models, err := registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(models).To(HaveLen(1))
|
|
Expect(models[0].InFlight).To(Equal(0))
|
|
|
|
// 7. Verify model recorded as "loaded" in registry
|
|
Expect(models[0].State).To(Equal("loaded"))
|
|
Expect(models[0].ModelName).To(Equal("test-model"))
|
|
})
|
|
|
|
It("should load-balance across multiple nodes with same model", func() {
|
|
// Start two mock gRPC backends
|
|
llm1 := &testLLM{}
|
|
addr1, cleanup1, err := startTestGRPCServer(llm1)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup1()
|
|
|
|
llm2 := &testLLM{}
|
|
addr2, cleanup2, err := startTestGRPCServer(llm2)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup2()
|
|
|
|
// Register both nodes
|
|
node1 := &nodes.BackendNode{Name: "node-heavy", Address: addr1}
|
|
node2 := &nodes.BackendNode{Name: "node-light", Address: addr2}
|
|
Expect(registry.Register(context.Background(), node1, true)).To(Succeed())
|
|
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
|
|
|
|
// Set both as having the model loaded
|
|
Expect(registry.SetNodeModel(context.Background(), node1.ID, "test-model", 0, "loaded", "", 0)).To(Succeed())
|
|
Expect(registry.SetNodeModel(context.Background(), node2.ID, "test-model", 0, "loaded", "", 0)).To(Succeed())
|
|
|
|
// Set node-1 with high in-flight (5), node-2 with low in-flight (1)
|
|
for range 5 {
|
|
Expect(registry.IncrementInFlight(context.Background(), node1.ID, "test-model", 0)).To(Succeed())
|
|
}
|
|
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "test-model", 0)).To(Succeed())
|
|
|
|
// Route should pick node-2 (least loaded) thanks to ORDER BY in_flight ASC
|
|
router := newTestSmartRouter(registry)
|
|
result, err := router.Route(ctx, "", "test-model", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.Name).To(Equal("node-light"))
|
|
result.Release()
|
|
})
|
|
|
|
It("should load model on empty node when no node has it", func() {
|
|
// Start a mock gRPC backend
|
|
llm := &testLLM{}
|
|
addr, cleanup, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup()
|
|
|
|
// Register a node with NO models loaded
|
|
node := &nodes.BackendNode{Name: "empty-node", Address: addr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Route should pick this node and call LoadModel on it
|
|
router := newTestSmartRouter(registry)
|
|
result, err := router.Route(ctx, "", "new-model", "llama-cpp", &pb.ModelOptions{
|
|
Model: "new-model",
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.Name).To(Equal("empty-node"))
|
|
Expect(llm.loaded).To(BeTrue())
|
|
|
|
// Verify model is now recorded in registry
|
|
models, err := registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(models).To(HaveLen(1))
|
|
Expect(models[0].ModelName).To(Equal("new-model"))
|
|
Expect(models[0].State).To(Equal("loaded"))
|
|
|
|
result.Release()
|
|
})
|
|
|
|
It("should unload remote model via NATS", func() {
|
|
// Register a node with a loaded model
|
|
node := &nodes.BackendNode{Name: "gpu-unload", Address: "127.0.0.1:50099"}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
Expect(registry.SetNodeModel(context.Background(), node.ID, "old-model", 0, "loaded", "", 0)).To(Succeed())
|
|
|
|
// Subscribe to NATS backend.stop for this node
|
|
stopSubject := messaging.SubjectNodeBackendStop(node.ID)
|
|
received := make(chan struct{}, 1)
|
|
rawConn, err := nats.Connect(infra.NatsURL)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer rawConn.Close()
|
|
|
|
_, err = rawConn.Subscribe(stopSubject, func(msg *nats.Msg) {
|
|
received <- struct{}{}
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Create RemoteUnloaderAdapter and unload model
|
|
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC, 3*time.Minute, 15*time.Minute)
|
|
err = unloader.UnloadRemoteModel("old-model")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Verify NATS event received
|
|
Eventually(received, 5*time.Second).Should(Receive())
|
|
|
|
// Verify model removed from registry
|
|
models, err := registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(models).To(BeEmpty())
|
|
})
|
|
|
|
It("should integrate ModelRouterAdapter with SmartRouter end-to-end", func() {
|
|
// Start a mock gRPC backend
|
|
llm := &testLLM{}
|
|
addr, cleanup, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup()
|
|
|
|
// Register node
|
|
node := &nodes.BackendNode{Name: "adapter-node", Address: addr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Create SmartRouter + ModelRouterAdapter
|
|
router := newTestSmartRouter(registry)
|
|
adapter := nodes.NewModelRouterAdapter(router)
|
|
|
|
// Call adapter.Route() (same signature ModelLoader uses)
|
|
m, err := adapter.Route(ctx, "llama-cpp", "test-model-id", "test-model", "",
|
|
&pb.ModelOptions{Model: "test-model"}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(m).ToNot(BeNil())
|
|
|
|
// Verify returned Model has correct ID and nil process (remote)
|
|
Expect(m.ID).To(Equal("test-model-id"))
|
|
Expect(m.Process()).To(BeNil())
|
|
|
|
// Verify the model was loaded on the backend
|
|
Expect(llm.loaded).To(BeTrue())
|
|
|
|
// Use the Model's GRPC() method to get a client and verify inference works
|
|
client := m.GRPC(false, nil)
|
|
Expect(client).ToNot(BeNil())
|
|
reply, err := client.Predict(ctx, &pb.PredictOptions{Prompt: "test"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(string(reply.Message)).To(Equal("test response from remote node"))
|
|
|
|
// Release the model via adapter
|
|
adapter.ReleaseModel("test-model-id")
|
|
})
|
|
|
|
It("should stage model files via HTTP when routing to a new node", func() {
|
|
// Create a real model file on disk
|
|
modelDir := GinkgoT().TempDir()
|
|
modelContent := []byte("fake GGUF model data — this is test content for file transfer verification")
|
|
modelPath := filepath.Join(modelDir, "model.gguf")
|
|
Expect(os.WriteFile(modelPath, modelContent, 0644)).To(Succeed())
|
|
|
|
mmprojContent := []byte("fake mmproj data for multimodal projection")
|
|
mmprojPath := filepath.Join(modelDir, "mmproj.bin")
|
|
Expect(os.WriteFile(mmprojPath, mmprojContent, 0644)).To(Succeed())
|
|
|
|
// Start a real gRPC backend server (for AI RPCs) and HTTP server (for file transfer)
|
|
llm := &testLLM{}
|
|
stagingDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupHTTP()
|
|
|
|
// Register the node in PostgreSQL
|
|
node := &nodes.BackendNode{Name: "staging-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Create HTTPFileStager that resolves node IDs to HTTP addresses
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Create SmartRouter with the HTTPFileStager
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
// Route with ModelOptions that have file paths — SmartRouter should stage them
|
|
result, err := router.Route(ctx, "", "staged-model", "llama-cpp", &pb.ModelOptions{
|
|
Model: "staged-model",
|
|
ModelFile: modelPath,
|
|
MMProj: mmprojPath,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.Name).To(Equal("staging-node"))
|
|
|
|
// Verify the model file bytes were transferred to the backend's staging dir
|
|
stagedModelPath := filepath.Join(stagingDir, "model.gguf")
|
|
stagedModelData, err := os.ReadFile(stagedModelPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedModelData).To(Equal(modelContent))
|
|
|
|
stagedMMProjPath := filepath.Join(stagingDir, "mmproj.bin")
|
|
stagedMMProjData, err := os.ReadFile(stagedMMProjPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedMMProjData).To(Equal(mmprojContent))
|
|
|
|
// Verify LoadModel was called with the rewritten (remote) paths
|
|
Expect(llm.loaded).To(BeTrue())
|
|
Expect(llm.lastModel).To(Equal(stagedModelPath))
|
|
|
|
// Verify Predict still works through the FileStagingClient wrapper
|
|
reply, err := result.Client.Predict(ctx, &pb.PredictOptions{
|
|
Prompt: "test via staging client",
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(string(reply.Message)).To(Equal("test response from remote node"))
|
|
|
|
result.Release()
|
|
})
|
|
|
|
It("should stage multimodal input files via HTTP through FileStagingClient", func() {
|
|
// Create a real image file on disk
|
|
imageDir := GinkgoT().TempDir()
|
|
imageContent := []byte("fake JPEG image data for multimodal testing")
|
|
imagePath := filepath.Join(imageDir, "photo.jpg")
|
|
Expect(os.WriteFile(imagePath, imageContent, 0644)).To(Succeed())
|
|
|
|
// Start gRPC server (AI RPCs) and HTTP server (file transfer)
|
|
llm := &testLLM{}
|
|
stagingDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupHTTP()
|
|
|
|
// Register node
|
|
node := &nodes.BackendNode{Name: "mm-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Create HTTPFileStager
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Create SmartRouter with FileStager
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
// Route with ModelOptions — triggers LoadModel on the node
|
|
modelDir := GinkgoT().TempDir()
|
|
modelPath := filepath.Join(modelDir, "vision.gguf")
|
|
Expect(os.WriteFile(modelPath, []byte("vision model data"), 0644)).To(Succeed())
|
|
|
|
result, err := router.Route(ctx, "", "vision-model", "llama-cpp", &pb.ModelOptions{
|
|
Model: "vision-model",
|
|
ModelFile: modelPath,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Verify LoadModel was called (model file was staged)
|
|
Expect(llm.loaded).To(BeTrue())
|
|
|
|
// Now call Predict with image file paths — FileStagingClient should stage them
|
|
_, err = result.Client.Predict(ctx, &pb.PredictOptions{
|
|
Prompt: "describe this image",
|
|
Images: []string{imagePath},
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Verify the image file was actually transferred to the backend staging dir
|
|
stagedImagePath := filepath.Join(stagingDir, "photo.jpg")
|
|
stagedImageData, err := os.ReadFile(stagedImagePath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedImageData).To(Equal(imageContent))
|
|
|
|
result.Release()
|
|
})
|
|
|
|
It("should transfer output files back via HTTP", func() {
|
|
// Start gRPC server (AI RPCs) and HTTP server (file transfer)
|
|
llm := &testLLM{}
|
|
stagingDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupHTTP()
|
|
|
|
// Register node
|
|
node := &nodes.BackendNode{Name: "output-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Create HTTPFileStager
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Test AllocRemoteTemp + FetchRemote directly (the output retrieval path)
|
|
remoteTmpPath, err := stager.AllocRemoteTemp(ctx, node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(remoteTmpPath).ToNot(BeEmpty())
|
|
|
|
// Simulate backend writing output to the temp path
|
|
outputContent := []byte("generated image output data from the backend")
|
|
Expect(os.WriteFile(remoteTmpPath, outputContent, 0644)).To(Succeed())
|
|
|
|
// FetchRemote pulls the file from the backend to a local path
|
|
localOutputDir := GinkgoT().TempDir()
|
|
localOutputPath := filepath.Join(localOutputDir, "output.png")
|
|
err = stager.FetchRemote(ctx, node.ID, remoteTmpPath, localOutputPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Verify the output file was retrieved with correct content
|
|
retrievedData, err := os.ReadFile(localOutputPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputContent))
|
|
})
|
|
|
|
// --- Full round-trip tests for every FileStagingClient src/dst path ---
|
|
|
|
// Helper: creates an HTTPFileStager + SmartRouter, registers a node,
|
|
// and routes to it. Returns the RouteResult (with FileStagingClient) and cleanup.
|
|
setupStagedRoute := func(llm *testLLM, backendType, modelName string) (
|
|
*nodes.RouteResult, string, func(),
|
|
) {
|
|
stagingDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
node := &nodes.BackendNode{Name: modelName + "-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
result, err := router.Route(ctx, "", modelName, backendType, &pb.ModelOptions{
|
|
Model: modelName,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
cleanup := func() {
|
|
cleanupGRPC()
|
|
cleanupHTTP()
|
|
}
|
|
return result, stagingDir, cleanup
|
|
}
|
|
|
|
It("should round-trip output via FileStagingClient.GenerateImage (Src + Dst)", func() {
|
|
outputData := []byte("PNG image generated by the backend - 1024x1024 pixels")
|
|
|
|
llm := &testLLM{dstOutput: outputData}
|
|
result, stagingDir, cleanup := setupStagedRoute(llm, "diffusers", "sd-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
// Create a source image to test input staging (img2img)
|
|
srcDir := GinkgoT().TempDir()
|
|
srcContent := []byte("source image for img2img")
|
|
srcPath := filepath.Join(srcDir, "src.png")
|
|
Expect(os.WriteFile(srcPath, srcContent, 0644)).To(Succeed())
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "generated.png")
|
|
|
|
genResult, err := result.Client.GenerateImage(ctx, &pb.GenerateImageRequest{
|
|
PositivePrompt: "a cat",
|
|
Src: srcPath,
|
|
Dst: frontendDst,
|
|
Height: 1024,
|
|
Width: 1024,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(genResult.Success).To(BeTrue())
|
|
|
|
// Verify input: Src was staged to backend — testLLM.lastSrc should be a staging dir path
|
|
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
|
|
|
|
// Verify the staged input file has correct content
|
|
stagedSrcData, err := os.ReadFile(llm.lastSrc)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedSrcData).To(Equal(srcContent))
|
|
|
|
// Verify output: the generated file was pulled back to the frontend
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should round-trip output via FileStagingClient.GenerateVideo (StartImage + Dst)", func() {
|
|
outputData := []byte("MP4 video generated by the backend")
|
|
|
|
llm := &testLLM{dstOutput: outputData}
|
|
result, stagingDir, cleanup := setupStagedRoute(llm, "diffusers", "vid-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
// Create a start image to test input staging
|
|
imgDir := GinkgoT().TempDir()
|
|
startImageContent := []byte("start frame image data")
|
|
startImagePath := filepath.Join(imgDir, "start.png")
|
|
Expect(os.WriteFile(startImagePath, startImageContent, 0644)).To(Succeed())
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "generated.mp4")
|
|
|
|
genResult, err := result.Client.GenerateVideo(ctx, &pb.GenerateVideoRequest{
|
|
Prompt: "a flying cat",
|
|
StartImage: startImagePath,
|
|
Dst: frontendDst,
|
|
NumFrames: 16,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(genResult.Success).To(BeTrue())
|
|
|
|
// Verify input: StartImage was staged
|
|
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
|
|
stagedStartData, err := os.ReadFile(llm.lastSrc)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedStartData).To(Equal(startImageContent))
|
|
|
|
// Verify output: video was pulled back
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should round-trip output via FileStagingClient.TTS (Dst only)", func() {
|
|
outputData := []byte("WAV audio generated by TTS backend")
|
|
|
|
llm := &testLLM{dstOutput: outputData}
|
|
result, _, cleanup := setupStagedRoute(llm, "piper", "tts-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "speech.wav")
|
|
|
|
ttsResult, err := result.Client.TTS(ctx, &pb.TTSRequest{
|
|
Text: "Hello world",
|
|
Model: "tts-model",
|
|
Dst: frontendDst,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(ttsResult.Success).To(BeTrue())
|
|
|
|
// Verify output: audio was pulled back
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should round-trip via FileStagingClient.SoundGeneration (Src + Dst)", func() {
|
|
outputData := []byte("generated sound effect audio data")
|
|
|
|
llm := &testLLM{dstOutput: outputData}
|
|
result, stagingDir, cleanup := setupStagedRoute(llm, "bark", "soundgen-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
// Create input audio source
|
|
srcDir := GinkgoT().TempDir()
|
|
srcContent := []byte("input audio for sound generation")
|
|
srcPath := filepath.Join(srcDir, "input.wav")
|
|
Expect(os.WriteFile(srcPath, srcContent, 0644)).To(Succeed())
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "output.wav")
|
|
|
|
sgResult, err := result.Client.SoundGeneration(ctx, &pb.SoundGenerationRequest{
|
|
Text: "explosion sound",
|
|
Src: &srcPath,
|
|
Dst: frontendDst,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(sgResult.Success).To(BeTrue())
|
|
|
|
// Verify input: Src was staged
|
|
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
|
|
stagedSrcData, err := os.ReadFile(llm.lastSrc)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedSrcData).To(Equal(srcContent))
|
|
|
|
// Verify output: audio was pulled back
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should stage input audio via FileStagingClient.AudioTranscription (Dst is input)", func() {
|
|
llm := &testLLM{}
|
|
result, stagingDir, cleanup := setupStagedRoute(llm, "whisper", "whisper-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
// Create input audio file
|
|
audioDir := GinkgoT().TempDir()
|
|
audioContent := []byte("WAV audio data for transcription")
|
|
audioPath := filepath.Join(audioDir, "recording.wav")
|
|
Expect(os.WriteFile(audioPath, audioContent, 0644)).To(Succeed())
|
|
|
|
// AudioTranscription uses Dst as the input audio path (confusing naming)
|
|
txResult, err := result.Client.AudioTranscription(ctx, &pb.TranscriptRequest{
|
|
Dst: audioPath,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(txResult.Text).To(Equal("transcribed text"))
|
|
|
|
// Verify input: audio file was staged to the backend
|
|
Expect(llm.lastAudioDst).To(ContainSubstring(stagingDir))
|
|
stagedAudioData, err := os.ReadFile(llm.lastAudioDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedAudioData).To(Equal(audioContent))
|
|
})
|
|
|
|
It("should translate TTS Model path to remote worker path", func() {
|
|
outputData := []byte("WAV audio generated by TTS backend")
|
|
llm := &testLLM{dstOutput: outputData}
|
|
|
|
// Set up real file transfer server so model staging preserves directory structure
|
|
modelsDir := GinkgoT().TempDir()
|
|
stagingDir := GinkgoT().TempDir()
|
|
dataDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
httpLis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
httpAddr := httpLis.Addr().String()
|
|
httpServer, err := nodes.StartFileTransferServerWithListener(httpLis, stagingDir, modelsDir, dataDir, "", 0)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer nodes.ShutdownFileTransferServer(httpServer)
|
|
|
|
node := &nodes.BackendNode{Name: "tts-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Create model files on the "frontend"
|
|
frontendModelsDir := GinkgoT().TempDir()
|
|
modelContent := []byte("fake onnx model data")
|
|
configContent := []byte(`{"audio":{"sample_rate":22050}}`)
|
|
modelFile := filepath.Join(frontendModelsDir, "it-paola-medium.onnx")
|
|
configFile := filepath.Join(frontendModelsDir, "it-paola-medium.onnx.json")
|
|
Expect(os.WriteFile(modelFile, modelContent, 0644)).To(Succeed())
|
|
Expect(os.WriteFile(configFile, configContent, 0644)).To(Succeed())
|
|
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
// Route with ModelFile pointing to the .onnx file (triggers model staging)
|
|
result, err := router.Route(ctx, "voice-it-paola-medium", "it-paola-medium.onnx", "piper", &pb.ModelOptions{
|
|
Model: "it-paola-medium.onnx",
|
|
ModelFile: modelFile,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer result.Release()
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "speech.wav")
|
|
|
|
// Simulate what core/backend/tts.go does: construct Model path using frontend ModelPath
|
|
frontendModelPath := filepath.Join(frontendModelsDir, "it-paola-medium.onnx")
|
|
|
|
ttsResult, err := result.Client.TTS(ctx, &pb.TTSRequest{
|
|
Text: "Hello world",
|
|
Model: frontendModelPath, // frontend absolute path — should be translated to remote
|
|
Dst: frontendDst,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(ttsResult.Success).To(BeTrue())
|
|
|
|
// Verify: the backend received the remote worker path, NOT the frontend path
|
|
Expect(llm.lastTTSModel).ToNot(Equal(frontendModelPath))
|
|
// The remote path should be under the worker's models dir with the tracking key
|
|
Expect(llm.lastTTSModel).To(ContainSubstring("voice-it-paola-medium"))
|
|
Expect(llm.lastTTSModel).To(HaveSuffix("it-paola-medium.onnx"))
|
|
|
|
// Verify the model file exists at the translated path (already staged during LoadModel)
|
|
stagedModelData, err := os.ReadFile(llm.lastTTSModel)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedModelData).To(Equal(modelContent))
|
|
|
|
// Verify the companion .onnx.json is next to it (staged during LoadModel)
|
|
companionPath := llm.lastTTSModel + ".json"
|
|
stagedConfigData, err := os.ReadFile(companionPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedConfigData).To(Equal(configContent))
|
|
|
|
// Verify output: audio was pulled back
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should stage companion .onnx.json files alongside .onnx model files", func() {
|
|
llm := &testLLM{}
|
|
modelsDir := GinkgoT().TempDir()
|
|
stagingDir := GinkgoT().TempDir()
|
|
dataDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
// Use the real file transfer server (preserves directory structure)
|
|
httpLis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
httpAddr := httpLis.Addr().String()
|
|
httpServer, err := nodes.StartFileTransferServerWithListener(httpLis, stagingDir, modelsDir, dataDir, "", 0)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer nodes.ShutdownFileTransferServer(httpServer)
|
|
|
|
node := &nodes.BackendNode{Name: "companion-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Create model files: .onnx and .onnx.json in a temp "models" dir
|
|
frontendModelsDir := GinkgoT().TempDir()
|
|
modelContent := []byte("fake onnx model")
|
|
configContent := []byte(`{"audio":{"sample_rate":22050}}`)
|
|
modelFile := filepath.Join(frontendModelsDir, "my-model.onnx")
|
|
configFile := filepath.Join(frontendModelsDir, "my-model.onnx.json")
|
|
Expect(os.WriteFile(modelFile, modelContent, 0644)).To(Succeed())
|
|
Expect(os.WriteFile(configFile, configContent, 0644)).To(Succeed())
|
|
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
// Route with ModelFile pointing to the .onnx file
|
|
result, err := router.Route(ctx, "piper-companion-test", "my-model.onnx", "piper", &pb.ModelOptions{
|
|
Model: "my-model.onnx",
|
|
ModelFile: modelFile,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer result.Release()
|
|
|
|
// Verify: both .onnx and .onnx.json were staged to the worker's models dir
|
|
stagedOnnx := filepath.Join(modelsDir, "piper-companion-test", "my-model.onnx")
|
|
stagedConfig := filepath.Join(modelsDir, "piper-companion-test", "my-model.onnx.json")
|
|
|
|
stagedOnnxData, err := os.ReadFile(stagedOnnx)
|
|
Expect(err).ToNot(HaveOccurred(), "companion .onnx model should be staged")
|
|
Expect(stagedOnnxData).To(Equal(modelContent))
|
|
|
|
stagedConfigData, err := os.ReadFile(stagedConfig)
|
|
Expect(err).ToNot(HaveOccurred(), "companion .onnx.json config should be staged alongside model")
|
|
Expect(stagedConfigData).To(Equal(configContent))
|
|
})
|
|
})
|