Files
LocalAI/tests/e2e/distributed/distributed_full_flow_test.go
Ettore Di Giacinto 59108fbe32 feat: add distributed mode (#9124)
* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 00:47:27 +02:00

985 lines
34 KiB
Go

package distributed_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"time"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/nats-io/nats.go"
"google.golang.org/grpc"
pgdriver "gorm.io/driver/postgres"
gormDB "gorm.io/gorm"
"gorm.io/gorm/logger"
)
// testLLM is a minimal AIModel implementation for testing.
// Override methods to write output to Dst so we can test the full
// FileStagingClient round-trip (upload inputs + download outputs).
type testLLM struct {
base.Base
loaded bool
lastModel string
// dstOutput is the content written to any Dst path by output-producing methods.
dstOutput []byte
// lastSrc records the last Src/input path seen (for verifying staging rewrote it).
lastSrc string
// lastAudioDst records the Dst field from AudioTranscription (it's an input, not output).
lastAudioDst string
// lastTTSModel records the Model field from TTS requests (for verifying path rewriting).
lastTTSModel string
}
func (t *testLLM) Load(opts *pb.ModelOptions) error {
t.loaded = true
t.lastModel = opts.ModelFile
return nil
}
func (t *testLLM) Predict(opts *pb.PredictOptions) (string, error) {
if !t.loaded {
return "", fmt.Errorf("model not loaded")
}
return "test response from remote node", nil
}
func (t *testLLM) GenerateImage(req *pb.GenerateImageRequest) error {
t.lastSrc = req.Src
if req.Dst != "" && len(t.dstOutput) > 0 {
return os.WriteFile(req.Dst, t.dstOutput, 0644)
}
return nil
}
func (t *testLLM) GenerateVideo(req *pb.GenerateVideoRequest) error {
t.lastSrc = req.StartImage
if req.Dst != "" && len(t.dstOutput) > 0 {
return os.WriteFile(req.Dst, t.dstOutput, 0644)
}
return nil
}
func (t *testLLM) TTS(req *pb.TTSRequest) error {
t.lastTTSModel = req.Model
if req.Dst != "" && len(t.dstOutput) > 0 {
return os.WriteFile(req.Dst, t.dstOutput, 0644)
}
return nil
}
func (t *testLLM) SoundGeneration(req *pb.SoundGenerationRequest) error {
if req.Src != nil {
t.lastSrc = *req.Src
}
if req.Dst != "" && len(t.dstOutput) > 0 {
return os.WriteFile(req.Dst, t.dstOutput, 0644)
}
return nil
}
func (t *testLLM) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
t.lastAudioDst = req.Dst
return pb.TranscriptResult{Text: "transcribed text"}, nil
}
// startTestGRPCServer starts a real gRPC backend server on a free port
// and returns the address and cleanup function.
func startTestGRPCServer(llm grpcPkg.AIModel) (string, func(), error) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", nil, err
}
addr := lis.Addr().String()
s := grpc.NewServer(
grpc.MaxRecvMsgSize(50*1024*1024),
grpc.MaxSendMsgSize(50*1024*1024),
)
pb.RegisterBackendServer(s, grpcPkg.NewBackendServer(llm))
go func() {
defer GinkgoRecover()
_ = s.Serve(lis)
}()
cleanup := func() {
s.GracefulStop()
}
return addr, cleanup, nil
}
// startTestHTTPFileServer starts a test HTTP file transfer server (mirroring serve_backend_http.go)
// on a free port and returns the address and cleanup function.
func startTestHTTPFileServer(stagingDir string) (string, func(), error) {
if err := os.MkdirAll(stagingDir, 0750); err != nil {
return "", nil, err
}
mux := http.NewServeMux()
mux.HandleFunc("/v1/files/", func(w http.ResponseWriter, r *http.Request) {
key := r.URL.Path[len("/v1/files/"):]
switch r.Method {
case http.MethodPut:
safeName := filepath.Base(key)
dstPath := filepath.Join(stagingDir, safeName)
f, err := os.Create(dstPath)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
io.Copy(f, r.Body)
f.Close()
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"local_path":%q}`, dstPath)
case http.MethodGet:
safeName := filepath.Base(key)
srcPath := filepath.Join(stagingDir, safeName)
if _, statErr := os.Stat(srcPath); os.IsNotExist(statErr) {
// AllocRemoteTemp creates files under stagingDir/tmp/
srcPath = filepath.Join(stagingDir, "tmp", safeName)
}
f, err := os.Open(srcPath)
if err != nil {
http.Error(w, "not found", http.StatusNotFound)
return
}
defer f.Close()
w.Header().Set("Content-Type", "application/octet-stream")
io.Copy(w, f)
case http.MethodPost:
if key == "temp" {
tmpDir := filepath.Join(stagingDir, "tmp")
os.MkdirAll(tmpDir, 0750)
f, err := os.CreateTemp(tmpDir, "output-*")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
localPath := f.Name()
f.Close()
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"local_path":%q}`, localPath)
} else {
http.Error(w, "not found", http.StatusNotFound)
}
}
})
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", nil, err
}
httpAddr := lis.Addr().String()
srv := &http.Server{Handler: mux}
go srv.Serve(lis)
cleanup := func() {
srv.Close()
}
return httpAddr, cleanup, nil
}
var _ = Describe("Full Distributed Inference Flow", Label("Distributed"), func() {
var (
infra *TestInfra
cancel context.CancelFunc
ctx context.Context
db *gormDB.DB
registry *nodes.NodeRegistry
)
BeforeEach(func() {
infra = SetupInfra("localai_fullflow_test")
ctx, cancel = context.WithTimeout(infra.Ctx, 2*time.Minute)
var err error
db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
Expect(err).ToNot(HaveOccurred())
registry, err = nodes.NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
cancel()
})
// newTestSmartRouter creates a SmartRouter with NATS wired up and a mock
// backend.install handler that always replies success for all registered nodes.
newTestSmartRouter := func(reg *nodes.NodeRegistry, extraOpts ...nodes.SmartRouterOptions) *nodes.SmartRouter {
unloader := nodes.NewRemoteUnloaderAdapter(reg, infra.NC)
opts := nodes.SmartRouterOptions{
Unloader: unloader,
}
if len(extraOpts) > 0 {
o := extraOpts[0]
if o.FileStager != nil {
opts.FileStager = o.FileStager
}
if o.GalleriesJSON != "" {
opts.GalleriesJSON = o.GalleriesJSON
}
if o.AuthToken != "" {
opts.AuthToken = o.AuthToken
}
if o.DB != nil {
opts.DB = o.DB
}
}
router := nodes.NewSmartRouter(reg, opts)
// Subscribe a mock backend.install handler that replies success for any node.
// We use a wildcard-style approach: subscribe to all nodes' install subjects
// by registering after each node. In practice, we rely on the test registering
// nodes before calling Route, so we subscribe to a catch-all pattern.
infra.NC.Conn().Subscribe("nodes.*.backend.install", func(msg *nats.Msg) {
reply := messaging.BackendInstallReply{Success: true}
data, _ := json.Marshal(reply)
msg.Respond(data)
})
return router
}
// suppress unused warning in case some tests don't call it
_ = newTestSmartRouter
It("should route inference to a registered node with a real gRPC backend", func() {
// 1. Start a mock gRPC backend
llm := &testLLM{}
addr, cleanup, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanup()
// 2. Register it as a node
node := &nodes.BackendNode{
Name: "test-gpu-1",
Address: addr,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// 3. Create SmartRouter and route a request
router := newTestSmartRouter(registry)
// The model is not loaded yet, so Route will pick the node and call LoadModel
result, err := router.Route(ctx, "", "test-model", "llama-cpp", &pb.ModelOptions{
Model: "test-model",
}, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.Name).To(Equal("test-gpu-1"))
// 4. Verify the model was loaded on the backend
Expect(llm.loaded).To(BeTrue())
// 5. Use the client to call Predict
reply, err := result.Client.Predict(ctx, &pb.PredictOptions{
Prompt: "Hello world",
})
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.Message)).To(Equal("test response from remote node"))
// 6. Release and verify in-flight decremented
result.Release()
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(1))
Expect(models[0].InFlight).To(Equal(0))
// 7. Verify model recorded as "loaded" in registry
Expect(models[0].State).To(Equal("loaded"))
Expect(models[0].ModelName).To(Equal("test-model"))
})
It("should load-balance across multiple nodes with same model", func() {
// Start two mock gRPC backends
llm1 := &testLLM{}
addr1, cleanup1, err := startTestGRPCServer(llm1)
Expect(err).ToNot(HaveOccurred())
defer cleanup1()
llm2 := &testLLM{}
addr2, cleanup2, err := startTestGRPCServer(llm2)
Expect(err).ToNot(HaveOccurred())
defer cleanup2()
// Register both nodes
node1 := &nodes.BackendNode{Name: "node-heavy", Address: addr1}
node2 := &nodes.BackendNode{Name: "node-light", Address: addr2}
Expect(registry.Register(context.Background(), node1, true)).To(Succeed())
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
// Set both as having the model loaded
Expect(registry.SetNodeModel(context.Background(), node1.ID, "test-model", "loaded", "", 0)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node2.ID, "test-model", "loaded", "", 0)).To(Succeed())
// Set node-1 with high in-flight (5), node-2 with low in-flight (1)
for range 5 {
Expect(registry.IncrementInFlight(context.Background(), node1.ID, "test-model")).To(Succeed())
}
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "test-model")).To(Succeed())
// Route should pick node-2 (least loaded) thanks to ORDER BY in_flight ASC
router := newTestSmartRouter(registry)
result, err := router.Route(ctx, "", "test-model", "llama-cpp", nil, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.Name).To(Equal("node-light"))
result.Release()
})
It("should load model on empty node when no node has it", func() {
// Start a mock gRPC backend
llm := &testLLM{}
addr, cleanup, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanup()
// Register a node with NO models loaded
node := &nodes.BackendNode{Name: "empty-node", Address: addr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Route should pick this node and call LoadModel on it
router := newTestSmartRouter(registry)
result, err := router.Route(ctx, "", "new-model", "llama-cpp", &pb.ModelOptions{
Model: "new-model",
}, false)
Expect(err).ToNot(HaveOccurred())
Expect(result.Node.Name).To(Equal("empty-node"))
Expect(llm.loaded).To(BeTrue())
// Verify model is now recorded in registry
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(1))
Expect(models[0].ModelName).To(Equal("new-model"))
Expect(models[0].State).To(Equal("loaded"))
result.Release()
})
It("should unload remote model via NATS", func() {
// Register a node with a loaded model
node := &nodes.BackendNode{Name: "gpu-unload", Address: "127.0.0.1:50099"}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "old-model", "loaded", "", 0)).To(Succeed())
// Subscribe to NATS backend.stop for this node
stopSubject := messaging.SubjectNodeBackendStop(node.ID)
received := make(chan struct{}, 1)
rawConn, err := nats.Connect(infra.NatsURL)
Expect(err).ToNot(HaveOccurred())
defer rawConn.Close()
_, err = rawConn.Subscribe(stopSubject, func(msg *nats.Msg) {
received <- struct{}{}
})
Expect(err).ToNot(HaveOccurred())
// Create RemoteUnloaderAdapter and unload model
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
err = unloader.UnloadRemoteModel("old-model")
Expect(err).ToNot(HaveOccurred())
// Verify NATS event received
Eventually(received, 5*time.Second).Should(Receive())
// Verify model removed from registry
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(BeEmpty())
})
It("should integrate ModelRouterAdapter with SmartRouter end-to-end", func() {
// Start a mock gRPC backend
llm := &testLLM{}
addr, cleanup, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanup()
// Register node
node := &nodes.BackendNode{Name: "adapter-node", Address: addr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Create SmartRouter + ModelRouterAdapter
router := newTestSmartRouter(registry)
adapter := nodes.NewModelRouterAdapter(router)
// Call adapter.Route() (same signature ModelLoader uses)
m, err := adapter.Route(ctx, "llama-cpp", "test-model-id", "test-model", "",
&pb.ModelOptions{Model: "test-model"}, false)
Expect(err).ToNot(HaveOccurred())
Expect(m).ToNot(BeNil())
// Verify returned Model has correct ID and nil process (remote)
Expect(m.ID).To(Equal("test-model-id"))
Expect(m.Process()).To(BeNil())
// Verify the model was loaded on the backend
Expect(llm.loaded).To(BeTrue())
// Use the Model's GRPC() method to get a client and verify inference works
client := m.GRPC(false, nil)
Expect(client).ToNot(BeNil())
reply, err := client.Predict(ctx, &pb.PredictOptions{Prompt: "test"})
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.Message)).To(Equal("test response from remote node"))
// Release the model via adapter
adapter.ReleaseModel("test-model-id")
})
It("should stage model files via HTTP when routing to a new node", func() {
// Create a real model file on disk
modelDir := GinkgoT().TempDir()
modelContent := []byte("fake GGUF model data — this is test content for file transfer verification")
modelPath := filepath.Join(modelDir, "model.gguf")
Expect(os.WriteFile(modelPath, modelContent, 0644)).To(Succeed())
mmprojContent := []byte("fake mmproj data for multimodal projection")
mmprojPath := filepath.Join(modelDir, "mmproj.bin")
Expect(os.WriteFile(mmprojPath, mmprojContent, 0644)).To(Succeed())
// Start a real gRPC backend server (for AI RPCs) and HTTP server (for file transfer)
llm := &testLLM{}
stagingDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
Expect(err).ToNot(HaveOccurred())
defer cleanupHTTP()
// Register the node in PostgreSQL
node := &nodes.BackendNode{Name: "staging-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Create HTTPFileStager that resolves node IDs to HTTP addresses
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Create SmartRouter with the HTTPFileStager
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
// Route with ModelOptions that have file paths — SmartRouter should stage them
result, err := router.Route(ctx, "", "staged-model", "llama-cpp", &pb.ModelOptions{
Model: "staged-model",
ModelFile: modelPath,
MMProj: mmprojPath,
}, false)
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
Expect(result.Node.Name).To(Equal("staging-node"))
// Verify the model file bytes were transferred to the backend's staging dir
stagedModelPath := filepath.Join(stagingDir, "model.gguf")
stagedModelData, err := os.ReadFile(stagedModelPath)
Expect(err).ToNot(HaveOccurred())
Expect(stagedModelData).To(Equal(modelContent))
stagedMMProjPath := filepath.Join(stagingDir, "mmproj.bin")
stagedMMProjData, err := os.ReadFile(stagedMMProjPath)
Expect(err).ToNot(HaveOccurred())
Expect(stagedMMProjData).To(Equal(mmprojContent))
// Verify LoadModel was called with the rewritten (remote) paths
Expect(llm.loaded).To(BeTrue())
Expect(llm.lastModel).To(Equal(stagedModelPath))
// Verify Predict still works through the FileStagingClient wrapper
reply, err := result.Client.Predict(ctx, &pb.PredictOptions{
Prompt: "test via staging client",
})
Expect(err).ToNot(HaveOccurred())
Expect(string(reply.Message)).To(Equal("test response from remote node"))
result.Release()
})
It("should stage multimodal input files via HTTP through FileStagingClient", func() {
// Create a real image file on disk
imageDir := GinkgoT().TempDir()
imageContent := []byte("fake JPEG image data for multimodal testing")
imagePath := filepath.Join(imageDir, "photo.jpg")
Expect(os.WriteFile(imagePath, imageContent, 0644)).To(Succeed())
// Start gRPC server (AI RPCs) and HTTP server (file transfer)
llm := &testLLM{}
stagingDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
Expect(err).ToNot(HaveOccurred())
defer cleanupHTTP()
// Register node
node := &nodes.BackendNode{Name: "mm-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Create HTTPFileStager
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Create SmartRouter with FileStager
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
// Route with ModelOptions — triggers LoadModel on the node
modelDir := GinkgoT().TempDir()
modelPath := filepath.Join(modelDir, "vision.gguf")
Expect(os.WriteFile(modelPath, []byte("vision model data"), 0644)).To(Succeed())
result, err := router.Route(ctx, "", "vision-model", "llama-cpp", &pb.ModelOptions{
Model: "vision-model",
ModelFile: modelPath,
}, false)
Expect(err).ToNot(HaveOccurred())
// Verify LoadModel was called (model file was staged)
Expect(llm.loaded).To(BeTrue())
// Now call Predict with image file paths — FileStagingClient should stage them
_, err = result.Client.Predict(ctx, &pb.PredictOptions{
Prompt: "describe this image",
Images: []string{imagePath},
})
Expect(err).ToNot(HaveOccurred())
// Verify the image file was actually transferred to the backend staging dir
stagedImagePath := filepath.Join(stagingDir, "photo.jpg")
stagedImageData, err := os.ReadFile(stagedImagePath)
Expect(err).ToNot(HaveOccurred())
Expect(stagedImageData).To(Equal(imageContent))
result.Release()
})
It("should transfer output files back via HTTP", func() {
// Start gRPC server (AI RPCs) and HTTP server (file transfer)
llm := &testLLM{}
stagingDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
Expect(err).ToNot(HaveOccurred())
defer cleanupHTTP()
// Register node
node := &nodes.BackendNode{Name: "output-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
// Create HTTPFileStager
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Test AllocRemoteTemp + FetchRemote directly (the output retrieval path)
remoteTmpPath, err := stager.AllocRemoteTemp(ctx, node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(remoteTmpPath).ToNot(BeEmpty())
// Simulate backend writing output to the temp path
outputContent := []byte("generated image output data from the backend")
Expect(os.WriteFile(remoteTmpPath, outputContent, 0644)).To(Succeed())
// FetchRemote pulls the file from the backend to a local path
localOutputDir := GinkgoT().TempDir()
localOutputPath := filepath.Join(localOutputDir, "output.png")
err = stager.FetchRemote(ctx, node.ID, remoteTmpPath, localOutputPath)
Expect(err).ToNot(HaveOccurred())
// Verify the output file was retrieved with correct content
retrievedData, err := os.ReadFile(localOutputPath)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputContent))
})
// --- Full round-trip tests for every FileStagingClient src/dst path ---
// Helper: creates an HTTPFileStager + SmartRouter, registers a node,
// and routes to it. Returns the RouteResult (with FileStagingClient) and cleanup.
setupStagedRoute := func(llm *testLLM, backendType, modelName string) (
*nodes.RouteResult, string, func(),
) {
stagingDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
Expect(err).ToNot(HaveOccurred())
node := &nodes.BackendNode{Name: modelName + "-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
result, err := router.Route(ctx, "", modelName, backendType, &pb.ModelOptions{
Model: modelName,
}, false)
Expect(err).ToNot(HaveOccurred())
cleanup := func() {
cleanupGRPC()
cleanupHTTP()
}
return result, stagingDir, cleanup
}
It("should round-trip output via FileStagingClient.GenerateImage (Src + Dst)", func() {
outputData := []byte("PNG image generated by the backend - 1024x1024 pixels")
llm := &testLLM{dstOutput: outputData}
result, stagingDir, cleanup := setupStagedRoute(llm, "diffusers", "sd-model")
defer cleanup()
defer result.Release()
// Create a source image to test input staging (img2img)
srcDir := GinkgoT().TempDir()
srcContent := []byte("source image for img2img")
srcPath := filepath.Join(srcDir, "src.png")
Expect(os.WriteFile(srcPath, srcContent, 0644)).To(Succeed())
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "generated.png")
genResult, err := result.Client.GenerateImage(ctx, &pb.GenerateImageRequest{
PositivePrompt: "a cat",
Src: srcPath,
Dst: frontendDst,
Height: 1024,
Width: 1024,
})
Expect(err).ToNot(HaveOccurred())
Expect(genResult.Success).To(BeTrue())
// Verify input: Src was staged to backend — testLLM.lastSrc should be a staging dir path
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
// Verify the staged input file has correct content
stagedSrcData, err := os.ReadFile(llm.lastSrc)
Expect(err).ToNot(HaveOccurred())
Expect(stagedSrcData).To(Equal(srcContent))
// Verify output: the generated file was pulled back to the frontend
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should round-trip output via FileStagingClient.GenerateVideo (StartImage + Dst)", func() {
outputData := []byte("MP4 video generated by the backend")
llm := &testLLM{dstOutput: outputData}
result, stagingDir, cleanup := setupStagedRoute(llm, "diffusers", "vid-model")
defer cleanup()
defer result.Release()
// Create a start image to test input staging
imgDir := GinkgoT().TempDir()
startImageContent := []byte("start frame image data")
startImagePath := filepath.Join(imgDir, "start.png")
Expect(os.WriteFile(startImagePath, startImageContent, 0644)).To(Succeed())
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "generated.mp4")
genResult, err := result.Client.GenerateVideo(ctx, &pb.GenerateVideoRequest{
Prompt: "a flying cat",
StartImage: startImagePath,
Dst: frontendDst,
NumFrames: 16,
})
Expect(err).ToNot(HaveOccurred())
Expect(genResult.Success).To(BeTrue())
// Verify input: StartImage was staged
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
stagedStartData, err := os.ReadFile(llm.lastSrc)
Expect(err).ToNot(HaveOccurred())
Expect(stagedStartData).To(Equal(startImageContent))
// Verify output: video was pulled back
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should round-trip output via FileStagingClient.TTS (Dst only)", func() {
outputData := []byte("WAV audio generated by TTS backend")
llm := &testLLM{dstOutput: outputData}
result, _, cleanup := setupStagedRoute(llm, "piper", "tts-model")
defer cleanup()
defer result.Release()
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "speech.wav")
ttsResult, err := result.Client.TTS(ctx, &pb.TTSRequest{
Text: "Hello world",
Model: "tts-model",
Dst: frontendDst,
})
Expect(err).ToNot(HaveOccurred())
Expect(ttsResult.Success).To(BeTrue())
// Verify output: audio was pulled back
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should round-trip via FileStagingClient.SoundGeneration (Src + Dst)", func() {
outputData := []byte("generated sound effect audio data")
llm := &testLLM{dstOutput: outputData}
result, stagingDir, cleanup := setupStagedRoute(llm, "bark", "soundgen-model")
defer cleanup()
defer result.Release()
// Create input audio source
srcDir := GinkgoT().TempDir()
srcContent := []byte("input audio for sound generation")
srcPath := filepath.Join(srcDir, "input.wav")
Expect(os.WriteFile(srcPath, srcContent, 0644)).To(Succeed())
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "output.wav")
sgResult, err := result.Client.SoundGeneration(ctx, &pb.SoundGenerationRequest{
Text: "explosion sound",
Src: &srcPath,
Dst: frontendDst,
})
Expect(err).ToNot(HaveOccurred())
Expect(sgResult.Success).To(BeTrue())
// Verify input: Src was staged
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
stagedSrcData, err := os.ReadFile(llm.lastSrc)
Expect(err).ToNot(HaveOccurred())
Expect(stagedSrcData).To(Equal(srcContent))
// Verify output: audio was pulled back
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should stage input audio via FileStagingClient.AudioTranscription (Dst is input)", func() {
llm := &testLLM{}
result, stagingDir, cleanup := setupStagedRoute(llm, "whisper", "whisper-model")
defer cleanup()
defer result.Release()
// Create input audio file
audioDir := GinkgoT().TempDir()
audioContent := []byte("WAV audio data for transcription")
audioPath := filepath.Join(audioDir, "recording.wav")
Expect(os.WriteFile(audioPath, audioContent, 0644)).To(Succeed())
// AudioTranscription uses Dst as the input audio path (confusing naming)
txResult, err := result.Client.AudioTranscription(ctx, &pb.TranscriptRequest{
Dst: audioPath,
})
Expect(err).ToNot(HaveOccurred())
Expect(txResult.Text).To(Equal("transcribed text"))
// Verify input: audio file was staged to the backend
Expect(llm.lastAudioDst).To(ContainSubstring(stagingDir))
stagedAudioData, err := os.ReadFile(llm.lastAudioDst)
Expect(err).ToNot(HaveOccurred())
Expect(stagedAudioData).To(Equal(audioContent))
})
It("should translate TTS Model path to remote worker path", func() {
outputData := []byte("WAV audio generated by TTS backend")
llm := &testLLM{dstOutput: outputData}
// Set up real file transfer server so model staging preserves directory structure
modelsDir := GinkgoT().TempDir()
stagingDir := GinkgoT().TempDir()
dataDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
httpLis, err := net.Listen("tcp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
httpAddr := httpLis.Addr().String()
httpServer, err := nodes.StartFileTransferServerWithListener(httpLis, stagingDir, modelsDir, dataDir, "", 0)
Expect(err).ToNot(HaveOccurred())
defer nodes.ShutdownFileTransferServer(httpServer)
node := &nodes.BackendNode{Name: "tts-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Create model files on the "frontend"
frontendModelsDir := GinkgoT().TempDir()
modelContent := []byte("fake onnx model data")
configContent := []byte(`{"audio":{"sample_rate":22050}}`)
modelFile := filepath.Join(frontendModelsDir, "it-paola-medium.onnx")
configFile := filepath.Join(frontendModelsDir, "it-paola-medium.onnx.json")
Expect(os.WriteFile(modelFile, modelContent, 0644)).To(Succeed())
Expect(os.WriteFile(configFile, configContent, 0644)).To(Succeed())
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
// Route with ModelFile pointing to the .onnx file (triggers model staging)
result, err := router.Route(ctx, "voice-it-paola-medium", "it-paola-medium.onnx", "piper", &pb.ModelOptions{
Model: "it-paola-medium.onnx",
ModelFile: modelFile,
}, false)
Expect(err).ToNot(HaveOccurred())
defer result.Release()
localOutputDir := GinkgoT().TempDir()
frontendDst := filepath.Join(localOutputDir, "speech.wav")
// Simulate what core/backend/tts.go does: construct Model path using frontend ModelPath
frontendModelPath := filepath.Join(frontendModelsDir, "it-paola-medium.onnx")
ttsResult, err := result.Client.TTS(ctx, &pb.TTSRequest{
Text: "Hello world",
Model: frontendModelPath, // frontend absolute path — should be translated to remote
Dst: frontendDst,
})
Expect(err).ToNot(HaveOccurred())
Expect(ttsResult.Success).To(BeTrue())
// Verify: the backend received the remote worker path, NOT the frontend path
Expect(llm.lastTTSModel).ToNot(Equal(frontendModelPath))
// The remote path should be under the worker's models dir with the tracking key
Expect(llm.lastTTSModel).To(ContainSubstring("voice-it-paola-medium"))
Expect(llm.lastTTSModel).To(HaveSuffix("it-paola-medium.onnx"))
// Verify the model file exists at the translated path (already staged during LoadModel)
stagedModelData, err := os.ReadFile(llm.lastTTSModel)
Expect(err).ToNot(HaveOccurred())
Expect(stagedModelData).To(Equal(modelContent))
// Verify the companion .onnx.json is next to it (staged during LoadModel)
companionPath := llm.lastTTSModel + ".json"
stagedConfigData, err := os.ReadFile(companionPath)
Expect(err).ToNot(HaveOccurred())
Expect(stagedConfigData).To(Equal(configContent))
// Verify output: audio was pulled back
retrievedData, err := os.ReadFile(frontendDst)
Expect(err).ToNot(HaveOccurred())
Expect(retrievedData).To(Equal(outputData))
})
It("should stage companion .onnx.json files alongside .onnx model files", func() {
llm := &testLLM{}
modelsDir := GinkgoT().TempDir()
stagingDir := GinkgoT().TempDir()
dataDir := GinkgoT().TempDir()
addr, cleanupGRPC, err := startTestGRPCServer(llm)
Expect(err).ToNot(HaveOccurred())
defer cleanupGRPC()
// Use the real file transfer server (preserves directory structure)
httpLis, err := net.Listen("tcp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
httpAddr := httpLis.Addr().String()
httpServer, err := nodes.StartFileTransferServerWithListener(httpLis, stagingDir, modelsDir, dataDir, "", 0)
Expect(err).ToNot(HaveOccurred())
defer nodes.ShutdownFileTransferServer(httpServer)
node := &nodes.BackendNode{Name: "companion-node", Address: addr, HTTPAddress: httpAddr}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
n, err := registry.Get(context.Background(), nodeID)
if err != nil {
return "", err
}
return n.HTTPAddress, nil
}, "")
// Create model files: .onnx and .onnx.json in a temp "models" dir
frontendModelsDir := GinkgoT().TempDir()
modelContent := []byte("fake onnx model")
configContent := []byte(`{"audio":{"sample_rate":22050}}`)
modelFile := filepath.Join(frontendModelsDir, "my-model.onnx")
configFile := filepath.Join(frontendModelsDir, "my-model.onnx.json")
Expect(os.WriteFile(modelFile, modelContent, 0644)).To(Succeed())
Expect(os.WriteFile(configFile, configContent, 0644)).To(Succeed())
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
// Route with ModelFile pointing to the .onnx file
result, err := router.Route(ctx, "piper-companion-test", "my-model.onnx", "piper", &pb.ModelOptions{
Model: "my-model.onnx",
ModelFile: modelFile,
}, false)
Expect(err).ToNot(HaveOccurred())
defer result.Release()
// Verify: both .onnx and .onnx.json were staged to the worker's models dir
stagedOnnx := filepath.Join(modelsDir, "piper-companion-test", "my-model.onnx")
stagedConfig := filepath.Join(modelsDir, "piper-companion-test", "my-model.onnx.json")
stagedOnnxData, err := os.ReadFile(stagedOnnx)
Expect(err).ToNot(HaveOccurred(), "companion .onnx model should be staged")
Expect(stagedOnnxData).To(Equal(modelContent))
stagedConfigData, err := os.ReadFile(stagedConfig)
Expect(err).ToNot(HaveOccurred(), "companion .onnx.json config should be staged alongside model")
Expect(stagedConfigData).To(Equal(configContent))
})
})