fix(distributed): self-heal stale 'model not loaded' routing

In distributed mode the registry can list a model as loaded on a node
while the worker has evicted it (autonomous LRU eviction, an out-of-band
unload, etc.) yet the backend process survives. The router's cached-node
check only verifies the process is alive (probeHealth), so it routes there
and inference fails with "<backend>: model not loaded" — and stays broken
until the controller restarts and rebuilds its registry.

InFlightTrackingClient now reconciles this: when a tracked inference call
returns a model-not-loaded error, it drops the stale replica row
(RemoveNodeModel) so the next request reloads the model on a healthy node
instead of routing back to the evicted one. The original error is returned
unchanged; only the registry is corrected.

Assisted-by: Claude:claude-opus-4-8 go vet
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-04 22:33:30 +00:00
parent ef80a0e825
commit 9f41e69bc3
3 changed files with 90 additions and 12 deletions

View File

@@ -2,6 +2,7 @@ package nodes
import (
"context"
"strings"
"sync"
"time"
@@ -64,64 +65,101 @@ func (c *InFlightTrackingClient) track(ctx context.Context) func() {
}
}
// reconcile self-heals stale routing: when a backend reports that the model is
// no longer loaded (the process survived but the model was evicted, while the
// registry still lists it as loaded), it drops the replica row so the next
// request triggers a fresh load instead of routing back here. Without this the
// model stays unreachable until the controller restarts. The original error is
// returned unchanged.
func (c *InFlightTrackingClient) reconcile(err error) error {
if !isModelNotLoaded(err) {
return err
}
rmCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if rmErr := c.registry.RemoveNodeModel(rmCtx, c.nodeID, c.modelName, c.replicaIndex); rmErr != nil {
xlog.Warn("Failed to drop stale replica after model-not-loaded",
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex, "error", rmErr)
} else {
xlog.Warn("Backend reports model not loaded; dropped stale replica so the next request reloads",
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex)
}
return err
}
// isModelNotLoaded reports whether err is a backend "model not loaded" response.
// Backends phrase it as "<backend>: model not loaded", so match on the suffix.
func isModelNotLoaded(err error) bool {
return err != nil && strings.Contains(strings.ToLower(err.Error()), "model not loaded")
}
// --- Tracked inference methods ---
func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
defer c.track(ctx)()
return c.Backend.Predict(ctx, in, opts...)
reply, err := c.Backend.Predict(ctx, in, opts...)
return reply, c.reconcile(err)
}
func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.PredictStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
defer c.track(ctx)()
return c.Backend.Embeddings(ctx, in, opts...)
res, err := c.Backend.Embeddings(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.GenerateImage(ctx, in, opts...)
res, err := c.Backend.GenerateImage(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.GenerateVideo(ctx, in, opts...)
res, err := c.Backend.GenerateVideo(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.TTS(ctx, in, opts...)
res, err := c.Backend.TTS(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.TTSStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.SoundGeneration(ctx, in, opts...)
res, err := c.Backend.SoundGeneration(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
defer c.track(ctx)()
return c.Backend.AudioTranscription(ctx, in, opts...)
res, err := c.Backend.AudioTranscription(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
defer c.track(ctx)()
return c.Backend.Detect(ctx, in, opts...)
res, err := c.Backend.Detect(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
defer c.track(ctx)()
return c.Backend.Rerank(ctx, in, opts...)
res, err := c.Backend.Rerank(ctx, in, opts...)
return res, c.reconcile(err)
}

View File

@@ -20,9 +20,17 @@ type fakeInFlightTracker struct {
mu sync.Mutex
increments int
decrements int
removed int
incrementErr error
}
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.removed++
return nil
}
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
@@ -295,4 +303,33 @@ var _ = Describe("InFlightTrackingClient", func() {
Expect(tracker.decrements).To(Equal(1))
})
})
Describe("stale model reload (self-heal)", func() {
It("removes the replica when the backend reports the model is not loaded", func() {
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
It("keeps the replica on an unrelated error", func() {
backend.predictErr = fmt.Errorf("context deadline exceeded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})
It("does not remove on success", func() {
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})
It("self-heals on a streamed call too", func() {
backend.streamErr = fmt.Errorf("whisper: model not loaded")
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
})
})

View File

@@ -78,6 +78,9 @@ type ModelLookup interface {
type InFlightTracker interface {
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
// RemoveNodeModel drops a stale replica row so the next request reloads the
// model instead of routing back to a node where it is no longer loaded.
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
}
// NodeManager is used by HTTP endpoints for node registration and lifecycle.