Files
LocalAI/core/services/nodes/inflight_test.go
LocalAI [bot] fba8c9c498 fix(distributed): track in-flight for non-LLM inference methods (VAD, diarize, voice, ...) (#10238)
fix(distributed): track in-flight for non-LLM inference methods

InFlightTrackingClient only wrapped a subset of the grpc.Backend
inference methods (Predict, Embeddings, TTS, AudioTranscription, Detect,
Rerank, ...). Methods like VAD were left as embedded passthrough, so
track() never ran for them.

In distributed mode every model is loaded with in_flight=1 as a
reservation; that reservation is only released by the OnFirstComplete
callback, which fires after the first *tracked* inference call completes.
A VAD-only model (e.g. silero-vad) never calls a tracked method, so the
reservation is never released and in-flight stays pinned at 1 forever -
which also blocks the router's idle-eviction logic.

Wrap the remaining unary inference methods (VAD, Diarize, Face*, Voice*,
TokenClassify, Score, AudioEncode, AudioDecode, AudioTransform) with the
same track()/reconcile() pattern. The three bidi-stream constructors
(AudioTransformStream, AudioToAudioStream, Forward) are deliberately left
as passthrough - their inference spans the stream lifetime, not the
constructor call, so track() there would fire onFirstComplete before any
data flows.

Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-10 16:29:50 +02:00

435 lines
14 KiB
Go

package nodes
import (
"context"
"fmt"
"sync"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
ggrpc "google.golang.org/grpc"
)
// --- Fakes ---
// fakeInFlightTracker implements InFlightTracker, counting calls.
type fakeInFlightTracker struct {
mu sync.Mutex
increments int
decrements int
removed int
incrementErr error
}
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.removed++
return nil
}
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.increments++
return f.incrementErr
}
func (f *fakeInFlightTracker) DecrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.decrements++
return nil
}
// fakeGRPCBackend implements grpc.Backend with stub methods.
// Only the methods we test (Predict, PredictStream) have real behavior;
// the rest panic if called unexpectedly.
type fakeGRPCBackend struct {
predictReply *pb.Reply
predictErr error
streamReplies []*pb.Reply
streamErr error
}
func (f *fakeGRPCBackend) IsBusy() bool { return false }
func (f *fakeGRPCBackend) HealthCheck(_ context.Context) (bool, error) { return true, nil }
func (f *fakeGRPCBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) Predict(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.Reply, error) {
return f.predictReply, f.predictErr
}
func (f *fakeGRPCBackend) PredictStream(_ context.Context, _ *pb.PredictOptions, fn func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
for _, r := range f.streamReplies {
fn(r)
}
return f.streamErr
}
func (f *fakeGRPCBackend) Embeddings(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
return &pb.EmbeddingResult{}, nil
}
func (f *fakeGRPCBackend) GenerateImage(_ context.Context, _ *pb.GenerateImageRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) GenerateVideo(_ context.Context, _ *pb.GenerateVideoRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) TTS(_ context.Context, _ *pb.TTSRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) TTSStream(_ context.Context, _ *pb.TTSRequest, _ func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) SoundGeneration(_ context.Context, _ *pb.SoundGenerationRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) Detect(_ context.Context, _ *pb.DetectOptions, _ ...ggrpc.CallOption) (*pb.DetectResponse, error) {
return &pb.DetectResponse{}, nil
}
func (f *fakeGRPCBackend) FaceVerify(_ context.Context, _ *pb.FaceVerifyRequest, _ ...ggrpc.CallOption) (*pb.FaceVerifyResponse, error) {
return &pb.FaceVerifyResponse{}, nil
}
func (f *fakeGRPCBackend) FaceAnalyze(_ context.Context, _ *pb.FaceAnalyzeRequest, _ ...ggrpc.CallOption) (*pb.FaceAnalyzeResponse, error) {
return &pb.FaceAnalyzeResponse{}, nil
}
func (f *fakeGRPCBackend) VoiceVerify(_ context.Context, _ *pb.VoiceVerifyRequest, _ ...ggrpc.CallOption) (*pb.VoiceVerifyResponse, error) {
return &pb.VoiceVerifyResponse{}, nil
}
func (f *fakeGRPCBackend) VoiceAnalyze(_ context.Context, _ *pb.VoiceAnalyzeRequest, _ ...ggrpc.CallOption) (*pb.VoiceAnalyzeResponse, error) {
return &pb.VoiceAnalyzeResponse{}, nil
}
func (f *fakeGRPCBackend) VoiceEmbed(_ context.Context, _ *pb.VoiceEmbedRequest, _ ...ggrpc.CallOption) (*pb.VoiceEmbedResponse, error) {
return &pb.VoiceEmbedResponse{}, nil
}
func (f *fakeGRPCBackend) AudioTranscription(_ context.Context, _ *pb.TranscriptRequest, _ ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
return &pb.TranscriptResult{}, nil
}
func (f *fakeGRPCBackend) AudioTranscriptionStream(_ context.Context, _ *pb.TranscriptRequest, _ func(chunk *pb.TranscriptStreamResponse), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
return &pb.TokenizationResponse{}, nil
}
func (f *fakeGRPCBackend) Status(_ context.Context) (*pb.StatusResponse, error) {
return &pb.StatusResponse{}, nil
}
func (f *fakeGRPCBackend) StoresSet(_ context.Context, _ *pb.StoresSetOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StoresDelete(_ context.Context, _ *pb.StoresDeleteOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StoresGet(_ context.Context, _ *pb.StoresGetOptions, _ ...ggrpc.CallOption) (*pb.StoresGetResult, error) {
return &pb.StoresGetResult{}, nil
}
func (f *fakeGRPCBackend) StoresFind(_ context.Context, _ *pb.StoresFindOptions, _ ...ggrpc.CallOption) (*pb.StoresFindResult, error) {
return &pb.StoresFindResult{}, nil
}
func (f *fakeGRPCBackend) Rerank(_ context.Context, _ *pb.RerankRequest, _ ...ggrpc.CallOption) (*pb.RerankResult, error) {
return &pb.RerankResult{}, nil
}
func (f *fakeGRPCBackend) GetTokenMetrics(_ context.Context, _ *pb.MetricsRequest, _ ...ggrpc.CallOption) (*pb.MetricsResponse, error) {
return &pb.MetricsResponse{}, nil
}
func (f *fakeGRPCBackend) VAD(_ context.Context, _ *pb.VADRequest, _ ...ggrpc.CallOption) (*pb.VADResponse, error) {
return &pb.VADResponse{}, nil
}
func (f *fakeGRPCBackend) Diarize(_ context.Context, _ *pb.DiarizeRequest, _ ...ggrpc.CallOption) (*pb.DiarizeResponse, error) {
return &pb.DiarizeResponse{}, nil
}
func (f *fakeGRPCBackend) AudioEncode(_ context.Context, _ *pb.AudioEncodeRequest, _ ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) {
return &pb.AudioEncodeResult{}, nil
}
func (f *fakeGRPCBackend) AudioDecode(_ context.Context, _ *pb.AudioDecodeRequest, _ ...ggrpc.CallOption) (*pb.AudioDecodeResult, error) {
return &pb.AudioDecodeResult{}, nil
}
func (f *fakeGRPCBackend) AudioTransform(_ context.Context, _ *pb.AudioTransformRequest, _ ...ggrpc.CallOption) (*pb.AudioTransformResult, error) {
return &pb.AudioTransformResult{}, nil
}
func (f *fakeGRPCBackend) AudioTransformStream(_ context.Context, _ ...ggrpc.CallOption) (grpc.AudioTransformStreamClient, error) {
return nil, nil
}
func (f *fakeGRPCBackend) AudioToAudioStream(_ context.Context, _ ...ggrpc.CallOption) (grpc.AudioToAudioStreamClient, error) {
return nil, nil
}
func (f *fakeGRPCBackend) Forward(_ context.Context, _ ...ggrpc.CallOption) (grpc.ForwardClient, error) {
return nil, nil
}
func (f *fakeGRPCBackend) ModelMetadata(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.ModelMetadataResponse, error) {
return &pb.ModelMetadataResponse{}, nil
}
func (f *fakeGRPCBackend) StartFineTune(_ context.Context, _ *pb.FineTuneRequest, _ ...ggrpc.CallOption) (*pb.FineTuneJobResult, error) {
return &pb.FineTuneJobResult{}, nil
}
func (f *fakeGRPCBackend) FineTuneProgress(_ context.Context, _ *pb.FineTuneProgressRequest, _ func(update *pb.FineTuneProgressUpdate), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) StopFineTune(_ context.Context, _ *pb.FineTuneStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) ListCheckpoints(_ context.Context, _ *pb.ListCheckpointsRequest, _ ...ggrpc.CallOption) (*pb.ListCheckpointsResponse, error) {
return &pb.ListCheckpointsResponse{}, nil
}
func (f *fakeGRPCBackend) ExportModel(_ context.Context, _ *pb.ExportModelRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StartQuantization(_ context.Context, _ *pb.QuantizationRequest, _ ...ggrpc.CallOption) (*pb.QuantizationJobResult, error) {
return &pb.QuantizationJobResult{}, nil
}
func (f *fakeGRPCBackend) QuantizationProgress(_ context.Context, _ *pb.QuantizationProgressRequest, _ func(update *pb.QuantizationProgressUpdate), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) StopQuantization(_ context.Context, _ *pb.QuantizationStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) Free(_ context.Context) error {
return nil
}
func (f *fakeGRPCBackend) TokenClassify(_ context.Context, _ *pb.TokenClassifyRequest, _ ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) {
return nil, nil
}
func (f *fakeGRPCBackend) Score(_ context.Context, _ *pb.ScoreRequest, _ ...ggrpc.CallOption) (*pb.ScoreResponse, error) {
return nil, nil
}
// --- Tests ---
var _ = Describe("InFlightTrackingClient", func() {
var (
tracker *fakeInFlightTracker
backend *fakeGRPCBackend
client *InFlightTrackingClient
)
BeforeEach(func() {
tracker = &fakeInFlightTracker{}
backend = &fakeGRPCBackend{
predictReply: &pb.Reply{Message: []byte("hello")},
streamReplies: []*pb.Reply{{Message: []byte("chunk")}},
}
client = NewInFlightTrackingClient(backend, tracker, "node-1", "llama", 0)
})
Describe("track", func() {
It("increments and decrements via InFlightTracker", func() {
done := client.track(context.Background())
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(0))
done()
Expect(tracker.decrements).To(Equal(1))
})
It("returns no-op when increment fails", func() {
tracker.incrementErr = fmt.Errorf("registry down")
done := client.track(context.Background())
Expect(tracker.increments).To(Equal(1))
// Decrement should NOT be called on cleanup since increment failed.
done()
Expect(tracker.decrements).To(Equal(0))
})
})
Describe("Predict", func() {
It("calls track and delegates to backend", func() {
reply, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(reply.Message).To(Equal([]byte("hello")))
// track was called and cleaned up (defer).
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(1))
})
})
Describe("PredictStream", func() {
It("calls track and delegates to backend", func() {
var replies []*pb.Reply
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(r *pb.Reply) {
replies = append(replies, r)
})
Expect(err).ToNot(HaveOccurred())
Expect(replies).To(HaveLen(1))
Expect(replies[0].Message).To(Equal([]byte("chunk")))
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(1))
})
})
Describe("non-LLM inference methods track in-flight", func() {
// silero-vad and friends only ever expose a single non-Predict method.
// If that method isn't wrapped, the load-time reservation released by
// onFirstComplete never fires and in-flight is stuck at 1 forever.
assertTracked := func(call func() error) {
var firstFired int
client.OnFirstComplete(func() { firstFired++ })
err := call()
Expect(err).ToNot(HaveOccurred())
Expect(tracker.increments).To(Equal(1), "method must increment in-flight")
Expect(tracker.decrements).To(Equal(1), "method must decrement in-flight")
Expect(firstFired).To(Equal(1), "method must release the load-time reservation")
}
It("VAD", func() {
assertTracked(func() error {
_, err := client.VAD(context.Background(), &pb.VADRequest{})
return err
})
})
It("Diarize", func() {
assertTracked(func() error {
_, err := client.Diarize(context.Background(), &pb.DiarizeRequest{})
return err
})
})
It("VoiceVerify", func() {
assertTracked(func() error {
_, err := client.VoiceVerify(context.Background(), &pb.VoiceVerifyRequest{})
return err
})
})
It("VoiceAnalyze", func() {
assertTracked(func() error {
_, err := client.VoiceAnalyze(context.Background(), &pb.VoiceAnalyzeRequest{})
return err
})
})
It("VoiceEmbed", func() {
assertTracked(func() error {
_, err := client.VoiceEmbed(context.Background(), &pb.VoiceEmbedRequest{})
return err
})
})
It("FaceVerify", func() {
assertTracked(func() error {
_, err := client.FaceVerify(context.Background(), &pb.FaceVerifyRequest{})
return err
})
})
It("FaceAnalyze", func() {
assertTracked(func() error {
_, err := client.FaceAnalyze(context.Background(), &pb.FaceAnalyzeRequest{})
return err
})
})
It("TokenClassify", func() {
assertTracked(func() error {
_, err := client.TokenClassify(context.Background(), &pb.TokenClassifyRequest{})
return err
})
})
It("Score", func() {
assertTracked(func() error {
_, err := client.Score(context.Background(), &pb.ScoreRequest{})
return err
})
})
It("AudioEncode", func() {
assertTracked(func() error {
_, err := client.AudioEncode(context.Background(), &pb.AudioEncodeRequest{})
return err
})
})
It("AudioDecode", func() {
assertTracked(func() error {
_, err := client.AudioDecode(context.Background(), &pb.AudioDecodeRequest{})
return err
})
})
It("AudioTransform", func() {
assertTracked(func() error {
_, err := client.AudioTransform(context.Background(), &pb.AudioTransformRequest{})
return err
})
})
})
Describe("stale model reload (self-heal)", func() {
It("removes the replica when the backend reports the model is not loaded", func() {
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
It("keeps the replica on an unrelated error", func() {
backend.predictErr = fmt.Errorf("context deadline exceeded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})
It("does not remove on success", func() {
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})
It("self-heals on a streamed call too", func() {
backend.streamErr = fmt.Errorf("whisper: model not loaded")
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
})
})