Files
LocalAI/tests/e2e/distributed/node_registration_test.go
Ettore Di Giacinto 59108fbe32 feat: add distributed mode (#9124)
* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 00:47:27 +02:00

237 lines
7.5 KiB
Go

package distributed_test
import (
"context"
"time"
"github.com/mudler/LocalAI/core/services/nodes"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pgdriver "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var _ = Describe("Phase 1: Node Registration", Label("Distributed"), func() {
var (
infra *TestInfra
db *gorm.DB
registry *nodes.NodeRegistry
)
BeforeEach(func() {
infra = SetupInfra("localai_nodes_test")
var err error
db, err = gorm.Open(pgdriver.Open(infra.PGURL), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
Expect(err).ToNot(HaveOccurred())
registry, err = nodes.NewNodeRegistry(db)
Expect(err).ToNot(HaveOccurred())
})
Context("Node Registration", func() {
It("should register a node", func() {
node := &nodes.BackendNode{
Name: "test-node",
Address: "localhost:50051",
}
err := registry.Register(context.Background(), node, true)
Expect(err).ToNot(HaveOccurred())
Expect(node.ID).ToNot(BeEmpty())
Expect(node.Status).To(Equal("healthy"))
})
It("should list registered nodes", func() {
err := registry.Register(context.Background(), &nodes.BackendNode{
Name: "node-1", Address: "host1:50051",
}, true)
Expect(err).ToNot(HaveOccurred())
err = registry.Register(context.Background(), &nodes.BackendNode{
Name: "node-2", Address: "host2:50051",
}, true)
Expect(err).ToNot(HaveOccurred())
list, err := registry.List(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(list).To(HaveLen(2))
})
It("should deregister a node", func() {
node := &nodes.BackendNode{
Name: "ephemeral", Address: "host3:50051",
}
err := registry.Register(context.Background(), node, true)
Expect(err).ToNot(HaveOccurred())
err = registry.Deregister(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
list, err := registry.List(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(list).To(BeEmpty())
})
It("should receive heartbeats and update last_heartbeat", func() {
node := &nodes.BackendNode{
Name: "heartbeat-node", Address: "host4:50051",
}
err := registry.Register(context.Background(), node, true)
Expect(err).ToNot(HaveOccurred())
// Wait a bit then heartbeat
time.Sleep(100 * time.Millisecond)
err = registry.Heartbeat(context.Background(), node.ID, nil)
Expect(err).ToNot(HaveOccurred())
updated, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(updated.LastHeartbeat).To(BeTemporally(">", node.LastHeartbeat))
})
It("should mark node unhealthy after missed heartbeats", func() {
node := &nodes.BackendNode{
Name: "stale-node", Address: "host5:50051",
}
err := registry.Register(context.Background(), node, true)
Expect(err).ToNot(HaveOccurred())
err = registry.MarkUnhealthy(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
updated, err := registry.Get(context.Background(), node.ID)
Expect(err).ToNot(HaveOccurred())
Expect(updated.Status).To(Equal("unhealthy"))
})
It("should find stale nodes", func() {
node := &nodes.BackendNode{
Name: "old-node", Address: "host6:50051",
}
err := registry.Register(context.Background(), node, true)
Expect(err).ToNot(HaveOccurred())
// Set heartbeat to the past
db.Model(&nodes.BackendNode{}).Where("id = ?", node.ID).
Update("last_heartbeat", time.Now().Add(-5*time.Minute))
stale, err := registry.FindStaleNodes(context.Background(), 1*time.Minute)
Expect(err).ToNot(HaveOccurred())
Expect(stale).To(HaveLen(1))
Expect(stale[0].Name).To(Equal("old-node"))
})
It("should update existing node on re-registration", func() {
node := &nodes.BackendNode{
Name: "reregister-node", Address: "h1:50051",
}
err := registry.Register(context.Background(), node, true)
Expect(err).ToNot(HaveOccurred())
firstID := node.ID
// Re-register with updated address
node2 := &nodes.BackendNode{
Name: "reregister-node", Address: "h1:50052",
}
err = registry.Register(context.Background(), node2, true)
Expect(err).ToNot(HaveOccurred())
// Should be same node (upsert by name)
list, err := registry.List(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(list).To(HaveLen(1))
Expect(list[0].ID).To(Equal(firstID))
Expect(list[0].Address).To(Equal("h1:50052"))
})
})
Context("Node Models", func() {
var nodeID string
BeforeEach(func() {
node := &nodes.BackendNode{
Name: "model-node", Address: "mh:50051",
}
Expect(registry.Register(context.Background(), node, true)).To(Succeed())
nodeID = node.ID
})
It("should track models loaded on a node", func() {
err := registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
Expect(err).ToNot(HaveOccurred())
models, err := registry.GetNodeModels(context.Background(), nodeID)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(1))
Expect(models[0].ModelName).To(Equal("llama3"))
Expect(models[0].State).To(Equal("loaded"))
})
It("should find nodes with a specific model", func() {
registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
nodesWithModel, err := registry.FindNodesWithModel(context.Background(), "llama3")
Expect(err).ToNot(HaveOccurred())
Expect(nodesWithModel).To(HaveLen(1))
Expect(nodesWithModel[0].ID).To(Equal(nodeID))
})
It("should increment and decrement in-flight counters", func() {
registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
err := registry.IncrementInFlight(context.Background(), nodeID, "llama3")
Expect(err).ToNot(HaveOccurred())
err = registry.IncrementInFlight(context.Background(), nodeID, "llama3")
Expect(err).ToNot(HaveOccurred())
models, _ := registry.GetNodeModels(context.Background(), nodeID)
Expect(models[0].InFlight).To(Equal(2))
registry.DecrementInFlight(context.Background(), nodeID, "llama3")
models, _ = registry.GetNodeModels(context.Background(), nodeID)
Expect(models[0].InFlight).To(Equal(1))
})
It("should remove model association from node", func() {
registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
err := registry.RemoveNodeModel(context.Background(), nodeID, "llama3")
Expect(err).ToNot(HaveOccurred())
models, _ := registry.GetNodeModels(context.Background(), nodeID)
Expect(models).To(BeEmpty())
})
It("should find LRU model on a node", func() {
// Load two models
registry.SetNodeModel(context.Background(), nodeID, "old-model", "loaded", "", 0)
time.Sleep(10 * time.Millisecond)
registry.SetNodeModel(context.Background(), nodeID, "new-model", "loaded", "", 0)
// Update last_used to make old-model older
db.Model(&nodes.NodeModel{}).Where("node_id = ? AND model_name = ?", nodeID, "old-model").
Update("last_used", time.Now().Add(-10*time.Minute))
lru, err := registry.FindLRUModel(context.Background(), nodeID)
Expect(err).ToNot(HaveOccurred())
Expect(lru.ModelName).To(Equal("old-model"))
})
It("should clean up models when deregistering node", func() {
registry.SetNodeModel(context.Background(), nodeID, "llama3", "loaded", "", 0)
registry.SetNodeModel(context.Background(), nodeID, "whisper", "loaded", "", 0)
err := registry.Deregister(context.Background(), nodeID)
Expect(err).ToNot(HaveOccurred())
// Models should be gone too
models, _ := registry.GetNodeModels(context.Background(), nodeID)
Expect(models).To(BeEmpty())
})
})
})