Files
LocalAI/tests/e2e/distributed/model_routing_test.go
Ettore Di Giacinto 59108fbe32 feat: add distributed mode (#9124)
* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 00:47:27 +02:00

124 lines
4.3 KiB
Go

package distributed_test
import (
"context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/nodes"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pgdriver "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var _ = Describe("Model Routing", Label("Distributed"), func() {
var (
infra *TestInfra
db *gorm.DB
registry *nodes.NodeRegistry
)
BeforeEach(func() {
infra = SetupInfra("localai_routing_test")
var err error
db, err = gorm.Open(pgdriver.Open(infra.PGURL), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
Expect(err).ToNot(HaveOccurred())
registry, err = nodes.NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
Context("ModelRouterAdapter from SmartRouter", func() {
It("should create ModelRouterAdapter from SmartRouter", func() {
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{})
Expect(router).ToNot(BeNil())
adapter := nodes.NewModelRouterAdapter(router)
Expect(adapter).ToNot(BeNil())
// The adapter should provide a ModelRouter callback
routerFunc := adapter.AsModelRouter()
Expect(routerFunc).ToNot(BeNil())
})
It("should release in-flight counter on model unload", func() {
// Register a node with a loaded model
node := &nodes.BackendNode{
Name: "gpu-1", Address: "h1:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
Expect(registry.SetNodeModel(context.Background(), node.ID, "llama3", "loaded", "", 0)).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
Expect(registry.IncrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
// Verify in-flight count
models, err := registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models[0].InFlight).To(Equal(2))
// FindAndLockNodeWithModel should return this node and atomically increment in-flight
foundNode, foundModel, err := registry.FindAndLockNodeWithModel(context.Background(), "llama3")
Expect(err).ToNot(HaveOccurred())
Expect(foundNode.ID).To(Equal(node.ID))
Expect(foundModel.ModelName).To(Equal("llama3"))
Expect(foundModel.InFlight).To(Equal(2), "InFlight returned is the pre-increment snapshot from the query")
// Verify the DB now has in_flight = 3 (2 manual + 1 from FindAndLock)
models, err = registry.GetNodeModels(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(models[0].InFlight).To(Equal(3))
// Simulate decrement (what Release does)
Expect(registry.DecrementInFlight(context.Background(), node.ID, "llama3")).To(Succeed())
models, _ = registry.GetNodeModels(context.Background(), node.ID)
Expect(models[0].InFlight).To(Equal(2))
// The ModelRouterAdapter.ReleaseModel calls the stored Release function
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{})
adapter := nodes.NewModelRouterAdapter(router)
// ReleaseModel on an unknown model should be a no-op (no panic)
Expect(func() { adapter.ReleaseModel("nonexistent-model") }).ToNot(Panic())
})
It("should use SmartRouter to find nodes with a model", func() {
// Register multiple nodes
node1 := &nodes.BackendNode{
Name: "node-a", Address: "h1:50051",
}
node2 := &nodes.BackendNode{
Name: "node-b", Address: "h2:50051",
}
Expect(registry.Register(context.Background(), node1, true)).To(Succeed())
Expect(registry.Register(context.Background(), node2, true)).To(Succeed())
// Load model on node1
Expect(registry.SetNodeModel(context.Background(), node1.ID, "llama3", "loaded", "", 0)).To(Succeed())
// Verify routing can find the model
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "llama3")
Expect(err).ToNot(HaveOccurred())
Expect(nodesWithModel).To(HaveLen(1))
Expect(nodesWithModel[0].ID).To(Equal(node1.ID))
})
})
Context("Without --distributed", func() {
It("should fall through to local loading without --distributed", func() {
appCfg := config.NewApplicationConfig()
Expect(appCfg.Distributed.Enabled).To(BeFalse())
// Without distributed mode, no SmartRouter is created.
// The ModelLoader uses its local process management.
// This test documents the design decision.
Expect(appCfg.Distributed.NatsURL).To(BeEmpty())
})
})
})