mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-02 06:04:09 -04:00
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
208 lines
6.7 KiB
Go
208 lines
6.7 KiB
Go
package distributed_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
|
|
"github.com/mudler/LocalAI/core/services/messaging"
|
|
"github.com/mudler/LocalAI/core/services/nodes"
|
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
|
|
grpcPkg "github.com/mudler/LocalAI/pkg/grpc"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
|
|
pgdriver "gorm.io/driver/postgres"
|
|
gormDB "gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// trackingTestLLM is a minimal gRPC backend for router tracking tests.
|
|
type trackingTestLLM struct {
|
|
base.Base
|
|
loaded bool
|
|
}
|
|
|
|
func (t *trackingTestLLM) Load(opts *pb.ModelOptions) error {
|
|
t.loaded = true
|
|
return nil
|
|
}
|
|
|
|
func (t *trackingTestLLM) Predict(opts *pb.PredictOptions) (string, error) {
|
|
return "ok", nil
|
|
}
|
|
|
|
var _ = Describe("SmartRouter trackingKey", Label("Distributed"), func() {
|
|
var (
|
|
infra *TestInfra
|
|
db *gormDB.DB
|
|
registry *nodes.NodeRegistry
|
|
router *nodes.SmartRouter
|
|
grpcCleanup func()
|
|
grpcAddr string
|
|
nodeID string
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
infra = SetupInfra("localai_tracking_test")
|
|
|
|
var err error
|
|
db, err = gormDB.Open(pgdriver.Open(infra.PGURL), &gormDB.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
registry, err = nodes.NewNodeRegistry(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Mock backend.install handler — always replies success
|
|
infra.NC.Conn().Subscribe("nodes.*.backend.install", func(msg *nats.Msg) {
|
|
reply := messaging.BackendInstallReply{Success: true}
|
|
data, _ := json.Marshal(reply)
|
|
msg.Respond(data)
|
|
})
|
|
|
|
// Start a mock gRPC backend using the same helper as full flow tests
|
|
llm := &trackingTestLLM{}
|
|
grpcAddr, grpcCleanup, err = startTestGRPCServer(grpcPkg.AIModel(llm))
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Register a node pointing to the mock backend
|
|
node := &nodes.BackendNode{
|
|
Name: "tracking-node", Address: grpcAddr,
|
|
}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
nodeID = node.ID
|
|
|
|
unloader := nodes.NewRemoteUnloaderAdapter(registry, infra.NC)
|
|
router = nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
|
|
Unloader: unloader,
|
|
})
|
|
})
|
|
|
|
AfterEach(func() {
|
|
if grpcCleanup != nil {
|
|
grpcCleanup()
|
|
}
|
|
})
|
|
|
|
It("records model under modelID when modelID is provided", func() {
|
|
result, err := router.Route(infra.Ctx, "my-model-id", "path/to/model.gguf", "llama-cpp",
|
|
&pb.ModelOptions{ModelFile: "path/to/model.gguf"}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer result.Release()
|
|
|
|
// The DB should have the model tracked under "my-model-id"
|
|
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "my-model-id")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(nodesWithModel).To(HaveLen(1))
|
|
Expect(nodesWithModel[0].ID).To(Equal(nodeID))
|
|
})
|
|
|
|
It("records model under modelName when modelID is empty (backward compat)", func() {
|
|
result, err := router.Route(infra.Ctx, "", "legacy/model.bin", "llama-cpp",
|
|
&pb.ModelOptions{ModelFile: "legacy/model.bin"}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer result.Release()
|
|
|
|
// The DB should have the model tracked under the modelName
|
|
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "legacy/model.bin")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(nodesWithModel).To(HaveLen(1))
|
|
})
|
|
|
|
It("FindNodesWithModel(modelID) finds node; FindNodesWithModel(modelName) does not", func() {
|
|
result, err := router.Route(infra.Ctx, "distinct-id", "distinct/path.gguf", "llama-cpp",
|
|
&pb.ModelOptions{ModelFile: "distinct/path.gguf"}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer result.Release()
|
|
|
|
// Should find by modelID
|
|
found, err := registry.FindNodesWithModel(context.Background(), "distinct-id")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(found).To(HaveLen(1))
|
|
|
|
// Should NOT find by modelName (different from modelID)
|
|
notFound, err := registry.FindNodesWithModel(context.Background(), "distinct/path.gguf")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(notFound).To(BeEmpty())
|
|
})
|
|
|
|
It("InFlight tracking increments and decrements via registry", func() {
|
|
// Route to establish model record
|
|
result, err := router.Route(infra.Ctx, "release-model", "release/path.gguf", "llama-cpp",
|
|
&pb.ModelOptions{ModelFile: "release/path.gguf"}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer result.Release()
|
|
|
|
// Read the baseline in-flight count (Route sets initialInFlight=1)
|
|
models, err := registry.GetNodeModels(context.Background(), nodeID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
var baseline int
|
|
for _, m := range models {
|
|
if m.ModelName == "release-model" {
|
|
baseline = m.InFlight
|
|
}
|
|
}
|
|
|
|
// Manually increment in-flight (simulates what InFlightTrackingClient.track() does during inference)
|
|
Expect(registry.IncrementInFlight(context.Background(), nodeID, "release-model")).To(Succeed())
|
|
|
|
// Check in-flight increased
|
|
models, err = registry.GetNodeModels(context.Background(), nodeID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
var inflight int
|
|
for _, m := range models {
|
|
if m.ModelName == "release-model" {
|
|
inflight = m.InFlight
|
|
}
|
|
}
|
|
Expect(inflight).To(Equal(baseline + 1))
|
|
|
|
// Decrement and check in-flight goes back to baseline
|
|
Expect(registry.DecrementInFlight(context.Background(), nodeID, "release-model")).To(Succeed())
|
|
|
|
models, err = registry.GetNodeModels(context.Background(), nodeID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
for _, m := range models {
|
|
if m.ModelName == "release-model" {
|
|
Expect(m.InFlight).To(Equal(baseline))
|
|
}
|
|
}
|
|
})
|
|
|
|
It("clears stale model record when node is unreachable", func() {
|
|
// First route to establish the model record
|
|
result, err := router.Route(infra.Ctx, "stale-check", "stale/path.gguf", "llama-cpp",
|
|
&pb.ModelOptions{ModelFile: "stale/path.gguf"}, false)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
result.Release()
|
|
|
|
// Model should be in DB
|
|
found, err := registry.FindNodesWithModel(context.Background(), "stale-check")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(found).To(HaveLen(1))
|
|
|
|
// Stop the gRPC server to make the node unreachable
|
|
grpcCleanup()
|
|
grpcCleanup = nil
|
|
|
|
// Route again — should detect unreachable node and clear stale record
|
|
// (it will fall through to FindLeastLoadedNode + backend.install which succeeds,
|
|
// but the LoadModel gRPC call will fail since the server is down)
|
|
_, err = router.Route(infra.Ctx, "stale-check", "stale/path.gguf", "llama-cpp",
|
|
&pb.ModelOptions{ModelFile: "stale/path.gguf"}, false)
|
|
// Expect an error since the only node is down (LoadModel fails)
|
|
Expect(err).To(HaveOccurred())
|
|
|
|
// The stale model record should have been cleared
|
|
found, err = registry.FindNodesWithModel(context.Background(), "stale-check")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(found).To(BeEmpty())
|
|
})
|
|
})
|