mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-02 06:04:09 -04:00
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
142 lines
4.4 KiB
Go
142 lines
4.4 KiB
Go
package distributed_test
|
|
|
|
import (
|
|
"sync/atomic"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/services/distributed"
|
|
"github.com/mudler/LocalAI/core/services/messaging"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
pgdriver "gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
var _ = Describe("Fine-Tune Distributed", Label("Distributed"), func() {
|
|
var (
|
|
infra *TestInfra
|
|
db *gorm.DB
|
|
ftStore *distributed.FineTuneStore
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
infra = SetupInfra("localai_finetune_dist_test")
|
|
|
|
var err error
|
|
db, err = gorm.Open(pgdriver.Open(infra.PGURL), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
ftStore, err = distributed.NewFineTuneStore(db)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
Context("PostgreSQL persistence", func() {
|
|
It("should persist fine-tune jobs in PostgreSQL when store is set", func() {
|
|
job := &distributed.FineTuneJobRecord{
|
|
UserID: "u1",
|
|
Model: "llama3-8b",
|
|
Backend: "transformers",
|
|
TrainingType: "lora",
|
|
TrainingMethod: "sft",
|
|
Status: "queued",
|
|
}
|
|
Expect(ftStore.Create(job)).To(Succeed())
|
|
Expect(job.ID).ToNot(BeEmpty())
|
|
|
|
retrieved, err := ftStore.Get(job.ID)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(retrieved.Model).To(Equal("llama3-8b"))
|
|
Expect(retrieved.Status).To(Equal("queued"))
|
|
|
|
// Update status through training lifecycle
|
|
Expect(ftStore.UpdateStatus(job.ID, "loading_model", "Loading base model")).To(Succeed())
|
|
loading, _ := ftStore.Get(job.ID)
|
|
Expect(loading.Status).To(Equal("loading_model"))
|
|
|
|
Expect(ftStore.UpdateStatus(job.ID, "training", "Epoch 1/3")).To(Succeed())
|
|
training, _ := ftStore.Get(job.ID)
|
|
Expect(training.Status).To(Equal("training"))
|
|
Expect(training.Message).To(Equal("Epoch 1/3"))
|
|
|
|
Expect(ftStore.UpdateStatus(job.ID, "saving", "Saving adapter")).To(Succeed())
|
|
Expect(ftStore.UpdateStatus(job.ID, "completed", "Training finished")).To(Succeed())
|
|
completed, _ := ftStore.Get(job.ID)
|
|
Expect(completed.Status).To(Equal("completed"))
|
|
|
|
// Export status
|
|
Expect(ftStore.UpdateExportStatus(job.ID, "completed", "Export done", "llama3-lora-v1")).To(Succeed())
|
|
exported, _ := ftStore.Get(job.ID)
|
|
Expect(exported.ExportStatus).To(Equal("completed"))
|
|
Expect(exported.ExportModelName).To(Equal("llama3-lora-v1"))
|
|
|
|
// List jobs
|
|
allJobs, _ := ftStore.List("")
|
|
Expect(allJobs).To(HaveLen(1))
|
|
|
|
u1Jobs, _ := ftStore.List("u1")
|
|
Expect(u1Jobs).To(HaveLen(1))
|
|
})
|
|
})
|
|
|
|
Context("NATS progress publishing", func() {
|
|
It("should publish fine-tune progress via NATS", func() {
|
|
job := &distributed.FineTuneJobRecord{
|
|
UserID: "u1", Model: "m1", Backend: "b1",
|
|
TrainingType: "lora", TrainingMethod: "sft", Status: "queued",
|
|
}
|
|
Expect(ftStore.Create(job)).To(Succeed())
|
|
|
|
// Subscribe to fine-tune progress
|
|
var received atomic.Int32
|
|
sub, err := infra.NC.Subscribe(messaging.SubjectFineTuneProgress(job.ID), func(data []byte) {
|
|
received.Add(1)
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
defer sub.Unsubscribe()
|
|
|
|
FlushNATS(infra.NC)
|
|
|
|
// Publish progress events simulating training steps
|
|
Expect(infra.NC.Publish(messaging.SubjectFineTuneProgress(job.ID), map[string]any{
|
|
"job_id": job.ID,
|
|
"status": "training",
|
|
"message": "Epoch 1/3, loss=2.5",
|
|
})).To(Succeed())
|
|
|
|
Expect(infra.NC.Publish(messaging.SubjectFineTuneProgress(job.ID), map[string]any{
|
|
"job_id": job.ID,
|
|
"status": "training",
|
|
"message": "Epoch 2/3, loss=1.8",
|
|
})).To(Succeed())
|
|
|
|
Expect(infra.NC.Publish(messaging.SubjectFineTuneProgress(job.ID), map[string]any{
|
|
"job_id": job.ID,
|
|
"status": "completed",
|
|
"message": "Training finished",
|
|
})).To(Succeed())
|
|
|
|
Eventually(func() int32 { return received.Load() }, "5s").Should(Equal(int32(3)))
|
|
|
|
// Verify cancel subject is correctly formed
|
|
cancelSubj := messaging.SubjectFineTuneCancel(job.ID)
|
|
Expect(cancelSubj).To(ContainSubstring(".cancel"))
|
|
})
|
|
})
|
|
|
|
Context("Without --distributed", func() {
|
|
It("should use in-memory state without --distributed", func() {
|
|
appCfg := config.NewApplicationConfig()
|
|
Expect(appCfg.Distributed.Enabled).To(BeFalse())
|
|
|
|
// Without distributed mode, fine-tune jobs use local in-memory
|
|
// state tracking. No PostgreSQL or NATS needed.
|
|
Expect(appCfg.Distributed.NatsURL).To(BeEmpty())
|
|
})
|
|
})
|
|
})
|