mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-02 06:04:09 -04:00
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
124 lines
4.3 KiB
Go
124 lines
4.3 KiB
Go
package distributed_test
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/services/nodes"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
pgdriver "gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
var _ = Describe("Model Routing", Label("Distributed"), func() {
|
|
var (
|
|
infra *TestInfra
|
|
db *gorm.DB
|
|
registry *nodes.NodeRegistry
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
infra = SetupInfra("localai_routing_test")
|
|
|
|
var err error
|
|
db, err = gorm.Open(pgdriver.Open(infra.PGURL), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
registry, err = nodes.NewNodeRegistry(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
Context("ModelRouterAdapter from SmartRouter", func() {
|
|
It("should create ModelRouterAdapter from SmartRouter", func() {
|
|
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{})
|
|
Expect(router).ToNot(BeNil())
|
|
|
|
adapter := nodes.NewModelRouterAdapter(router)
|
|
Expect(adapter).ToNot(BeNil())
|
|
|
|
// The adapter should provide a ModelRouter callback
|
|
routerFunc := adapter.AsModelRouter()
|
|
Expect(routerFunc).ToNot(BeNil())
|
|
})
|
|
|
|
It("should release in-flight counter on model unload", func() {
|
|
// Register a node with a loaded model
|
|
node := &nodes.BackendNode{
|
|
Name: "gpu-1", Address: "h1:50051",
|
|
}
|
|
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
|
|
Expect(registry.SetNodeModel(context.Background(), node.ID, "llama3", "loaded", "", 0)).To(Succeed())
|
|
Expect(registry.IncrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
|
|
Expect(registry.IncrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
|
|
|
|
// Verify in-flight count
|
|
models, err := registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(models[0].InFlight).To(Equal(2))
|
|
|
|
// FindAndLockNodeWithModel should return this node and atomically increment in-flight
|
|
foundNode, foundModel, err := registry.FindAndLockNodeWithModel(context.Background(), "llama3")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(foundNode.ID).To(Equal(node.ID))
|
|
Expect(foundModel.ModelName).To(Equal("llama3"))
|
|
Expect(foundModel.InFlight).To(Equal(2), "InFlight returned is the pre-increment snapshot from the query")
|
|
|
|
// Verify the DB now has in_flight = 3 (2 manual + 1 from FindAndLock)
|
|
models, err = registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(models[0].InFlight).To(Equal(3))
|
|
|
|
// Simulate decrement (what Release does)
|
|
Expect(registry.DecrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
|
|
models, _ = registry.GetNodeModels(context.Background(), node.ID)
|
|
Expect(models[0].InFlight).To(Equal(2))
|
|
|
|
// The ModelRouterAdapter.ReleaseModel calls the stored Release function
|
|
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{})
|
|
adapter := nodes.NewModelRouterAdapter(router)
|
|
|
|
// ReleaseModel on an unknown model should be a no-op (no panic)
|
|
Expect(func() { adapter.ReleaseModel("nonexistent-model") }).ToNot(Panic())
|
|
})
|
|
|
|
It("should use SmartRouter to find nodes with a model", func() {
|
|
// Register multiple nodes
|
|
node1 := &nodes.BackendNode{
|
|
Name: "node-a", Address: "h1:50051",
|
|
}
|
|
node2 := &nodes.BackendNode{
|
|
Name: "node-b", Address: "h2:50051",
|
|
}
|
|
Expect(registry.Register(context.Background(), node1, true)).To(Succeed())
|
|
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
|
|
|
|
// Load model on node1
|
|
Expect(registry.SetNodeModel(context.Background(), node1.ID, "llama3", "loaded", "", 0)).To(Succeed())
|
|
|
|
// Verify routing can find the model
|
|
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "llama3")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(nodesWithModel).To(HaveLen(1))
|
|
Expect(nodesWithModel[0].ID).To(Equal(node1.ID))
|
|
})
|
|
})
|
|
|
|
Context("Without --distributed", func() {
|
|
It("should fall through to local loading without --distributed", func() {
|
|
appCfg := config.NewApplicationConfig()
|
|
Expect(appCfg.Distributed.Enabled).To(BeFalse())
|
|
|
|
// Without distributed mode, no SmartRouter is created.
|
|
// The ModelLoader uses its local process management.
|
|
// This test documents the design decision.
|
|
Expect(appCfg.Distributed.NatsURL).To(BeEmpty())
|
|
})
|
|
})
|
|
})
|