mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-27 01:47:18 -04:00
refactor(agentpool): back agent tasks with SyncedMap for cross-replica consistency
AgentJobService.ListTasks read the process-local tasks map only, while ListJobs already read through the DB persister + dispatcher NATS - so in distributed mode a task created on one replica was invisible to the others. Back tasks with a syncstate.SyncedMap keyed by task ID (value schema.Task, the exact REST shape); jobs are left untouched. - Store adapter (task_syncstore.go) over the existing JobPersister (LoadTasks/SaveTask/DeleteTask); reads svc.persister/userID live so a persister swap needs no rebuild. No new persister methods required. - Task reads -> SyncedMap.List/Get; create/update -> Set (write-through + broadcast); delete -> Delete. The file persister now owns its own task set so the write-through path does not re-enter the SyncedMap lock (deadlock guard). - The distributed NATS client is not available at construction (start() precedes initDistributed), so it is injected via SetTaskSyncNATS, which rebuilds the still-empty map before Start/hydrate. Wired at the main, restart, and per-user (UserServicesManager) distributed sites. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Assisted-by: Claude:claude-opus-4-8 [Claude Code]
This commit is contained in:
@@ -37,6 +37,8 @@ func (a *Application) RestartAgentJobService() error {
|
||||
if d.JobStore != nil {
|
||||
agentJobService.SetDistributedJobStore(d.JobStore)
|
||||
}
|
||||
// Keep agent tasks consistent across replicas (same client the dispatcher uses).
|
||||
agentJobService.SetTaskSyncNATS(d.Nats)
|
||||
}
|
||||
|
||||
// Start the service
|
||||
|
||||
@@ -604,6 +604,10 @@ func (a *Application) StartAgentPool() {
|
||||
usm.SetJobDBStore(s)
|
||||
}
|
||||
}
|
||||
// Keep per-user agent tasks consistent across replicas (nil in standalone).
|
||||
if d := a.Distributed(); d != nil {
|
||||
usm.SetJobSyncNATS(d.Nats)
|
||||
}
|
||||
aps.SetUserServicesManager(usm)
|
||||
|
||||
a.agentPoolService.Store(aps)
|
||||
|
||||
@@ -280,6 +280,9 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
if application.agentJobService != nil {
|
||||
application.agentJobService.SetDistributedBackends(distSvc.Dispatcher)
|
||||
application.agentJobService.SetDistributedJobStore(distSvc.JobStore)
|
||||
// Keep agent tasks consistent across replicas (jobs already sync via the
|
||||
// dispatcher + DB read-through). Same NATS client the dispatcher uses.
|
||||
application.agentJobService.SetTaskSyncNATS(distSvc.Nats)
|
||||
}
|
||||
// Wire skill store into AgentPoolService (wired at pool start time via closure)
|
||||
// The actual wiring happens in StartAgentPool since the pool doesn't exist yet.
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -43,8 +45,18 @@ type AgentJobService struct {
|
||||
configLoader *config.ModelConfigLoader
|
||||
evaluator *templates.Evaluator
|
||||
|
||||
// tasks is the cross-replica task store: an in-memory map kept consistent
|
||||
// across replicas via NATS, with read-through to the configured persister
|
||||
// (file in standalone, PostgreSQL in distributed). Unlike jobs - which already
|
||||
// converge via the dispatcher + DB read-through - tasks previously read
|
||||
// in-memory only, so ListTasks went stale on non-originating replicas.
|
||||
tasks *syncstate.SyncedMap[string, schema.Task]
|
||||
// taskNats is the distributed NATS client backing the tasks SyncedMap. It is
|
||||
// not available at construction time, so it is injected via SetTaskSyncNATS
|
||||
// during distributed wiring; nil keeps tasks in-memory-only (standalone).
|
||||
taskNats messaging.MessagingClient
|
||||
|
||||
// Storage (in-memory primary, persister for secondary persistence)
|
||||
tasks *xsync.SyncedMap[string, schema.Task]
|
||||
jobs *xsync.SyncedMap[string, schema.Job]
|
||||
persister JobPersister
|
||||
userID string // Scoping: empty for global (main service), set for per-user instances
|
||||
@@ -96,6 +108,31 @@ func (s *AgentJobService) SetDistributedJobStore(store *jobs.JobStore) {
|
||||
s.persister = &dbJobPersister{store: store}
|
||||
}
|
||||
|
||||
// SetTaskSyncNATS wires the distributed NATS client used to keep agent *tasks*
|
||||
// consistent across replicas (jobs already converge via the dispatcher + DB
|
||||
// read-through, so they are left untouched). The client is not available when the
|
||||
// service is constructed, so it is injected here during distributed wiring and the
|
||||
// tasks SyncedMap is rebuilt to pick it up. It is always called before Start /
|
||||
// hydrate, while the map is still empty, so rebuilding loses no state. Passing nil
|
||||
// (standalone) keeps the map in-memory-only with no broadcast.
|
||||
func (s *AgentJobService) SetTaskSyncNATS(nats messaging.MessagingClient) {
|
||||
s.taskNats = nats
|
||||
s.buildTasksMap()
|
||||
}
|
||||
|
||||
// buildTasksMap (re)constructs the cross-replica tasks SyncedMap from the current
|
||||
// taskNats. The Store adapter reads s.persister/s.userID live, so a persister swap
|
||||
// (SetDistributedJobStore) needs no rebuild; only the NATS client, fixed at
|
||||
// New-time, forces one - hence SetTaskSyncNATS calls this.
|
||||
func (s *AgentJobService) buildTasksMap() {
|
||||
s.tasks = syncstate.New(syncstate.Config[string, schema.Task]{
|
||||
Name: "agent.tasks",
|
||||
Key: func(t schema.Task) string { return t.ID },
|
||||
Nats: s.taskNats,
|
||||
Store: &taskStoreAdapter{svc: s},
|
||||
})
|
||||
}
|
||||
|
||||
// Dispatcher returns the distributed dispatcher (nil if not in distributed mode).
|
||||
func (s *AgentJobService) Dispatcher() DistributedDispatcher {
|
||||
return s.dispatcher
|
||||
@@ -106,13 +143,6 @@ func (s *AgentJobService) DBStore() *jobs.JobStore {
|
||||
return s.rawDBStore
|
||||
}
|
||||
|
||||
// saveTasks persists tasks via the configured persister (file or DB).
|
||||
func (s *AgentJobService) saveTasks(task schema.Task) {
|
||||
if err := s.persister.SaveTask(s.userID, task); err != nil {
|
||||
xlog.Warn("Failed to persist task", "error", err, "task_id", task.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// saveJobs persists jobs via the configured persister (file or DB).
|
||||
func (s *AgentJobService) saveJobs(job schema.Job) {
|
||||
if err := s.persister.SaveJob(s.userID, job); err != nil {
|
||||
@@ -129,18 +159,8 @@ func (s *AgentJobService) LoadFromDB() {
|
||||
|
||||
// loadFromPersister loads tasks and jobs from the configured persister into memory.
|
||||
func (s *AgentJobService) loadFromPersister() {
|
||||
if tasks, err := s.persister.LoadTasks(s.userID); err != nil {
|
||||
if err := s.hydrateTasks(s.appConfig.Context); err != nil {
|
||||
xlog.Warn("Failed to load tasks from persister", "error", err)
|
||||
} else {
|
||||
for _, task := range tasks {
|
||||
s.tasks.Set(task.ID, task)
|
||||
if task.Enabled && task.Cron != "" {
|
||||
if err := s.ScheduleCronTask(task); err != nil {
|
||||
xlog.Warn("Failed to schedule cron task on load", "error", err, "task_id", task.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
xlog.Info("Loaded tasks from persister", "count", len(tasks))
|
||||
}
|
||||
|
||||
if loadedJobs, err := s.persister.LoadJobs(s.userID); err != nil {
|
||||
@@ -153,6 +173,27 @@ func (s *AgentJobService) loadFromPersister() {
|
||||
}
|
||||
}
|
||||
|
||||
// hydrateTasks loads tasks into the cross-replica SyncedMap and (re)schedules
|
||||
// cron entries for enabled tasks. Hydration goes through the SyncedMap's Store
|
||||
// read-through (Start), not Set, so it neither re-persists nor re-broadcasts the
|
||||
// loaded tasks. Each service instance hydrates exactly once: the main service via
|
||||
// Start -> loadFromPersister, per-user services via LoadFromDB or LoadTasksFromFile.
|
||||
func (s *AgentJobService) hydrateTasks(ctx context.Context) error {
|
||||
if err := s.tasks.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
tasks := s.tasks.List()
|
||||
for _, task := range tasks {
|
||||
if task.Enabled && task.Cron != "" {
|
||||
if err := s.ScheduleCronTask(task); err != nil {
|
||||
xlog.Warn("Failed to schedule cron task on load", "error", err, "task_id", task.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
xlog.Info("Loaded tasks from persister", "count", len(tasks))
|
||||
return nil
|
||||
}
|
||||
|
||||
// JobExecution represents a job to be executed
|
||||
type JobExecution struct {
|
||||
Job schema.Job
|
||||
@@ -200,21 +241,19 @@ func NewAgentJobServiceWithPaths(
|
||||
) *AgentJobService {
|
||||
retentionDays := cmp.Or(appConfig.AgentJobRetentionDays, 30)
|
||||
|
||||
tasks := xsync.NewSyncedMap[string, schema.Task]()
|
||||
jobsMap := xsync.NewSyncedMap[string, schema.Job]()
|
||||
|
||||
return &AgentJobService{
|
||||
s := &AgentJobService{
|
||||
appConfig: appConfig,
|
||||
modelLoader: modelLoader,
|
||||
configLoader: configLoader,
|
||||
evaluator: evaluator,
|
||||
tasks: tasks,
|
||||
jobs: jobsMap,
|
||||
persister: &fileJobPersister{
|
||||
tasks: tasks,
|
||||
jobs: jobsMap,
|
||||
tasksFile: tasksFile,
|
||||
jobsFile: jobsFile,
|
||||
taskSet: make(map[string]schema.Task),
|
||||
},
|
||||
jobQueue: make(chan JobExecution, 100), // Buffer for 100 jobs
|
||||
cancellations: xsync.NewSyncedMap[string, context.CancelFunc](),
|
||||
@@ -222,25 +261,17 @@ func NewAgentJobServiceWithPaths(
|
||||
cronEntries: xsync.NewSyncedMap[string, cron.EntryID](),
|
||||
retentionDays: retentionDays,
|
||||
}
|
||||
// Build the cross-replica tasks map standalone (nil NATS); SetTaskSyncNATS
|
||||
// rebuilds it with the distributed client once that is available, before Start.
|
||||
s.buildTasksMap()
|
||||
return s
|
||||
}
|
||||
|
||||
// LoadTasksFromFile loads tasks from the persister into the in-memory map
|
||||
// and schedules cron entries. Named "FromFile" for backward compat; in DB
|
||||
// mode it loads from the database.
|
||||
func (s *AgentJobService) LoadTasksFromFile() error {
|
||||
tasks, err := s.persister.LoadTasks(s.userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, task := range tasks {
|
||||
s.tasks.Set(task.ID, task)
|
||||
if task.Enabled && task.Cron != "" {
|
||||
if err := s.ScheduleCronTask(task); err != nil {
|
||||
xlog.Warn("Failed to schedule cron task on load", "error", err, "task_id", task.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return s.hydrateTasks(s.appConfig.Context)
|
||||
}
|
||||
|
||||
// SaveTasksToFile flushes the current tasks map via the persister. File
|
||||
@@ -293,8 +324,12 @@ func (s *AgentJobService) CreateTask(task schema.Task) (string, error) {
|
||||
task.Enabled = true // Default to enabled
|
||||
}
|
||||
|
||||
// Store task
|
||||
s.tasks.Set(id, task)
|
||||
// Store task: Set updates the in-memory map, write-throughs to the persister
|
||||
// (file or DB), and broadcasts the create to peer replicas. Background ctx
|
||||
// because CreateTask carries no request ctx (mirrors the finetune service).
|
||||
if err := s.tasks.Set(context.Background(), task); err != nil {
|
||||
return "", fmt.Errorf("failed to persist task: %w", err)
|
||||
}
|
||||
|
||||
// Schedule cron if enabled and has cron expression
|
||||
if task.Enabled && task.Cron != "" {
|
||||
@@ -303,16 +338,15 @@ func (s *AgentJobService) CreateTask(task schema.Task) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
s.saveTasks(task)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// UpdateTask updates an existing task
|
||||
func (s *AgentJobService) UpdateTask(id string, task schema.Task) error {
|
||||
if !s.tasks.Exists(id) {
|
||||
existing, ok := s.tasks.Get(id)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
||||
}
|
||||
existing := s.tasks.Get(id)
|
||||
|
||||
// Preserve ID and CreatedAt
|
||||
task.ID = id
|
||||
@@ -324,8 +358,10 @@ func (s *AgentJobService) UpdateTask(id string, task schema.Task) error {
|
||||
s.UnscheduleCronTask(id)
|
||||
}
|
||||
|
||||
// Store updated task
|
||||
s.tasks.Set(id, task)
|
||||
// Store updated task: write-through + broadcast (see CreateTask).
|
||||
if err := s.tasks.Set(context.Background(), task); err != nil {
|
||||
return fmt.Errorf("failed to persist task: %w", err)
|
||||
}
|
||||
|
||||
// Schedule new cron if enabled and has cron expression
|
||||
if task.Enabled && task.Cron != "" {
|
||||
@@ -334,24 +370,22 @@ func (s *AgentJobService) UpdateTask(id string, task schema.Task) error {
|
||||
}
|
||||
}
|
||||
|
||||
s.saveTasks(task)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTask deletes a task
|
||||
func (s *AgentJobService) DeleteTask(id string) error {
|
||||
if !s.tasks.Exists(id) {
|
||||
if _, ok := s.tasks.Get(id); !ok {
|
||||
return fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
||||
}
|
||||
|
||||
// Unschedule cron
|
||||
s.UnscheduleCronTask(id)
|
||||
|
||||
// Remove from memory
|
||||
s.tasks.Delete(id)
|
||||
|
||||
if err := s.persister.DeleteTask(id); err != nil {
|
||||
xlog.Warn("Failed to delete task from persister", "error", err, "task_id", id)
|
||||
// Delete removes from the in-memory map, deletes from the persister, and
|
||||
// broadcasts the removal to peer replicas.
|
||||
if err := s.tasks.Delete(context.Background(), id); err != nil {
|
||||
xlog.Warn("Failed to delete task from store", "error", err, "task_id", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -359,8 +393,8 @@ func (s *AgentJobService) DeleteTask(id string) error {
|
||||
|
||||
// GetTask retrieves a task by ID
|
||||
func (s *AgentJobService) GetTask(id string) (*schema.Task, error) {
|
||||
task := s.tasks.Get(id)
|
||||
if task.ID == "" {
|
||||
task, ok := s.tasks.Get(id)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: %s", ErrTaskNotFound, id)
|
||||
}
|
||||
return &task, nil
|
||||
@@ -368,7 +402,7 @@ func (s *AgentJobService) GetTask(id string) (*schema.Task, error) {
|
||||
|
||||
// ListTasks returns all tasks, sorted by creation date (newest first)
|
||||
func (s *AgentJobService) ListTasks() []schema.Task {
|
||||
tasks := s.tasks.Values()
|
||||
tasks := s.tasks.List()
|
||||
// Sort by CreatedAt descending (newest first), then by Name for stability
|
||||
slices.SortFunc(tasks, func(a, b schema.Task) int {
|
||||
if a.CreatedAt.Equal(b.CreatedAt) {
|
||||
@@ -397,8 +431,8 @@ func (s *AgentJobService) buildPrompt(templateStr string, params map[string]stri
|
||||
// ExecuteJob creates and queues a job for execution
|
||||
// multimedia can be nil for backward compatibility
|
||||
func (s *AgentJobService) ExecuteJob(taskID string, params map[string]string, triggeredBy string, multimedia *schema.MultimediaAttachment) (string, error) {
|
||||
task := s.tasks.Get(taskID)
|
||||
if task.ID == "" {
|
||||
task, ok := s.tasks.Get(taskID)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%w: %s", ErrTaskNotFound, taskID)
|
||||
}
|
||||
|
||||
@@ -1451,6 +1485,12 @@ func (s *AgentJobService) Stop() error {
|
||||
if s.cronScheduler != nil {
|
||||
s.cronScheduler.Stop()
|
||||
}
|
||||
// Release the tasks SyncedMap subscription / background workers.
|
||||
if s.tasks != nil {
|
||||
if err := s.tasks.Close(); err != nil {
|
||||
xlog.Warn("Error closing tasks sync map", "error", err)
|
||||
}
|
||||
}
|
||||
xlog.Info("AgentJobService stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,24 +14,38 @@ import (
|
||||
)
|
||||
|
||||
// fileJobPersister persists tasks and jobs to JSON files.
|
||||
// It holds references to the service's syncmaps and serializes the entire
|
||||
// map contents on each save (bulk write). Reads at runtime return nil
|
||||
// (the in-memory map is the authoritative source); LoadTasks/LoadJobs
|
||||
// are used only at startup to bootstrap the syncmaps.
|
||||
//
|
||||
// Jobs serialize the service's in-memory jobs syncmap on each save (bulk write).
|
||||
// Tasks are kept in this persister's own taskSet map instead: the tasks SyncedMap
|
||||
// calls SaveTask/DeleteTask while holding its internal lock (write-through), so
|
||||
// reading back the SyncedMap here would re-enter that lock and deadlock. The
|
||||
// self-contained taskSet, seeded by LoadTasks, lets a per-task write rewrite the
|
||||
// whole bulk file without touching the SyncedMap.
|
||||
//
|
||||
// Runtime reads (GetJob/ListJobs) return nil (the in-memory state is the
|
||||
// authoritative source); LoadTasks/LoadJobs bootstrap state at startup.
|
||||
type fileJobPersister struct {
|
||||
tasks *xsync.SyncedMap[string, schema.Task]
|
||||
jobs *xsync.SyncedMap[string, schema.Job]
|
||||
tasksFile string
|
||||
jobsFile string
|
||||
mu sync.Mutex
|
||||
// taskSet is the persister's own view of all tasks, seeded by LoadTasks and
|
||||
// updated by SaveTask/DeleteTask. The bulk JSON file is rewritten from it.
|
||||
taskSet map[string]schema.Task
|
||||
}
|
||||
|
||||
func (p *fileJobPersister) SaveTask(_ string, _ schema.Task) error {
|
||||
return p.saveTasksToFile()
|
||||
func (p *fileJobPersister) SaveTask(_ string, task schema.Task) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.taskSet[task.ID] = task
|
||||
return p.writeTasksLocked()
|
||||
}
|
||||
|
||||
func (p *fileJobPersister) DeleteTask(_ string) error {
|
||||
return p.saveTasksToFile()
|
||||
func (p *fileJobPersister) DeleteTask(taskID string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
delete(p.taskSet, taskID)
|
||||
return p.writeTasksLocked()
|
||||
}
|
||||
|
||||
func (p *fileJobPersister) SaveJob(_ string, _ schema.Job) error {
|
||||
@@ -43,7 +57,9 @@ func (p *fileJobPersister) DeleteJob(_ string) error {
|
||||
}
|
||||
|
||||
func (p *fileJobPersister) FlushTasks() error {
|
||||
return p.saveTasksToFile()
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.writeTasksLocked()
|
||||
}
|
||||
|
||||
func (p *fileJobPersister) FlushJobs() error {
|
||||
@@ -83,6 +99,12 @@ func (p *fileJobPersister) LoadTasks(_ string) ([]schema.Task, error) {
|
||||
return nil, fmt.Errorf("failed to parse tasks file: %w", err)
|
||||
}
|
||||
|
||||
// Seed the in-memory set so subsequent per-task SaveTask/DeleteTask merge into
|
||||
// (rather than overwrite) the persisted tasks when the bulk file is rewritten.
|
||||
for _, t := range tf.Tasks {
|
||||
p.taskSet[t.ID] = t
|
||||
}
|
||||
|
||||
xlog.Info("Loaded tasks from file", "count", len(tf.Tasks))
|
||||
return tf.Tasks, nil
|
||||
}
|
||||
@@ -118,19 +140,20 @@ func (p *fileJobPersister) CleanupOldJobs(_ time.Duration) (int64, error) {
|
||||
return 0, nil // cleanup handled via in-memory filtering
|
||||
}
|
||||
|
||||
// saveTasksToFile serializes the entire tasks map to the JSON file.
|
||||
func (p *fileJobPersister) saveTasksToFile() error {
|
||||
// writeTasksLocked serializes the persister's task set to the JSON file. Callers
|
||||
// must hold p.mu.
|
||||
func (p *fileJobPersister) writeTasksLocked() error {
|
||||
if p.tasksFile == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
tf := schema.TasksFile{
|
||||
Tasks: p.tasks.Values(),
|
||||
tasks := make([]schema.Task, 0, len(p.taskSet))
|
||||
for _, t := range p.taskSet {
|
||||
tasks = append(tasks, t)
|
||||
}
|
||||
|
||||
tf := schema.TasksFile{Tasks: tasks}
|
||||
|
||||
data, err := json.MarshalIndent(tf, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal tasks: %w", err)
|
||||
|
||||
@@ -20,28 +20,26 @@ var _ = Describe("JobPersister", func() {
|
||||
Context("fileJobPersister", func() {
|
||||
var (
|
||||
p *fileJobPersister
|
||||
tasks *xsync.SyncedMap[string, schema.Task]
|
||||
jobsMap *xsync.SyncedMap[string, schema.Job]
|
||||
tmpDir string
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
tmpDir = GinkgoT().TempDir()
|
||||
tasks = xsync.NewSyncedMap[string, schema.Task]()
|
||||
jobsMap = xsync.NewSyncedMap[string, schema.Job]()
|
||||
p = &fileJobPersister{
|
||||
tasks: tasks,
|
||||
jobs: jobsMap,
|
||||
tasksFile: filepath.Join(tmpDir, "tasks.json"),
|
||||
jobsFile: filepath.Join(tmpDir, "jobs.json"),
|
||||
// taskSet is the persister's own task view (decoupled from the tasks
|
||||
// SyncedMap to avoid re-entering its lock during write-through).
|
||||
taskSet: make(map[string]schema.Task),
|
||||
}
|
||||
})
|
||||
|
||||
It("SaveTask writes all tasks to file", func() {
|
||||
tasks.Set("t1", schema.Task{ID: "t1", Name: "Task One", Model: "m", Prompt: "p"})
|
||||
tasks.Set("t2", schema.Task{ID: "t2", Name: "Task Two", Model: "m", Prompt: "p"})
|
||||
|
||||
Expect(p.SaveTask("", schema.Task{})).To(Succeed())
|
||||
Expect(p.SaveTask("", schema.Task{ID: "t1", Name: "Task One", Model: "m", Prompt: "p"})).To(Succeed())
|
||||
Expect(p.SaveTask("", schema.Task{ID: "t2", Name: "Task Two", Model: "m", Prompt: "p"})).To(Succeed())
|
||||
|
||||
// Verify file contents
|
||||
data, err := os.ReadFile(p.tasksFile)
|
||||
@@ -52,11 +50,9 @@ var _ = Describe("JobPersister", func() {
|
||||
})
|
||||
|
||||
It("DeleteTask writes updated tasks to file", func() {
|
||||
tasks.Set("t1", schema.Task{ID: "t1", Name: "Keep"})
|
||||
tasks.Set("t2", schema.Task{ID: "t2", Name: "Delete"})
|
||||
Expect(p.SaveTask("", schema.Task{ID: "t1", Name: "Keep"})).To(Succeed())
|
||||
Expect(p.SaveTask("", schema.Task{ID: "t2", Name: "Delete"})).To(Succeed())
|
||||
|
||||
// Simulate deletion from memory (caller does this before calling persister)
|
||||
tasks.Delete("t2")
|
||||
Expect(p.DeleteTask("t2")).To(Succeed())
|
||||
|
||||
data, err := os.ReadFile(p.tasksFile)
|
||||
|
||||
152
core/services/agentpool/task_sync_test.go
Normal file
152
core/services/agentpool/task_sync_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package agentpool
|
||||
|
||||
// White-box tests (package agentpool) so a spec can build two AgentJobService
|
||||
// instances sharing one in-memory bus and assert that agent *tasks* converge
|
||||
// across replicas - the bug this migration fixes (ListTasks used to read
|
||||
// in-memory only, so a task created on replica A was invisible on replica B).
|
||||
// Jobs are deliberately untouched here: they already converge via the dispatcher
|
||||
// + DB read-through.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
||||
"github.com/mudler/LocalAI/core/services/testutil"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
)
|
||||
|
||||
// newTaskSyncService builds an AgentJobService wired to the given bus and a
|
||||
// throwaway data dir (so the file persister has somewhere to write). Model/config
|
||||
// loaders are nil because the task sync paths under test never touch them.
|
||||
func newTaskSyncService(bus messaging.MessagingClient) *AgentJobService {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
sysState := &system.SystemState{}
|
||||
sysState.Model.ModelsPath = tmpDir
|
||||
appConfig := config.NewApplicationConfig(
|
||||
config.WithDynamicConfigDir(tmpDir),
|
||||
config.WithContext(context.Background()),
|
||||
)
|
||||
appConfig.SystemState = sysState
|
||||
|
||||
svc := NewAgentJobServiceWithPaths(appConfig, nil, nil, nil,
|
||||
// Distinct per-replica files so the file persister write-through never
|
||||
// crosses replicas: convergence here must be proven via the bus alone.
|
||||
tmpDir+"/tasks.json", tmpDir+"/jobs.json")
|
||||
svc.SetTaskSyncNATS(bus)
|
||||
return svc
|
||||
}
|
||||
|
||||
var _ = Describe("AgentJobService task cross-replica sync", func() {
|
||||
Describe("two replicas sharing one bus", func() {
|
||||
var (
|
||||
bus *testutil.FakeBus
|
||||
a, b *AgentJobService
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
// One shared bus, two replicas: exactly the distributed topology where a
|
||||
// round-robin request may land on a replica that did not originate the
|
||||
// change.
|
||||
bus = testutil.NewFakeBus()
|
||||
a = newTaskSyncService(bus)
|
||||
b = newTaskSyncService(bus)
|
||||
// Start hydrates (empty here) and subscribes both replicas to deltas.
|
||||
Expect(a.Start(context.Background())).To(Succeed())
|
||||
Expect(b.Start(context.Background())).To(Succeed())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(a.Stop()).To(Succeed())
|
||||
Expect(b.Stop()).To(Succeed())
|
||||
})
|
||||
|
||||
It("makes a task created on A visible via B's GetTask and ListTasks", func() {
|
||||
id, err := a.CreateTask(schema.Task{Name: "Shared", Model: "m", Prompt: "p"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
got, err := b.GetTask(id)
|
||||
Expect(err).NotTo(HaveOccurred(), "B must see a task A just created")
|
||||
Expect(got.Name).To(Equal("Shared"))
|
||||
|
||||
listed := b.ListTasks()
|
||||
Expect(listed).To(HaveLen(1))
|
||||
Expect(listed[0].ID).To(Equal(id))
|
||||
})
|
||||
|
||||
It("propagates a task update from A to B", func() {
|
||||
id, err := a.CreateTask(schema.Task{Name: "Before", Model: "m", Prompt: "p"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(a.UpdateTask(id, schema.Task{Name: "After", Model: "m", Prompt: "p"})).To(Succeed())
|
||||
|
||||
got, err := b.GetTask(id)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(got.Name).To(Equal("After"), "an update on A must be visible on B")
|
||||
})
|
||||
|
||||
It("removes a task from B when it is deleted on A", func() {
|
||||
id, err := a.CreateTask(schema.Task{Name: "Doomed", Model: "m", Prompt: "p"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = b.GetTask(id)
|
||||
Expect(err).NotTo(HaveOccurred(), "precondition: B must have the task before the delete")
|
||||
|
||||
Expect(a.DeleteTask(id)).To(Succeed())
|
||||
|
||||
_, err = b.GetTask(id)
|
||||
Expect(err).To(HaveOccurred(), "a delete on A must remove the task from B")
|
||||
Expect(b.ListTasks()).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("does not re-broadcast a delta it received (echo-loop guard)", func() {
|
||||
subject := messaging.SubjectSyncStateDelta("agent.tasks")
|
||||
|
||||
_, err := a.CreateTask(schema.Task{Name: "Once", Model: "m", Prompt: "p"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Exactly one publish: A's create. B applies it without re-publishing,
|
||||
// otherwise this would be 2+ and a real bus would storm.
|
||||
Expect(bus.PublishCount(subject)).To(Equal(1))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ListTasks ordering and scoping", func() {
|
||||
var svc *AgentJobService
|
||||
|
||||
BeforeEach(func() {
|
||||
svc = newTaskSyncService(testutil.NewFakeBus())
|
||||
Expect(svc.Start(context.Background())).To(Succeed())
|
||||
})
|
||||
AfterEach(func() { Expect(svc.Stop()).To(Succeed()) })
|
||||
|
||||
It("sorts newest-first, breaking ties by name", func() {
|
||||
// CreateTask stamps CreatedAt with time.Now(); space them out so ordering
|
||||
// is deterministic rather than relying on the sub-millisecond gap.
|
||||
oldID, err := svc.CreateTask(schema.Task{Name: "Old", Model: "m", Prompt: "p"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
newID, err := svc.CreateTask(schema.Task{Name: "New", Model: "m", Prompt: "p"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
listed := svc.ListTasks()
|
||||
Expect(listed).To(HaveLen(2))
|
||||
Expect(listed[0].ID).To(Equal(newID), "newest first")
|
||||
Expect(listed[1].ID).To(Equal(oldID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("compile-time adapter contract", func() {
|
||||
It("satisfies syncstate.Store for tasks", func() {
|
||||
// Mirrors the var assertion in task_syncstore.go; keeps the type
|
||||
// referenced from a spec so drift surfaces here too.
|
||||
var _ syncstate.Store[string, schema.Task] = (*taskStoreAdapter)(nil)
|
||||
Expect(&taskStoreAdapter{}).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
47
core/services/agentpool/task_syncstore.go
Normal file
47
core/services/agentpool/task_syncstore.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package agentpool
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/syncstate"
|
||||
)
|
||||
|
||||
// taskStoreAdapter bridges the existing JobPersister (file- or DB-backed) to the
|
||||
// generic syncstate.Store the tasks SyncedMap consumes. Only tasks are migrated:
|
||||
// jobs already converge across replicas via the dispatcher (NATS) plus the DB
|
||||
// read-through in ListJobs/GetJob, whereas ListTasks read in-memory only and so
|
||||
// went stale on replicas that did not originate the change.
|
||||
//
|
||||
// The adapter reads svc.persister and svc.userID live (rather than capturing
|
||||
// them) because both are configured by setters - SetDistributedJobStore swaps the
|
||||
// file persister for the DB one, SetUserID scopes per-user queries - AFTER the
|
||||
// service, and thus this adapter, is constructed. Reading them at call time means
|
||||
// the SyncedMap never has to be rebuilt when the persister is swapped.
|
||||
//
|
||||
// The SyncedMap value type is schema.Task: the exact shape ListTasks returns, so
|
||||
// reads need no conversion and REST responses are provably unchanged.
|
||||
type taskStoreAdapter struct {
|
||||
svc *AgentJobService
|
||||
}
|
||||
|
||||
// compile-time assertion that the adapter satisfies the component's Store.
|
||||
var _ syncstate.Store[string, schema.Task] = (*taskStoreAdapter)(nil)
|
||||
|
||||
// List hydrates the map from durable storage on Start/reconnect: the file's task
|
||||
// list (standalone) or every task row (DB / distributed).
|
||||
func (a *taskStoreAdapter) List(_ context.Context) ([]schema.Task, error) {
|
||||
return a.svc.persister.LoadTasks(a.svc.userID)
|
||||
}
|
||||
|
||||
// Upsert write-through persists a single task created/updated locally; the
|
||||
// SyncedMap then broadcasts the delta to peers.
|
||||
func (a *taskStoreAdapter) Upsert(_ context.Context, task schema.Task) error {
|
||||
return a.svc.persister.SaveTask(a.svc.userID, task)
|
||||
}
|
||||
|
||||
// Delete write-through removes a task locally; the SyncedMap then broadcasts the
|
||||
// removal to peers.
|
||||
func (a *taskStoreAdapter) Delete(_ context.Context, id string) error {
|
||||
return a.svc.persister.DeleteTask(id)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/mudler/LocalAGI/webui/collections"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -28,6 +29,9 @@ type UserServicesManager struct {
|
||||
// Shared distributed backends (set once, inherited by per-user job services)
|
||||
jobDispatcher DistributedDispatcher
|
||||
jobDBStore *jobs.JobStore
|
||||
// jobNats keeps per-user agent tasks consistent across replicas (nil in
|
||||
// standalone). Inherited by each per-user AgentJobService.
|
||||
jobNats messaging.MessagingClient
|
||||
}
|
||||
|
||||
// NewUserServicesManager creates a new UserServicesManager.
|
||||
@@ -162,6 +166,10 @@ func (m *UserServicesManager) GetJobs(userID string) (*AgentJobService, error) {
|
||||
if m.jobDispatcher != nil {
|
||||
svc.SetDistributedBackends(m.jobDispatcher)
|
||||
}
|
||||
// Inherit the NATS client so per-user tasks broadcast across replicas. Must be
|
||||
// set before the hydrate below (LoadFromDB / LoadTasksFromFile) so the tasks
|
||||
// SyncedMap is rebuilt with the client while it is still empty.
|
||||
svc.SetTaskSyncNATS(m.jobNats)
|
||||
if m.jobDBStore != nil {
|
||||
svc.SetDistributedJobStore(m.jobDBStore)
|
||||
// Load tasks/jobs from DB immediately (per-user services skip Start())
|
||||
@@ -189,6 +197,12 @@ func (m *UserServicesManager) SetJobDBStore(s *jobs.JobStore) {
|
||||
m.jobDBStore = s
|
||||
}
|
||||
|
||||
// SetJobSyncNATS sets the NATS client used to keep per-user agent tasks consistent
|
||||
// across replicas.
|
||||
func (m *UserServicesManager) SetJobSyncNATS(nats messaging.MessagingClient) {
|
||||
m.jobNats = nats
|
||||
}
|
||||
|
||||
// ListAllUserIDs returns all user IDs that have scoped data directories.
|
||||
func (m *UserServicesManager) ListAllUserIDs() ([]string, error) {
|
||||
return m.storage.ListUserDirs()
|
||||
|
||||
Reference in New Issue
Block a user