Files
LocalAI/tests/e2e/distributed/router_tracking_test.go
Ettore Di Giacinto 59108fbe32 feat: add distributed mode (#9124)
* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 00:47:27 +02:00

208 lines
6.7 KiB
Go

package distributed_test
import (
"context"
"encoding/json"
"github.com/mudler/LocalAI/core/services/messaging"
"github.com/mudler/LocalAI/core/services/nodes"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/nats-io/nats.go"
pgdriver "gorm.io/driver/postgres"
gormDB "gorm.io/gorm"
"gorm.io/gorm/logger"
)
// trackingTestLLM is a minimal gRPC backend for router tracking tests.
type trackingTestLLM struct {
base.Base
loaded bool
}
func (t *trackingTestLLM) Load(opts *pb.ModelOptions) error {
t.loaded = true
return nil
}
func (t *trackingTestLLM) Predict(opts *pb.PredictOptions) (string, error) {
return "ok", nil
}
var _ = Describe("SmartRouter trackingKey", Label("Distributed"), func() {
var (
infra *TestInfra
db *gormDB.DB
registry *nodes.NodeRegistry
router *nodes.SmartRouter
grpcCleanup func()
grpcAddr string
nodeID string
)
BeforeEach(func() {
infra = SetupInfra("localai_tracking_test")
var err error
db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
Expect(err).ToNot(HaveOccurred())
registry, err = nodes.NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
// Mock backend.install handler — always replies success
infra.NC.Conn().Subscribe("nodes.*.backend.install", func(msg *nats.Msg) {
reply := messaging.BackendInstallReply{Success: true}
data, _ := json.Marshal(reply)
msg.Respond(data)
})
// Start a mock gRPC backend using the same helper as full flow tests
llm := &trackingTestLLM{}
grpcAddr, grpcCleanup, err = startTestGRPCServer(grpcPkg.AIModel(llm))
Expect(err).ToNot(HaveOccurred())
// Register a node pointing to the mock backend
node := &nodes.BackendNode{
Name: "tracking-node", Address: grpcAddr,
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
nodeID = node.ID
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
router = nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
Unloader: unloader,
})
})
AfterEach(func() {
if grpcCleanup != nil {
grpcCleanup()
}
})
It("records model under modelID when modelID is provided", func() {
result, err := router.Route(infra.Ctx, "my-model-id", "path/to/model.gguf", "llama-cpp",
&pb.ModelOptions{ModelFile: "path/to/model.gguf"}, false)
Expect(err).ToNot(HaveOccurred())
defer result.Release()
// The DB should have the model tracked under "my-model-id"
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "my-model-id")
Expect(err).ToNot(HaveOccurred())
Expect(nodesWithModel).To(HaveLen(1))
Expect(nodesWithModel[0].ID).To(Equal(nodeID))
})
It("records model under modelName when modelID is empty (backward compat)", func() {
result, err := router.Route(infra.Ctx, "", "legacy/model.bin", "llama-cpp",
&pb.ModelOptions{ModelFile: "legacy/model.bin"}, false)
Expect(err).ToNot(HaveOccurred())
defer result.Release()
// The DB should have the model tracked under the modelName
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "legacy/model.bin")
Expect(err).ToNot(HaveOccurred())
Expect(nodesWithModel).To(HaveLen(1))
})
It("FindNodesWithModel(modelID) finds node; FindNodesWithModel(modelName) does not", func() {
result, err := router.Route(infra.Ctx, "distinct-id", "distinct/path.gguf", "llama-cpp",
&pb.ModelOptions{ModelFile: "distinct/path.gguf"}, false)
Expect(err).ToNot(HaveOccurred())
defer result.Release()
// Should find by modelID
found, err := registry.FindNodesWithModel(context.Background(), "distinct-id")
Expect(err).ToNot(HaveOccurred())
Expect(found).To(HaveLen(1))
// Should NOT find by modelName (different from modelID)
notFound, err := registry.FindNodesWithModel(context.Background(), "distinct/path.gguf")
Expect(err).ToNot(HaveOccurred())
Expect(notFound).To(BeEmpty())
})
It("InFlight tracking increments and decrements via registry", func() {
// Route to establish model record
result, err := router.Route(infra.Ctx, "release-model", "release/path.gguf", "llama-cpp",
&pb.ModelOptions{ModelFile: "release/path.gguf"}, false)
Expect(err).ToNot(HaveOccurred())
defer result.Release()
// Read the baseline in-flight count (Route sets initialInFlight=1)
models, err := registry.GetNodeModels(context.Background(), nodeID)
Expect(err).ToNot(HaveOccurred())
var baseline int
for _, m := range models {
if m.ModelName == "release-model" {
baseline = m.InFlight
}
}
// Manually increment in-flight (simulates what InFlightTrackingClient.track() does during inference)
Expect(registry.IncrementInFlight(context.Background(), nodeID, "release-model")).To(Succeed())
// Check in-flight increased
models, err = registry.GetNodeModels(context.Background(), nodeID)
Expect(err).ToNot(HaveOccurred())
var inflight int
for _, m := range models {
if m.ModelName == "release-model" {
inflight = m.InFlight
}
}
Expect(inflight).To(Equal(baseline + 1))
// Decrement and check in-flight goes back to baseline
Expect(registry.DecrementInFlight(context.Background(), nodeID, "release-model")).To(Succeed())
models, err = registry.GetNodeModels(context.Background(), nodeID)
Expect(err).ToNot(HaveOccurred())
for _, m := range models {
if m.ModelName == "release-model" {
Expect(m.InFlight).To(Equal(baseline))
}
}
})
It("clears stale model record when node is unreachable", func() {
// First route to establish the model record
result, err := router.Route(infra.Ctx, "stale-check", "stale/path.gguf", "llama-cpp",
&pb.ModelOptions{ModelFile: "stale/path.gguf"}, false)
Expect(err).ToNot(HaveOccurred())
result.Release()
// Model should be in DB
found, err := registry.FindNodesWithModel(context.Background(), "stale-check")
Expect(err).ToNot(HaveOccurred())
Expect(found).To(HaveLen(1))
// Stop the gRPC server to make the node unreachable
grpcCleanup()
grpcCleanup = nil
// Route again — should detect unreachable node and clear stale record
// (it will fall through to FindLeastLoadedNode + backend.install which succeeds,
// but the LoadModel gRPC call will fail since the server is down)
_, err = router.Route(infra.Ctx, "stale-check", "stale/path.gguf", "llama-cpp",
&pb.ModelOptions{ModelFile: "stale/path.gguf"}, false)
// Expect an error since the only node is down (LoadModel fails)
Expect(err).To(HaveOccurred())
// The stale model record should have been cleared
found, err = registry.FindNodesWithModel(context.Background(), "stale-check")
Expect(err).ToNot(HaveOccurred())
Expect(found).To(BeEmpty())
})
})