mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-05 23:36:49 -04:00
fix(distributed): self-heal stale 'model not loaded' routing (#10181)
* fix(distributed): self-heal stale 'model not loaded' routing In distributed mode the registry can list a model as loaded on a node while the worker has evicted it (autonomous LRU eviction, an out-of-band unload, etc.) yet the backend process survives. The router's cached-node check only verifies the process is alive (probeHealth), so it routes there and inference fails with "<backend>: model not loaded" — and stays broken until the controller restarts and rebuilds its registry. InFlightTrackingClient now reconciles this: when a tracked inference call returns a model-not-loaded error, it drops the stale replica row (RemoveNodeModel) so the next request reloads the model on a healthy node instead of routing back to the evicted one. The original error is returned unchanged; only the registry is corrected. Assisted-by: Claude:claude-opus-4-8 go vet Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(distributed): typed model-not-loaded error via gRPC status code Replace the controller-side error-string match with a shared, code-aware helper. Go error types don't survive the gRPC boundary, so the signal is carried as a status code (FailedPrecondition): - pkg/grpc/grpcerrors: ModelNotLoaded(backend) constructor + IsModelNotLoaded(err) checker (status-code first, message fallback for backends not yet migrated). - InFlightTrackingClient.reconcile now uses grpcerrors.IsModelNotLoaded. - Migrate the Go backends that emit this error (parakeet-cpp, cloud-proxy, rfdetr-cpp) to the typed constructor. Acting on a false positive is harmless (the model is just reloaded). Assisted-by: Claude:claude-opus-4-8 go vet Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
@@ -145,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
|
||||
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cloud-proxy: model not loaded")
|
||||
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -175,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
|
||||
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modeTranslate {
|
||||
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
|
||||
@@ -269,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,
|
||||
|
||||
cfg := c.cfg.Load()
|
||||
if cfg == nil {
|
||||
return errors.New("cloud-proxy: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("cloud-proxy")
|
||||
}
|
||||
if cfg.mode != modePassthrough {
|
||||
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
@@ -230,7 +231,7 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
|
||||
// (L2).
|
||||
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if p.ctxPtr == 0 {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
|
||||
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
@@ -351,7 +352,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
|
||||
defer close(results)
|
||||
|
||||
if p.ctxPtr == 0 {
|
||||
return errors.New("parakeet-cpp: model not loaded")
|
||||
return grpcerrors.ModelNotLoaded("parakeet-cpp")
|
||||
}
|
||||
if opts.Dst == "" {
|
||||
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
ggrpc "google.golang.org/grpc"
|
||||
@@ -64,64 +65,95 @@ func (c *InFlightTrackingClient) track(ctx context.Context) func() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconcile self-heals stale routing: when a backend reports that the model is
|
||||
// no longer loaded (the process survived but the model was evicted, while the
|
||||
// registry still lists it as loaded), it drops the replica row so the next
|
||||
// request triggers a fresh load instead of routing back here. Without this the
|
||||
// model stays unreachable until the controller restarts. The original error is
|
||||
// returned unchanged.
|
||||
func (c *InFlightTrackingClient) reconcile(err error) error {
|
||||
if !grpcerrors.IsModelNotLoaded(err) {
|
||||
return err
|
||||
}
|
||||
rmCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if rmErr := c.registry.RemoveNodeModel(rmCtx, c.nodeID, c.modelName, c.replicaIndex); rmErr != nil {
|
||||
xlog.Warn("Failed to drop stale replica after model-not-loaded",
|
||||
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex, "error", rmErr)
|
||||
} else {
|
||||
xlog.Warn("Backend reports model not loaded; dropped stale replica so the next request reloads",
|
||||
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// --- Tracked inference methods ---
|
||||
|
||||
func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Predict(ctx, in, opts...)
|
||||
reply, err := c.Backend.Predict(ctx, in, opts...)
|
||||
return reply, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.PredictStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Embeddings(ctx, in, opts...)
|
||||
res, err := c.Backend.Embeddings(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.GenerateImage(ctx, in, opts...)
|
||||
res, err := c.Backend.GenerateImage(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.GenerateVideo(ctx, in, opts...)
|
||||
res, err := c.Backend.GenerateVideo(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.TTS(ctx, in, opts...)
|
||||
res, err := c.Backend.TTS(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.TTSStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.SoundGeneration(ctx, in, opts...)
|
||||
res, err := c.Backend.SoundGeneration(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.AudioTranscription(ctx, in, opts...)
|
||||
res, err := c.Backend.AudioTranscription(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)
|
||||
return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...))
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Detect(ctx, in, opts...)
|
||||
res, err := c.Backend.Detect(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
|
||||
defer c.track(ctx)()
|
||||
return c.Backend.Rerank(ctx, in, opts...)
|
||||
res, err := c.Backend.Rerank(ctx, in, opts...)
|
||||
return res, c.reconcile(err)
|
||||
}
|
||||
|
||||
@@ -20,9 +20,17 @@ type fakeInFlightTracker struct {
|
||||
mu sync.Mutex
|
||||
increments int
|
||||
decrements int
|
||||
removed int
|
||||
incrementErr error
|
||||
}
|
||||
|
||||
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.removed++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
@@ -295,4 +303,33 @@ var _ = Describe("InFlightTrackingClient", func() {
|
||||
Expect(tracker.decrements).To(Equal(1))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("stale model reload (self-heal)", func() {
|
||||
It("removes the replica when the backend reports the model is not loaded", func() {
|
||||
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(1))
|
||||
})
|
||||
|
||||
It("keeps the replica on an unrelated error", func() {
|
||||
backend.predictErr = fmt.Errorf("context deadline exceeded")
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(0))
|
||||
})
|
||||
|
||||
It("does not remove on success", func() {
|
||||
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(0))
|
||||
})
|
||||
|
||||
It("self-heals on a streamed call too", func() {
|
||||
backend.streamErr = fmt.Errorf("whisper: model not loaded")
|
||||
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(tracker.removed).To(Equal(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -78,6 +78,9 @@ type ModelLookup interface {
|
||||
type InFlightTracker interface {
|
||||
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
// RemoveNodeModel drops a stale replica row so the next request reloads the
|
||||
// model instead of routing back to a node where it is no longer loaded.
|
||||
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
|
||||
}
|
||||
|
||||
// NodeManager is used by HTTP endpoints for node registration and lifecycle.
|
||||
|
||||
35
pkg/grpc/grpcerrors/errors.go
Normal file
35
pkg/grpc/grpcerrors/errors.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Package grpcerrors defines well-known error signals shared between backends
|
||||
// (which produce them) and the router (which consumes them). Go error types do
|
||||
// not survive the gRPC boundary, so these conditions are carried as gRPC status
|
||||
// codes and detected via the code rather than by matching the error message.
|
||||
package grpcerrors
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// ModelNotLoaded returns the canonical error a backend returns when it has no
|
||||
// model loaded for the request. It carries codes.FailedPrecondition so callers
|
||||
// can detect it across the gRPC boundary without matching the message string.
|
||||
func ModelNotLoaded(backend string) error {
|
||||
return status.Errorf(codes.FailedPrecondition, "%s: model not loaded", backend)
|
||||
}
|
||||
|
||||
// IsModelNotLoaded reports whether err signals that the backend has no model
|
||||
// loaded. It prefers the typed gRPC status code (FailedPrecondition) and falls
|
||||
// back to the message for backends that have not yet adopted ModelNotLoaded.
|
||||
//
|
||||
// Acting on a false positive is harmless: the only consequence upstream is that
|
||||
// the model is reloaded, which is idempotent.
|
||||
func IsModelNotLoaded(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if status.Code(err) == codes.FailedPrecondition {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "model not loaded")
|
||||
}
|
||||
37
pkg/grpc/grpcerrors/errors_test.go
Normal file
37
pkg/grpc/grpcerrors/errors_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package grpcerrors_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestGRPCErrors(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "grpcerrors test suite")
|
||||
}
|
||||
|
||||
var _ = Describe("grpcerrors", func() {
|
||||
DescribeTable("IsModelNotLoaded",
|
||||
func(err error, want bool) {
|
||||
Expect(grpcerrors.IsModelNotLoaded(err)).To(Equal(want))
|
||||
},
|
||||
Entry("nil", nil, false),
|
||||
Entry("typed via constructor", grpcerrors.ModelNotLoaded("parakeet-cpp"), true),
|
||||
Entry("typed code only", status.Error(codes.FailedPrecondition, "anything"), true),
|
||||
Entry("legacy message (Unknown code)", errors.New("parakeet-cpp: model not loaded"), true),
|
||||
Entry("legacy message mixed case", errors.New("Backend: Model Not Loaded"), true),
|
||||
Entry("unrelated error", errors.New("context deadline exceeded"), false),
|
||||
Entry("unrelated grpc code", status.Error(codes.Unavailable, "connection refused"), false),
|
||||
)
|
||||
|
||||
It("ModelNotLoaded carries FailedPrecondition", func() {
|
||||
Expect(status.Code(grpcerrors.ModelNotLoaded("whisper"))).To(Equal(codes.FailedPrecondition))
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user