diff --git a/core/services/nodes/inflight.go b/core/services/nodes/inflight.go index 60a31a1c6..5ea9d435a 100644 --- a/core/services/nodes/inflight.go +++ b/core/services/nodes/inflight.go @@ -2,6 +2,7 @@ package nodes import ( "context" + "strings" "sync" "time" @@ -64,64 +65,101 @@ func (c *InFlightTrackingClient) track(ctx context.Context) func() { } } +// reconcile self-heals stale routing: when a backend reports that the model is +// no longer loaded (the process survived but the model was evicted, while the +// registry still lists it as loaded), it drops the replica row so the next +// request triggers a fresh load instead of routing back here. Without this the +// model stays unreachable until the controller restarts. The original error is +// returned unchanged. +func (c *InFlightTrackingClient) reconcile(err error) error { + if !isModelNotLoaded(err) { + return err + } + rmCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if rmErr := c.registry.RemoveNodeModel(rmCtx, c.nodeID, c.modelName, c.replicaIndex); rmErr != nil { + xlog.Warn("Failed to drop stale replica after model-not-loaded", + "node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex, "error", rmErr) + } else { + xlog.Warn("Backend reports model not loaded; dropped stale replica so the next request reloads", + "node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex) + } + return err +} + +// isModelNotLoaded reports whether err is a backend "model not loaded" response. +// Backends phrase it as ": model not loaded", so match on the suffix. +func isModelNotLoaded(err error) bool { + return err != nil && strings.Contains(strings.ToLower(err.Error()), "model not loaded") +} + // --- Tracked inference methods --- func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) { defer c.track(ctx)() - return c.Backend.Predict(ctx, in, opts...) + reply, err := c.Backend.Predict(ctx, in, opts...) + return reply, c.reconcile(err) } func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error { defer c.track(ctx)() - return c.Backend.PredictStream(ctx, in, f, opts...) + return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...)) } func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) { defer c.track(ctx)() - return c.Backend.Embeddings(ctx, in, opts...) + res, err := c.Backend.Embeddings(ctx, in, opts...) + return res, c.reconcile(err) } func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { defer c.track(ctx)() - return c.Backend.GenerateImage(ctx, in, opts...) + res, err := c.Backend.GenerateImage(ctx, in, opts...) + return res, c.reconcile(err) } func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { defer c.track(ctx)() - return c.Backend.GenerateVideo(ctx, in, opts...) + res, err := c.Backend.GenerateVideo(ctx, in, opts...) + return res, c.reconcile(err) } func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { defer c.track(ctx)() - return c.Backend.TTS(ctx, in, opts...) + res, err := c.Backend.TTS(ctx, in, opts...) + return res, c.reconcile(err) } func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error { defer c.track(ctx)() - return c.Backend.TTSStream(ctx, in, f, opts...) + return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...)) } func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { defer c.track(ctx)() - return c.Backend.SoundGeneration(ctx, in, opts...) + res, err := c.Backend.SoundGeneration(ctx, in, opts...) + return res, c.reconcile(err) } func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) { defer c.track(ctx)() - return c.Backend.AudioTranscription(ctx, in, opts...) + res, err := c.Backend.AudioTranscription(ctx, in, opts...) + return res, c.reconcile(err) } func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error { defer c.track(ctx)() - return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...) + return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)) } func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) { defer c.track(ctx)() - return c.Backend.Detect(ctx, in, opts...) + res, err := c.Backend.Detect(ctx, in, opts...) + return res, c.reconcile(err) } func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) { defer c.track(ctx)() - return c.Backend.Rerank(ctx, in, opts...) + res, err := c.Backend.Rerank(ctx, in, opts...) + return res, c.reconcile(err) } diff --git a/core/services/nodes/inflight_test.go b/core/services/nodes/inflight_test.go index 60a5299cc..b689faaf8 100644 --- a/core/services/nodes/inflight_test.go +++ b/core/services/nodes/inflight_test.go @@ -20,9 +20,17 @@ type fakeInFlightTracker struct { mu sync.Mutex increments int decrements int + removed int incrementErr error } +func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error { + f.mu.Lock() + defer f.mu.Unlock() + f.removed++ + return nil +} + func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error { f.mu.Lock() defer f.mu.Unlock() @@ -295,4 +303,33 @@ var _ = Describe("InFlightTrackingClient", func() { Expect(tracker.decrements).To(Equal(1)) }) }) + + Describe("stale model reload (self-heal)", func() { + It("removes the replica when the backend reports the model is not loaded", func() { + backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded") + _, err := client.Predict(context.Background(), &pb.PredictOptions{}) + Expect(err).To(HaveOccurred()) + Expect(tracker.removed).To(Equal(1)) + }) + + It("keeps the replica on an unrelated error", func() { + backend.predictErr = fmt.Errorf("context deadline exceeded") + _, err := client.Predict(context.Background(), &pb.PredictOptions{}) + Expect(err).To(HaveOccurred()) + Expect(tracker.removed).To(Equal(0)) + }) + + It("does not remove on success", func() { + _, err := client.Predict(context.Background(), &pb.PredictOptions{}) + Expect(err).ToNot(HaveOccurred()) + Expect(tracker.removed).To(Equal(0)) + }) + + It("self-heals on a streamed call too", func() { + backend.streamErr = fmt.Errorf("whisper: model not loaded") + err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {}) + Expect(err).To(HaveOccurred()) + Expect(tracker.removed).To(Equal(1)) + }) + }) }) diff --git a/core/services/nodes/interfaces.go b/core/services/nodes/interfaces.go index 4e82d56cf..38aa34b49 100644 --- a/core/services/nodes/interfaces.go +++ b/core/services/nodes/interfaces.go @@ -78,6 +78,9 @@ type ModelLookup interface { type InFlightTracker interface { IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error + // RemoveNodeModel drops a stale replica row so the next request reloads the + // model instead of routing back to a node where it is no longer loaded. + RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error } // NodeManager is used by HTTP endpoints for node registration and lifecycle.