mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 21:53:01 -04:00
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
985 lines
34 KiB
Go
985 lines
34 KiB
Go
package distributed_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/mudler/LocalAI/core/services/messaging"
|
|
"github.com/mudler/LocalAI/core/services/nodes"
|
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
|
|
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
"google.golang.org/grpc"
|
|
|
|
pgdriver "gorm.io/driver/postgres"
|
|
gormDB "gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// testLLM is a minimal AIModel implementation for testing.
|
|
// Override methods to write output to Dst so we can test the full
|
|
// FileStagingClient round-trip (upload inputs + download outputs).
|
|
type testLLM struct {
|
|
base.Base
|
|
loaded bool
|
|
lastModel string
|
|
// dstOutput is the content written to any Dst path by output-producing methods.
|
|
dstOutput []byte
|
|
// lastSrc records the last Src/input path seen (for verifying staging rewrote it).
|
|
lastSrc string
|
|
// lastAudioDst records the Dst field from AudioTranscription (it's an input, not output).
|
|
lastAudioDst string
|
|
// lastTTSModel records the Model field from TTS requests (for verifying path rewriting).
|
|
lastTTSModel string
|
|
}
|
|
|
|
func (t *testLLM) Load(opts *pb.ModelOptions) error {
|
|
t.loaded = true
|
|
t.lastModel = opts.ModelFile
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) Predict(opts *pb.PredictOptions) (string, error) {
|
|
if !t.loaded {
|
|
return "", fmt.Errorf("model not loaded")
|
|
}
|
|
return "test response from remote node", nil
|
|
}
|
|
|
|
func (t *testLLM) GenerateImage(req *pb.GenerateImageRequest) error {
|
|
t.lastSrc = req.Src
|
|
if req.Dst != "" && len(t.dstOutput) > 0 {
|
|
return os.WriteFile(req.Dst, t.dstOutput, 0644)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) GenerateVideo(req *pb.GenerateVideoRequest) error {
|
|
t.lastSrc = req.StartImage
|
|
if req.Dst != "" && len(t.dstOutput) > 0 {
|
|
return os.WriteFile(req.Dst, t.dstOutput, 0644)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) TTS(req *pb.TTSRequest) error {
|
|
t.lastTTSModel = req.Model
|
|
if req.Dst != "" && len(t.dstOutput) > 0 {
|
|
return os.WriteFile(req.Dst, t.dstOutput, 0644)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) SoundGeneration(req *pb.SoundGenerationRequest) error {
|
|
if req.Src != nil {
|
|
t.lastSrc = *req.Src
|
|
}
|
|
if req.Dst != "" && len(t.dstOutput) > 0 {
|
|
return os.WriteFile(req.Dst, t.dstOutput, 0644)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *testLLM) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
|
t.lastAudioDst = req.Dst
|
|
return pb.TranscriptResult{Text: "transcribed text"}, nil
|
|
}
|
|
|
|
// startTestGRPCServer starts a real gRPC backend server on a free port
|
|
// and returns the address and cleanup function.
|
|
func startTestGRPCServer(llm grpcPkg.AIModel) (string, func(), error) {
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
addr := lis.Addr().String()
|
|
|
|
s := grpc.NewServer(
|
|
grpc.MaxRecvMsgSize(50*1024*1024),
|
|
grpc.MaxSendMsgSize(50*1024*1024),
|
|
)
|
|
pb.RegisterBackendServer(s, grpcPkg.NewBackendServer(llm))
|
|
|
|
go func() {
|
|
defer GinkgoRecover()
|
|
_ = s.Serve(lis)
|
|
}()
|
|
|
|
cleanup := func() {
|
|
s.GracefulStop()
|
|
}
|
|
return addr, cleanup, nil
|
|
}
|
|
|
|
// startTestHTTPFileServer starts a test HTTP file transfer server (mirroring serve_backend_http.go)
|
|
// on a free port and returns the address and cleanup function.
|
|
func startTestHTTPFileServer(stagingDir string) (string, func(), error) {
|
|
if err := os.MkdirAll(stagingDir, 0750); err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/v1/files/", func(w http.ResponseWriter, r *http.Request) {
|
|
key := r.URL.Path[len("/v1/files/"):]
|
|
switch r.Method {
|
|
case http.MethodPut:
|
|
safeName := filepath.Base(key)
|
|
dstPath := filepath.Join(stagingDir, safeName)
|
|
f, err := os.Create(dstPath)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
io.Copy(f, r.Body)
|
|
f.Close()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprintf(w, `{"local_path":%q}`, dstPath)
|
|
case http.MethodGet:
|
|
safeName := filepath.Base(key)
|
|
srcPath := filepath.Join(stagingDir, safeName)
|
|
if _, statErr := os.Stat(srcPath); os.IsNotExist(statErr) {
|
|
// AllocRemoteTemp creates files under stagingDir/tmp/
|
|
srcPath = filepath.Join(stagingDir, "tmp", safeName)
|
|
}
|
|
f, err := os.Open(srcPath)
|
|
if err != nil {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
defer f.Close()
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
io.Copy(w, f)
|
|
case http.MethodPost:
|
|
if key == "temp" {
|
|
tmpDir := filepath.Join(stagingDir, "tmp")
|
|
os.MkdirAll(tmpDir, 0750)
|
|
f, err := os.CreateTemp(tmpDir, "output-*")
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
localPath := f.Name()
|
|
f.Close()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
fmt.Fprintf(w, `{"local_path":%q}`, localPath)
|
|
} else {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
}
|
|
}
|
|
})
|
|
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
httpAddr := lis.Addr().String()
|
|
srv := &http.Server{Handler: mux}
|
|
go srv.Serve(lis)
|
|
|
|
cleanup := func() {
|
|
srv.Close()
|
|
}
|
|
return httpAddr, cleanup, nil
|
|
}
|
|
|
|
var _ = Describe("Full Distributed Inference Flow", Label("Distributed"), func() {
|
|
var (
|
|
infra *TestInfra
|
|
cancel context.CancelFunc
|
|
ctx context.Context
|
|
db *gormDB.DB
|
|
registry *nodes.NodeRegistry
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
infra = SetupInfra("localai_fullflow_test")
|
|
ctx, cancel = context.WithTimeout(infra.Ctx, 2*time.Minute)
|
|
|
|
var err error
|
|
db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
registry, err = nodes.NewNodeRegistry(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
AfterEach(func() {
|
|
cancel()
|
|
})
|
|
|
|
// newTestSmartRouter creates a SmartRouter with NATS wired up and a mock
|
|
// backend.install handler that always replies success for all registered nodes.
|
|
newTestSmartRouter := func(reg *nodes.NodeRegistry, extraOpts ...nodes.SmartRouterOptions) *nodes.SmartRouter {
|
|
unloader := nodes.NewRemoteUnloaderAdapter(reg, infra.NC)
|
|
|
|
opts := nodes.SmartRouterOptions{
|
|
Unloader: unloader,
|
|
}
|
|
if len(extraOpts) > 0 {
|
|
o := extraOpts[0]
|
|
if o.FileStager != nil {
|
|
opts.FileStager = o.FileStager
|
|
}
|
|
if o.GalleriesJSON != "" {
|
|
opts.GalleriesJSON = o.GalleriesJSON
|
|
}
|
|
if o.AuthToken != "" {
|
|
opts.AuthToken = o.AuthToken
|
|
}
|
|
if o.DB != nil {
|
|
opts.DB = o.DB
|
|
}
|
|
}
|
|
|
|
router := nodes.NewSmartRouter(reg, opts)
|
|
|
|
// Subscribe a mock backend.install handler that replies success for any node.
|
|
// We use a wildcard-style approach: subscribe to all nodes' install subjects
|
|
// by registering after each node. In practice, we rely on the test registering
|
|
// nodes before calling Route, so we subscribe to a catch-all pattern.
|
|
infra.NC.Conn().Subscribe("nodes.*.backend.install", func(msg *nats.Msg) {
|
|
reply := messaging.BackendInstallReply{Success: true}
|
|
data, _ := json.Marshal(reply)
|
|
msg.Respond(data)
|
|
})
|
|
|
|
return router
|
|
}
|
|
// suppress unused warning in case some tests don't call it
|
|
_ = newTestSmartRouter
|
|
|
|
It("should route inference to a registered node with a real gRPC backend", func() {
|
|
// 1. Start a mock gRPC backend
|
|
llm := &testLLM{}
|
|
addr, cleanup, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup()
|
|
|
|
// 2. Register it as a node
|
|
node := &nodes.BackendNode{
|
|
Name: "test-gpu-1",
|
|
Address: addr,
|
|
}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// 3. Create SmartRouter and route a request
|
|
router := newTestSmartRouter(registry)
|
|
|
|
// The model is not loaded yet, so Route will pick the node and call LoadModel
|
|
result, err := router.Route(ctx, "", "test-model", "llama-cpp", &pb.ModelOptions{
|
|
Model: "test-model",
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.Name).To(Equal("test-gpu-1"))
|
|
|
|
// 4. Verify the model was loaded on the backend
|
|
Expect(llm.loaded).To(BeTrue())
|
|
|
|
// 5. Use the client to call Predict
|
|
reply, err := result.Client.Predict(ctx, &pb.PredictOptions{
|
|
Prompt: "Hello world",
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(string(reply.Message)).To(Equal("test response from remote node"))
|
|
|
|
// 6. Release and verify in-flight decremented
|
|
result.Release()
|
|
models, err := registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(models).To(HaveLen(1))
|
|
Expect(models[0].InFlight).To(Equal(0))
|
|
|
|
// 7. Verify model recorded as "loaded" in registry
|
|
Expect(models[0].State).To(Equal("loaded"))
|
|
Expect(models[0].ModelName).To(Equal("test-model"))
|
|
})
|
|
|
|
It("should load-balance across multiple nodes with same model", func() {
|
|
// Start two mock gRPC backends
|
|
llm1 := &testLLM{}
|
|
addr1, cleanup1, err := startTestGRPCServer(llm1)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup1()
|
|
|
|
llm2 := &testLLM{}
|
|
addr2, cleanup2, err := startTestGRPCServer(llm2)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup2()
|
|
|
|
// Register both nodes
|
|
node1 := &nodes.BackendNode{Name: "node-heavy", Address: addr1}
|
|
node2 := &nodes.BackendNode{Name: "node-light", Address: addr2}
|
|
Expect(registry.Register(context.Background(), node1, true)).To(Succeed())
|
|
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
|
|
|
|
// Set both as having the model loaded
|
|
Expect(registry.SetNodeModel(context.Background(), node1.ID, "test-model", "loaded", "", 0)).To(Succeed())
|
|
Expect(registry.SetNodeModel(context.Background(), node2.ID, "test-model", "loaded", "", 0)).To(Succeed())
|
|
|
|
// Set node-1 with high in-flight (5), node-2 with low in-flight (1)
|
|
for range 5 {
|
|
Expect(registry.IncrementInFlight(context.Background(), node1.ID, "test-model")).To(Succeed())
|
|
}
|
|
Expect(registry.IncrementInFlight(context.Background(), node2.ID, "test-model")).To(Succeed())
|
|
|
|
// Route should pick node-2 (least loaded) thanks to ORDER BY in_flight ASC
|
|
router := newTestSmartRouter(registry)
|
|
result, err := router.Route(ctx, "", "test-model", "llama-cpp", nil, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.Name).To(Equal("node-light"))
|
|
result.Release()
|
|
})
|
|
|
|
It("should load model on empty node when no node has it", func() {
|
|
// Start a mock gRPC backend
|
|
llm := &testLLM{}
|
|
addr, cleanup, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup()
|
|
|
|
// Register a node with NO models loaded
|
|
node := &nodes.BackendNode{Name: "empty-node", Address: addr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Route should pick this node and call LoadModel on it
|
|
router := newTestSmartRouter(registry)
|
|
result, err := router.Route(ctx, "", "new-model", "llama-cpp", &pb.ModelOptions{
|
|
Model: "new-model",
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result.Node.Name).To(Equal("empty-node"))
|
|
Expect(llm.loaded).To(BeTrue())
|
|
|
|
// Verify model is now recorded in registry
|
|
models, err := registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(models).To(HaveLen(1))
|
|
Expect(models[0].ModelName).To(Equal("new-model"))
|
|
Expect(models[0].State).To(Equal("loaded"))
|
|
|
|
result.Release()
|
|
})
|
|
|
|
It("should unload remote model via NATS", func() {
|
|
// Register a node with a loaded model
|
|
node := &nodes.BackendNode{Name: "gpu-unload", Address: "127.0.0.1:50099"}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
Expect(registry.SetNodeModel(context.Background(), node.ID, "old-model", "loaded", "", 0)).To(Succeed())
|
|
|
|
// Subscribe to NATS backend.stop for this node
|
|
stopSubject := messaging.SubjectNodeBackendStop(node.ID)
|
|
received := make(chan struct{}, 1)
|
|
rawConn, err := nats.Connect(infra.NatsURL)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer rawConn.Close()
|
|
|
|
_, err = rawConn.Subscribe(stopSubject, func(msg *nats.Msg) {
|
|
received <- struct{}{}
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Create RemoteUnloaderAdapter and unload model
|
|
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
|
err = unloader.UnloadRemoteModel("old-model")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Verify NATS event received
|
|
Eventually(received, 5*time.Second).Should(Receive())
|
|
|
|
// Verify model removed from registry
|
|
models, err := registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(models).To(BeEmpty())
|
|
})
|
|
|
|
It("should integrate ModelRouterAdapter with SmartRouter end-to-end", func() {
|
|
// Start a mock gRPC backend
|
|
llm := &testLLM{}
|
|
addr, cleanup, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanup()
|
|
|
|
// Register node
|
|
node := &nodes.BackendNode{Name: "adapter-node", Address: addr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Create SmartRouter + ModelRouterAdapter
|
|
router := newTestSmartRouter(registry)
|
|
adapter := nodes.NewModelRouterAdapter(router)
|
|
|
|
// Call adapter.Route() (same signature ModelLoader uses)
|
|
m, err := adapter.Route(ctx, "llama-cpp", "test-model-id", "test-model", "",
|
|
&pb.ModelOptions{Model: "test-model"}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(m).ToNot(BeNil())
|
|
|
|
// Verify returned Model has correct ID and nil process (remote)
|
|
Expect(m.ID).To(Equal("test-model-id"))
|
|
Expect(m.Process()).To(BeNil())
|
|
|
|
// Verify the model was loaded on the backend
|
|
Expect(llm.loaded).To(BeTrue())
|
|
|
|
// Use the Model's GRPC() method to get a client and verify inference works
|
|
client := m.GRPC(false, nil)
|
|
Expect(client).ToNot(BeNil())
|
|
reply, err := client.Predict(ctx, &pb.PredictOptions{Prompt: "test"})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(string(reply.Message)).To(Equal("test response from remote node"))
|
|
|
|
// Release the model via adapter
|
|
adapter.ReleaseModel("test-model-id")
|
|
})
|
|
|
|
It("should stage model files via HTTP when routing to a new node", func() {
|
|
// Create a real model file on disk
|
|
modelDir := GinkgoT().TempDir()
|
|
modelContent := []byte("fake GGUF model data — this is test content for file transfer verification")
|
|
modelPath := filepath.Join(modelDir, "model.gguf")
|
|
Expect(os.WriteFile(modelPath, modelContent, 0644)).To(Succeed())
|
|
|
|
mmprojContent := []byte("fake mmproj data for multimodal projection")
|
|
mmprojPath := filepath.Join(modelDir, "mmproj.bin")
|
|
Expect(os.WriteFile(mmprojPath, mmprojContent, 0644)).To(Succeed())
|
|
|
|
// Start a real gRPC backend server (for AI RPCs) and HTTP server (for file transfer)
|
|
llm := &testLLM{}
|
|
stagingDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupHTTP()
|
|
|
|
// Register the node in PostgreSQL
|
|
node := &nodes.BackendNode{Name: "staging-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Create HTTPFileStager that resolves node IDs to HTTP addresses
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Create SmartRouter with the HTTPFileStager
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
// Route with ModelOptions that have file paths — SmartRouter should stage them
|
|
result, err := router.Route(ctx, "", "staged-model", "llama-cpp", &pb.ModelOptions{
|
|
Model: "staged-model",
|
|
ModelFile: modelPath,
|
|
MMProj: mmprojPath,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(result).ToNot(BeNil())
|
|
Expect(result.Node.Name).To(Equal("staging-node"))
|
|
|
|
// Verify the model file bytes were transferred to the backend's staging dir
|
|
stagedModelPath := filepath.Join(stagingDir, "model.gguf")
|
|
stagedModelData, err := os.ReadFile(stagedModelPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedModelData).To(Equal(modelContent))
|
|
|
|
stagedMMProjPath := filepath.Join(stagingDir, "mmproj.bin")
|
|
stagedMMProjData, err := os.ReadFile(stagedMMProjPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedMMProjData).To(Equal(mmprojContent))
|
|
|
|
// Verify LoadModel was called with the rewritten (remote) paths
|
|
Expect(llm.loaded).To(BeTrue())
|
|
Expect(llm.lastModel).To(Equal(stagedModelPath))
|
|
|
|
// Verify Predict still works through the FileStagingClient wrapper
|
|
reply, err := result.Client.Predict(ctx, &pb.PredictOptions{
|
|
Prompt: "test via staging client",
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(string(reply.Message)).To(Equal("test response from remote node"))
|
|
|
|
result.Release()
|
|
})
|
|
|
|
It("should stage multimodal input files via HTTP through FileStagingClient", func() {
|
|
// Create a real image file on disk
|
|
imageDir := GinkgoT().TempDir()
|
|
imageContent := []byte("fake JPEG image data for multimodal testing")
|
|
imagePath := filepath.Join(imageDir, "photo.jpg")
|
|
Expect(os.WriteFile(imagePath, imageContent, 0644)).To(Succeed())
|
|
|
|
// Start gRPC server (AI RPCs) and HTTP server (file transfer)
|
|
llm := &testLLM{}
|
|
stagingDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupHTTP()
|
|
|
|
// Register node
|
|
node := &nodes.BackendNode{Name: "mm-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Create HTTPFileStager
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Create SmartRouter with FileStager
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
// Route with ModelOptions — triggers LoadModel on the node
|
|
modelDir := GinkgoT().TempDir()
|
|
modelPath := filepath.Join(modelDir, "vision.gguf")
|
|
Expect(os.WriteFile(modelPath, []byte("vision model data"), 0644)).To(Succeed())
|
|
|
|
result, err := router.Route(ctx, "", "vision-model", "llama-cpp", &pb.ModelOptions{
|
|
Model: "vision-model",
|
|
ModelFile: modelPath,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Verify LoadModel was called (model file was staged)
|
|
Expect(llm.loaded).To(BeTrue())
|
|
|
|
// Now call Predict with image file paths — FileStagingClient should stage them
|
|
_, err = result.Client.Predict(ctx, &pb.PredictOptions{
|
|
Prompt: "describe this image",
|
|
Images: []string{imagePath},
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Verify the image file was actually transferred to the backend staging dir
|
|
stagedImagePath := filepath.Join(stagingDir, "photo.jpg")
|
|
stagedImageData, err := os.ReadFile(stagedImagePath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedImageData).To(Equal(imageContent))
|
|
|
|
result.Release()
|
|
})
|
|
|
|
It("should transfer output files back via HTTP", func() {
|
|
// Start gRPC server (AI RPCs) and HTTP server (file transfer)
|
|
llm := &testLLM{}
|
|
stagingDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupHTTP()
|
|
|
|
// Register node
|
|
node := &nodes.BackendNode{Name: "output-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
// Create HTTPFileStager
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Test AllocRemoteTemp + FetchRemote directly (the output retrieval path)
|
|
remoteTmpPath, err := stager.AllocRemoteTemp(ctx, node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(remoteTmpPath).ToNot(BeEmpty())
|
|
|
|
// Simulate backend writing output to the temp path
|
|
outputContent := []byte("generated image output data from the backend")
|
|
Expect(os.WriteFile(remoteTmpPath, outputContent, 0644)).To(Succeed())
|
|
|
|
// FetchRemote pulls the file from the backend to a local path
|
|
localOutputDir := GinkgoT().TempDir()
|
|
localOutputPath := filepath.Join(localOutputDir, "output.png")
|
|
err = stager.FetchRemote(ctx, node.ID, remoteTmpPath, localOutputPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Verify the output file was retrieved with correct content
|
|
retrievedData, err := os.ReadFile(localOutputPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputContent))
|
|
})
|
|
|
|
// --- Full round-trip tests for every FileStagingClient src/dst path ---
|
|
|
|
// Helper: creates an HTTPFileStager + SmartRouter, registers a node,
|
|
// and routes to it. Returns the RouteResult (with FileStagingClient) and cleanup.
|
|
setupStagedRoute := func(llm *testLLM, backendType, modelName string) (
|
|
*nodes.RouteResult, string, func(),
|
|
) {
|
|
stagingDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
httpAddr, cleanupHTTP, err := startTestHTTPFileServer(stagingDir)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
node := &nodes.BackendNode{Name: modelName + "-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
result, err := router.Route(ctx, "", modelName, backendType, &pb.ModelOptions{
|
|
Model: modelName,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
cleanup := func() {
|
|
cleanupGRPC()
|
|
cleanupHTTP()
|
|
}
|
|
return result, stagingDir, cleanup
|
|
}
|
|
|
|
It("should round-trip output via FileStagingClient.GenerateImage (Src + Dst)", func() {
|
|
outputData := []byte("PNG image generated by the backend - 1024x1024 pixels")
|
|
|
|
llm := &testLLM{dstOutput: outputData}
|
|
result, stagingDir, cleanup := setupStagedRoute(llm, "diffusers", "sd-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
// Create a source image to test input staging (img2img)
|
|
srcDir := GinkgoT().TempDir()
|
|
srcContent := []byte("source image for img2img")
|
|
srcPath := filepath.Join(srcDir, "src.png")
|
|
Expect(os.WriteFile(srcPath, srcContent, 0644)).To(Succeed())
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "generated.png")
|
|
|
|
genResult, err := result.Client.GenerateImage(ctx, &pb.GenerateImageRequest{
|
|
PositivePrompt: "a cat",
|
|
Src: srcPath,
|
|
Dst: frontendDst,
|
|
Height: 1024,
|
|
Width: 1024,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(genResult.Success).To(BeTrue())
|
|
|
|
// Verify input: Src was staged to backend — testLLM.lastSrc should be a staging dir path
|
|
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
|
|
|
|
// Verify the staged input file has correct content
|
|
stagedSrcData, err := os.ReadFile(llm.lastSrc)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedSrcData).To(Equal(srcContent))
|
|
|
|
// Verify output: the generated file was pulled back to the frontend
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should round-trip output via FileStagingClient.GenerateVideo (StartImage + Dst)", func() {
|
|
outputData := []byte("MP4 video generated by the backend")
|
|
|
|
llm := &testLLM{dstOutput: outputData}
|
|
result, stagingDir, cleanup := setupStagedRoute(llm, "diffusers", "vid-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
// Create a start image to test input staging
|
|
imgDir := GinkgoT().TempDir()
|
|
startImageContent := []byte("start frame image data")
|
|
startImagePath := filepath.Join(imgDir, "start.png")
|
|
Expect(os.WriteFile(startImagePath, startImageContent, 0644)).To(Succeed())
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "generated.mp4")
|
|
|
|
genResult, err := result.Client.GenerateVideo(ctx, &pb.GenerateVideoRequest{
|
|
Prompt: "a flying cat",
|
|
StartImage: startImagePath,
|
|
Dst: frontendDst,
|
|
NumFrames: 16,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(genResult.Success).To(BeTrue())
|
|
|
|
// Verify input: StartImage was staged
|
|
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
|
|
stagedStartData, err := os.ReadFile(llm.lastSrc)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedStartData).To(Equal(startImageContent))
|
|
|
|
// Verify output: video was pulled back
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should round-trip output via FileStagingClient.TTS (Dst only)", func() {
|
|
outputData := []byte("WAV audio generated by TTS backend")
|
|
|
|
llm := &testLLM{dstOutput: outputData}
|
|
result, _, cleanup := setupStagedRoute(llm, "piper", "tts-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "speech.wav")
|
|
|
|
ttsResult, err := result.Client.TTS(ctx, &pb.TTSRequest{
|
|
Text: "Hello world",
|
|
Model: "tts-model",
|
|
Dst: frontendDst,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(ttsResult.Success).To(BeTrue())
|
|
|
|
// Verify output: audio was pulled back
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should round-trip via FileStagingClient.SoundGeneration (Src + Dst)", func() {
|
|
outputData := []byte("generated sound effect audio data")
|
|
|
|
llm := &testLLM{dstOutput: outputData}
|
|
result, stagingDir, cleanup := setupStagedRoute(llm, "bark", "soundgen-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
// Create input audio source
|
|
srcDir := GinkgoT().TempDir()
|
|
srcContent := []byte("input audio for sound generation")
|
|
srcPath := filepath.Join(srcDir, "input.wav")
|
|
Expect(os.WriteFile(srcPath, srcContent, 0644)).To(Succeed())
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "output.wav")
|
|
|
|
sgResult, err := result.Client.SoundGeneration(ctx, &pb.SoundGenerationRequest{
|
|
Text: "explosion sound",
|
|
Src: &srcPath,
|
|
Dst: frontendDst,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(sgResult.Success).To(BeTrue())
|
|
|
|
// Verify input: Src was staged
|
|
Expect(llm.lastSrc).To(ContainSubstring(stagingDir))
|
|
stagedSrcData, err := os.ReadFile(llm.lastSrc)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedSrcData).To(Equal(srcContent))
|
|
|
|
// Verify output: audio was pulled back
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should stage input audio via FileStagingClient.AudioTranscription (Dst is input)", func() {
|
|
llm := &testLLM{}
|
|
result, stagingDir, cleanup := setupStagedRoute(llm, "whisper", "whisper-model")
|
|
defer cleanup()
|
|
defer result.Release()
|
|
|
|
// Create input audio file
|
|
audioDir := GinkgoT().TempDir()
|
|
audioContent := []byte("WAV audio data for transcription")
|
|
audioPath := filepath.Join(audioDir, "recording.wav")
|
|
Expect(os.WriteFile(audioPath, audioContent, 0644)).To(Succeed())
|
|
|
|
// AudioTranscription uses Dst as the input audio path (confusing naming)
|
|
txResult, err := result.Client.AudioTranscription(ctx, &pb.TranscriptRequest{
|
|
Dst: audioPath,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(txResult.Text).To(Equal("transcribed text"))
|
|
|
|
// Verify input: audio file was staged to the backend
|
|
Expect(llm.lastAudioDst).To(ContainSubstring(stagingDir))
|
|
stagedAudioData, err := os.ReadFile(llm.lastAudioDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedAudioData).To(Equal(audioContent))
|
|
})
|
|
|
|
It("should translate TTS Model path to remote worker path", func() {
|
|
outputData := []byte("WAV audio generated by TTS backend")
|
|
llm := &testLLM{dstOutput: outputData}
|
|
|
|
// Set up real file transfer server so model staging preserves directory structure
|
|
modelsDir := GinkgoT().TempDir()
|
|
stagingDir := GinkgoT().TempDir()
|
|
dataDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
httpLis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
httpAddr := httpLis.Addr().String()
|
|
httpServer, err := nodes.StartFileTransferServerWithListener(httpLis, stagingDir, modelsDir, dataDir, "", 0)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer nodes.ShutdownFileTransferServer(httpServer)
|
|
|
|
node := &nodes.BackendNode{Name: "tts-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Create model files on the "frontend"
|
|
frontendModelsDir := GinkgoT().TempDir()
|
|
modelContent := []byte("fake onnx model data")
|
|
configContent := []byte(`{"audio":{"sample_rate":22050}}`)
|
|
modelFile := filepath.Join(frontendModelsDir, "it-paola-medium.onnx")
|
|
configFile := filepath.Join(frontendModelsDir, "it-paola-medium.onnx.json")
|
|
Expect(os.WriteFile(modelFile, modelContent, 0644)).To(Succeed())
|
|
Expect(os.WriteFile(configFile, configContent, 0644)).To(Succeed())
|
|
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
// Route with ModelFile pointing to the .onnx file (triggers model staging)
|
|
result, err := router.Route(ctx, "voice-it-paola-medium", "it-paola-medium.onnx", "piper", &pb.ModelOptions{
|
|
Model: "it-paola-medium.onnx",
|
|
ModelFile: modelFile,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer result.Release()
|
|
|
|
localOutputDir := GinkgoT().TempDir()
|
|
frontendDst := filepath.Join(localOutputDir, "speech.wav")
|
|
|
|
// Simulate what core/backend/tts.go does: construct Model path using frontend ModelPath
|
|
frontendModelPath := filepath.Join(frontendModelsDir, "it-paola-medium.onnx")
|
|
|
|
ttsResult, err := result.Client.TTS(ctx, &pb.TTSRequest{
|
|
Text: "Hello world",
|
|
Model: frontendModelPath, // frontend absolute path — should be translated to remote
|
|
Dst: frontendDst,
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(ttsResult.Success).To(BeTrue())
|
|
|
|
// Verify: the backend received the remote worker path, NOT the frontend path
|
|
Expect(llm.lastTTSModel).ToNot(Equal(frontendModelPath))
|
|
// The remote path should be under the worker's models dir with the tracking key
|
|
Expect(llm.lastTTSModel).To(ContainSubstring("voice-it-paola-medium"))
|
|
Expect(llm.lastTTSModel).To(HaveSuffix("it-paola-medium.onnx"))
|
|
|
|
// Verify the model file exists at the translated path (already staged during LoadModel)
|
|
stagedModelData, err := os.ReadFile(llm.lastTTSModel)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedModelData).To(Equal(modelContent))
|
|
|
|
// Verify the companion .onnx.json is next to it (staged during LoadModel)
|
|
companionPath := llm.lastTTSModel + ".json"
|
|
stagedConfigData, err := os.ReadFile(companionPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(stagedConfigData).To(Equal(configContent))
|
|
|
|
// Verify output: audio was pulled back
|
|
retrievedData, err := os.ReadFile(frontendDst)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrievedData).To(Equal(outputData))
|
|
})
|
|
|
|
It("should stage companion .onnx.json files alongside .onnx model files", func() {
|
|
llm := &testLLM{}
|
|
modelsDir := GinkgoT().TempDir()
|
|
stagingDir := GinkgoT().TempDir()
|
|
dataDir := GinkgoT().TempDir()
|
|
addr, cleanupGRPC, err := startTestGRPCServer(llm)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer cleanupGRPC()
|
|
|
|
// Use the real file transfer server (preserves directory structure)
|
|
httpLis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
httpAddr := httpLis.Addr().String()
|
|
httpServer, err := nodes.StartFileTransferServerWithListener(httpLis, stagingDir, modelsDir, dataDir, "", 0)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer nodes.ShutdownFileTransferServer(httpServer)
|
|
|
|
node := &nodes.BackendNode{Name: "companion-node", Address: addr, HTTPAddress: httpAddr}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
|
|
stager := nodes.NewHTTPFileStager(func(nodeID string) (string, error) {
|
|
n, err := registry.Get(context.Background(), nodeID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return n.HTTPAddress, nil
|
|
}, "")
|
|
|
|
// Create model files: .onnx and .onnx.json in a temp "models" dir
|
|
frontendModelsDir := GinkgoT().TempDir()
|
|
modelContent := []byte("fake onnx model")
|
|
configContent := []byte(`{"audio":{"sample_rate":22050}}`)
|
|
modelFile := filepath.Join(frontendModelsDir, "my-model.onnx")
|
|
configFile := filepath.Join(frontendModelsDir, "my-model.onnx.json")
|
|
Expect(os.WriteFile(modelFile, modelContent, 0644)).To(Succeed())
|
|
Expect(os.WriteFile(configFile, configContent, 0644)).To(Succeed())
|
|
|
|
router := newTestSmartRouter(registry, nodes.SmartRouterOptions{FileStager: stager})
|
|
|
|
// Route with ModelFile pointing to the .onnx file
|
|
result, err := router.Route(ctx, "piper-companion-test", "my-model.onnx", "piper", &pb.ModelOptions{
|
|
Model: "my-model.onnx",
|
|
ModelFile: modelFile,
|
|
}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer result.Release()
|
|
|
|
// Verify: both .onnx and .onnx.json were staged to the worker's models dir
|
|
stagedOnnx := filepath.Join(modelsDir, "piper-companion-test", "my-model.onnx")
|
|
stagedConfig := filepath.Join(modelsDir, "piper-companion-test", "my-model.onnx.json")
|
|
|
|
stagedOnnxData, err := os.ReadFile(stagedOnnx)
|
|
Expect(err).ToNot(HaveOccurred(), "companion .onnx model should be staged")
|
|
Expect(stagedOnnxData).To(Equal(modelContent))
|
|
|
|
stagedConfigData, err := os.ReadFile(stagedConfig)
|
|
Expect(err).ToNot(HaveOccurred(), "companion .onnx.json config should be staged alongside model")
|
|
Expect(stagedConfigData).To(Equal(configContent))
|
|
})
|
|
})
|