Files
LocalAI/core/services/nodes/inflight_test.go
LocalAI [bot] fc618dcee6 fix(distributed): track in-flight for SoundDetection requests (#10475)
The distributed router wraps backend clients in InFlightTrackingClient so
the eviction logic knows which replicas are actively serving. Every
inference method must be wrapped: track() increments in-flight on entry
and decrements (plus fires onFirstComplete, which releases the load-time
reservation) on return.

SoundDetection was added after the tracking client and never got a
wrapper, so its calls fell through to the embedded passthrough Backend.
The increment/decrement never ran and, critically, onFirstComplete never
fired, so the reservation set at model load was never released - leaving
in-flight stuck at 1 and the replica permanently ineligible for eviction.

Wrap SoundDetection like the other non-LLM methods and cover it in the
"non-LLM inference methods track in-flight" table test.


Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-06-24 10:13:37 +02:00

449 lines
15 KiB
Go

package nodes
import (
"context"
"fmt"
"sync"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
grpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
ggrpc "google.golang.org/grpc"
)
// --- Fakes ---
// fakeInFlightTracker implements InFlightTracker, counting calls.
type fakeInFlightTracker struct {
mu sync.Mutex
increments int
decrements int
removed int
incrementErr error
}
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.removed++
return nil
}
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.increments++
return f.incrementErr
}
func (f *fakeInFlightTracker) DecrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.decrements++
return nil
}
// fakeGRPCBackend implements grpc.Backend with stub methods.
// Only the methods we test (Predict, PredictStream) have real behavior;
// the rest panic if called unexpectedly.
type fakeGRPCBackend struct {
predictReply *pb.Reply
predictErr error
streamReplies []*pb.Reply
streamErr error
}
func (f *fakeGRPCBackend) IsBusy() bool { return false }
func (f *fakeGRPCBackend) HealthCheck(_ context.Context) (bool, error) { return true, nil }
func (f *fakeGRPCBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) Predict(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.Reply, error) {
return f.predictReply, f.predictErr
}
func (f *fakeGRPCBackend) PredictStream(_ context.Context, _ *pb.PredictOptions, fn func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
for _, r := range f.streamReplies {
fn(r)
}
return f.streamErr
}
func (f *fakeGRPCBackend) Embeddings(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
return &pb.EmbeddingResult{}, nil
}
func (f *fakeGRPCBackend) GenerateImage(_ context.Context, _ *pb.GenerateImageRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) GenerateVideo(_ context.Context, _ *pb.GenerateVideoRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) TTS(_ context.Context, _ *pb.TTSRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) TTSStream(_ context.Context, _ *pb.TTSRequest, _ func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) SoundGeneration(_ context.Context, _ *pb.SoundGenerationRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) Detect(_ context.Context, _ *pb.DetectOptions, _ ...ggrpc.CallOption) (*pb.DetectResponse, error) {
return &pb.DetectResponse{}, nil
}
func (f *fakeGRPCBackend) SoundDetection(_ context.Context, _ *pb.SoundDetectionRequest, _ ...ggrpc.CallOption) (*pb.SoundDetectionResponse, error) {
return &pb.SoundDetectionResponse{}, nil
}
func (f *fakeGRPCBackend) Depth(_ context.Context, _ *pb.DepthRequest, _ ...ggrpc.CallOption) (*pb.DepthResponse, error) {
return &pb.DepthResponse{}, nil
}
func (f *fakeGRPCBackend) FaceVerify(_ context.Context, _ *pb.FaceVerifyRequest, _ ...ggrpc.CallOption) (*pb.FaceVerifyResponse, error) {
return &pb.FaceVerifyResponse{}, nil
}
func (f *fakeGRPCBackend) FaceAnalyze(_ context.Context, _ *pb.FaceAnalyzeRequest, _ ...ggrpc.CallOption) (*pb.FaceAnalyzeResponse, error) {
return &pb.FaceAnalyzeResponse{}, nil
}
func (f *fakeGRPCBackend) VoiceVerify(_ context.Context, _ *pb.VoiceVerifyRequest, _ ...ggrpc.CallOption) (*pb.VoiceVerifyResponse, error) {
return &pb.VoiceVerifyResponse{}, nil
}
func (f *fakeGRPCBackend) VoiceAnalyze(_ context.Context, _ *pb.VoiceAnalyzeRequest, _ ...ggrpc.CallOption) (*pb.VoiceAnalyzeResponse, error) {
return &pb.VoiceAnalyzeResponse{}, nil
}
func (f *fakeGRPCBackend) VoiceEmbed(_ context.Context, _ *pb.VoiceEmbedRequest, _ ...ggrpc.CallOption) (*pb.VoiceEmbedResponse, error) {
return &pb.VoiceEmbedResponse{}, nil
}
func (f *fakeGRPCBackend) AudioTranscription(_ context.Context, _ *pb.TranscriptRequest, _ ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
return &pb.TranscriptResult{}, nil
}
func (f *fakeGRPCBackend) AudioTranscriptionStream(_ context.Context, _ *pb.TranscriptRequest, _ func(chunk *pb.TranscriptStreamResponse), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
return &pb.TokenizationResponse{}, nil
}
func (f *fakeGRPCBackend) Status(_ context.Context) (*pb.StatusResponse, error) {
return &pb.StatusResponse{}, nil
}
func (f *fakeGRPCBackend) StoresSet(_ context.Context, _ *pb.StoresSetOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StoresDelete(_ context.Context, _ *pb.StoresDeleteOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StoresGet(_ context.Context, _ *pb.StoresGetOptions, _ ...ggrpc.CallOption) (*pb.StoresGetResult, error) {
return &pb.StoresGetResult{}, nil
}
func (f *fakeGRPCBackend) StoresFind(_ context.Context, _ *pb.StoresFindOptions, _ ...ggrpc.CallOption) (*pb.StoresFindResult, error) {
return &pb.StoresFindResult{}, nil
}
func (f *fakeGRPCBackend) Rerank(_ context.Context, _ *pb.RerankRequest, _ ...ggrpc.CallOption) (*pb.RerankResult, error) {
return &pb.RerankResult{}, nil
}
func (f *fakeGRPCBackend) GetTokenMetrics(_ context.Context, _ *pb.MetricsRequest, _ ...ggrpc.CallOption) (*pb.MetricsResponse, error) {
return &pb.MetricsResponse{}, nil
}
func (f *fakeGRPCBackend) VAD(_ context.Context, _ *pb.VADRequest, _ ...ggrpc.CallOption) (*pb.VADResponse, error) {
return &pb.VADResponse{}, nil
}
func (f *fakeGRPCBackend) Diarize(_ context.Context, _ *pb.DiarizeRequest, _ ...ggrpc.CallOption) (*pb.DiarizeResponse, error) {
return &pb.DiarizeResponse{}, nil
}
func (f *fakeGRPCBackend) AudioEncode(_ context.Context, _ *pb.AudioEncodeRequest, _ ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) {
return &pb.AudioEncodeResult{}, nil
}
func (f *fakeGRPCBackend) AudioDecode(_ context.Context, _ *pb.AudioDecodeRequest, _ ...ggrpc.CallOption) (*pb.AudioDecodeResult, error) {
return &pb.AudioDecodeResult{}, nil
}
func (f *fakeGRPCBackend) AudioTransform(_ context.Context, _ *pb.AudioTransformRequest, _ ...ggrpc.CallOption) (*pb.AudioTransformResult, error) {
return &pb.AudioTransformResult{}, nil
}
func (f *fakeGRPCBackend) AudioTransformStream(_ context.Context, _ ...ggrpc.CallOption) (grpc.AudioTransformStreamClient, error) {
return nil, nil
}
func (f *fakeGRPCBackend) AudioToAudioStream(_ context.Context, _ ...ggrpc.CallOption) (grpc.AudioToAudioStreamClient, error) {
return nil, nil
}
func (f *fakeGRPCBackend) Forward(_ context.Context, _ ...ggrpc.CallOption) (grpc.ForwardClient, error) {
return nil, nil
}
func (f *fakeGRPCBackend) ModelMetadata(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.ModelMetadataResponse, error) {
return &pb.ModelMetadataResponse{}, nil
}
func (f *fakeGRPCBackend) StartFineTune(_ context.Context, _ *pb.FineTuneRequest, _ ...ggrpc.CallOption) (*pb.FineTuneJobResult, error) {
return &pb.FineTuneJobResult{}, nil
}
func (f *fakeGRPCBackend) FineTuneProgress(_ context.Context, _ *pb.FineTuneProgressRequest, _ func(update *pb.FineTuneProgressUpdate), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) StopFineTune(_ context.Context, _ *pb.FineTuneStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) ListCheckpoints(_ context.Context, _ *pb.ListCheckpointsRequest, _ ...ggrpc.CallOption) (*pb.ListCheckpointsResponse, error) {
return &pb.ListCheckpointsResponse{}, nil
}
func (f *fakeGRPCBackend) ExportModel(_ context.Context, _ *pb.ExportModelRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StartQuantization(_ context.Context, _ *pb.QuantizationRequest, _ ...ggrpc.CallOption) (*pb.QuantizationJobResult, error) {
return &pb.QuantizationJobResult{}, nil
}
func (f *fakeGRPCBackend) QuantizationProgress(_ context.Context, _ *pb.QuantizationProgressRequest, _ func(update *pb.QuantizationProgressUpdate), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) StopQuantization(_ context.Context, _ *pb.QuantizationStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) Free(_ context.Context) error {
return nil
}
func (f *fakeGRPCBackend) TokenClassify(_ context.Context, _ *pb.TokenClassifyRequest, _ ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) {
return nil, nil
}
func (f *fakeGRPCBackend) Score(_ context.Context, _ *pb.ScoreRequest, _ ...ggrpc.CallOption) (*pb.ScoreResponse, error) {
return nil, nil
}
// --- Tests ---
var _ = Describe("InFlightTrackingClient", func() {
var (
tracker *fakeInFlightTracker
backend *fakeGRPCBackend
client *InFlightTrackingClient
)
BeforeEach(func() {
tracker = &fakeInFlightTracker{}
backend = &fakeGRPCBackend{
predictReply: &pb.Reply{Message: []byte("hello")},
streamReplies: []*pb.Reply{{Message: []byte("chunk")}},
}
client = NewInFlightTrackingClient(backend, tracker, "node-1", "llama", 0)
})
Describe("track", func() {
It("increments and decrements via InFlightTracker", func() {
done := client.track(context.Background())
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(0))
done()
Expect(tracker.decrements).To(Equal(1))
})
It("returns no-op when increment fails", func() {
tracker.incrementErr = fmt.Errorf("registry down")
done := client.track(context.Background())
Expect(tracker.increments).To(Equal(1))
// Decrement should NOT be called on cleanup since increment failed.
done()
Expect(tracker.decrements).To(Equal(0))
})
})
Describe("Predict", func() {
It("calls track and delegates to backend", func() {
reply, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(reply.Message).To(Equal([]byte("hello")))
// track was called and cleaned up (defer).
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(1))
})
})
Describe("PredictStream", func() {
It("calls track and delegates to backend", func() {
var replies []*pb.Reply
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(r *pb.Reply) {
replies = append(replies, r)
})
Expect(err).ToNot(HaveOccurred())
Expect(replies).To(HaveLen(1))
Expect(replies[0].Message).To(Equal([]byte("chunk")))
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(1))
})
})
Describe("non-LLM inference methods track in-flight", func() {
// silero-vad and friends only ever expose a single non-Predict method.
// If that method isn't wrapped, the load-time reservation released by
// onFirstComplete never fires and in-flight is stuck at 1 forever.
assertTracked := func(call func() error) {
var firstFired int
client.OnFirstComplete(func() { firstFired++ })
err := call()
Expect(err).ToNot(HaveOccurred())
Expect(tracker.increments).To(Equal(1), "method must increment in-flight")
Expect(tracker.decrements).To(Equal(1), "method must decrement in-flight")
Expect(firstFired).To(Equal(1), "method must release the load-time reservation")
}
It("VAD", func() {
assertTracked(func() error {
_, err := client.VAD(context.Background(), &pb.VADRequest{})
return err
})
})
It("Diarize", func() {
assertTracked(func() error {
_, err := client.Diarize(context.Background(), &pb.DiarizeRequest{})
return err
})
})
It("VoiceVerify", func() {
assertTracked(func() error {
_, err := client.VoiceVerify(context.Background(), &pb.VoiceVerifyRequest{})
return err
})
})
It("VoiceAnalyze", func() {
assertTracked(func() error {
_, err := client.VoiceAnalyze(context.Background(), &pb.VoiceAnalyzeRequest{})
return err
})
})
It("VoiceEmbed", func() {
assertTracked(func() error {
_, err := client.VoiceEmbed(context.Background(), &pb.VoiceEmbedRequest{})
return err
})
})
It("FaceVerify", func() {
assertTracked(func() error {
_, err := client.FaceVerify(context.Background(), &pb.FaceVerifyRequest{})
return err
})
})
It("FaceAnalyze", func() {
assertTracked(func() error {
_, err := client.FaceAnalyze(context.Background(), &pb.FaceAnalyzeRequest{})
return err
})
})
It("TokenClassify", func() {
assertTracked(func() error {
_, err := client.TokenClassify(context.Background(), &pb.TokenClassifyRequest{})
return err
})
})
It("Score", func() {
assertTracked(func() error {
_, err := client.Score(context.Background(), &pb.ScoreRequest{})
return err
})
})
It("AudioEncode", func() {
assertTracked(func() error {
_, err := client.AudioEncode(context.Background(), &pb.AudioEncodeRequest{})
return err
})
})
It("AudioDecode", func() {
assertTracked(func() error {
_, err := client.AudioDecode(context.Background(), &pb.AudioDecodeRequest{})
return err
})
})
It("AudioTransform", func() {
assertTracked(func() error {
_, err := client.AudioTransform(context.Background(), &pb.AudioTransformRequest{})
return err
})
})
It("SoundDetection", func() {
assertTracked(func() error {
_, err := client.SoundDetection(context.Background(), &pb.SoundDetectionRequest{})
return err
})
})
})
Describe("stale model reload (self-heal)", func() {
It("removes the replica when the backend reports the model is not loaded", func() {
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
It("keeps the replica on an unrelated error", func() {
backend.predictErr = fmt.Errorf("context deadline exceeded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})
It("does not remove on success", func() {
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})
It("self-heals on a streamed call too", func() {
backend.streamErr = fmt.Errorf("whisper: model not loaded")
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
})
})