mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-14 19:58:44 -04:00
fix(distributed): track in-flight for non-LLM inference methods InFlightTrackingClient only wrapped a subset of the grpc.Backend inference methods (Predict, Embeddings, TTS, AudioTranscription, Detect, Rerank, ...). Methods like VAD were left as embedded passthrough, so track() never ran for them. In distributed mode every model is loaded with in_flight=1 as a reservation; that reservation is only released by the OnFirstComplete callback, which fires after the first *tracked* inference call completes. A VAD-only model (e.g. silero-vad) never calls a tracked method, so the reservation is never released and in-flight stays pinned at 1 forever - which also blocks the router's idle-eviction logic. Wrap the remaining unary inference methods (VAD, Diarize, Face*, Voice*, TokenClassify, Score, AudioEncode, AudioDecode, AudioTransform) with the same track()/reconcile() pattern. The three bidi-stream constructors (AudioTransformStream, AudioToAudioStream, Forward) are deliberately left as passthrough - their inference spans the stream lifetime, not the constructor call, so track() there would fire onFirstComplete before any data flows. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
435 lines
14 KiB
Go
435 lines
14 KiB
Go
package nodes
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
ggrpc "google.golang.org/grpc"
|
|
)
|
|
|
|
// --- Fakes ---
|
|
|
|
// fakeInFlightTracker implements InFlightTracker, counting calls.
|
|
type fakeInFlightTracker struct {
|
|
mu sync.Mutex
|
|
increments int
|
|
decrements int
|
|
removed int
|
|
incrementErr error
|
|
}
|
|
|
|
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.removed++
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.increments++
|
|
return f.incrementErr
|
|
}
|
|
|
|
func (f *fakeInFlightTracker) DecrementInFlight(_ context.Context, _, _ string, _ int) error {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.decrements++
|
|
return nil
|
|
}
|
|
|
|
// fakeGRPCBackend implements grpc.Backend with stub methods.
|
|
// Only the methods we test (Predict, PredictStream) have real behavior;
|
|
// the rest panic if called unexpectedly.
|
|
type fakeGRPCBackend struct {
|
|
predictReply *pb.Reply
|
|
predictErr error
|
|
streamReplies []*pb.Reply
|
|
streamErr error
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) IsBusy() bool { return false }
|
|
func (f *fakeGRPCBackend) HealthCheck(_ context.Context) (bool, error) { return true, nil }
|
|
func (f *fakeGRPCBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Predict(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.Reply, error) {
|
|
return f.predictReply, f.predictErr
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) PredictStream(_ context.Context, _ *pb.PredictOptions, fn func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
|
|
for _, r := range f.streamReplies {
|
|
fn(r)
|
|
}
|
|
return f.streamErr
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Embeddings(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
|
|
return &pb.EmbeddingResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) GenerateImage(_ context.Context, _ *pb.GenerateImageRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) GenerateVideo(_ context.Context, _ *pb.GenerateVideoRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) TTS(_ context.Context, _ *pb.TTSRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) TTSStream(_ context.Context, _ *pb.TTSRequest, _ func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) SoundGeneration(_ context.Context, _ *pb.SoundGenerationRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Detect(_ context.Context, _ *pb.DetectOptions, _ ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
|
return &pb.DetectResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) FaceVerify(_ context.Context, _ *pb.FaceVerifyRequest, _ ...ggrpc.CallOption) (*pb.FaceVerifyResponse, error) {
|
|
return &pb.FaceVerifyResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) FaceAnalyze(_ context.Context, _ *pb.FaceAnalyzeRequest, _ ...ggrpc.CallOption) (*pb.FaceAnalyzeResponse, error) {
|
|
return &pb.FaceAnalyzeResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) VoiceVerify(_ context.Context, _ *pb.VoiceVerifyRequest, _ ...ggrpc.CallOption) (*pb.VoiceVerifyResponse, error) {
|
|
return &pb.VoiceVerifyResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) VoiceAnalyze(_ context.Context, _ *pb.VoiceAnalyzeRequest, _ ...ggrpc.CallOption) (*pb.VoiceAnalyzeResponse, error) {
|
|
return &pb.VoiceAnalyzeResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) VoiceEmbed(_ context.Context, _ *pb.VoiceEmbedRequest, _ ...ggrpc.CallOption) (*pb.VoiceEmbedResponse, error) {
|
|
return &pb.VoiceEmbedResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioTranscription(_ context.Context, _ *pb.TranscriptRequest, _ ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
|
|
return &pb.TranscriptResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioTranscriptionStream(_ context.Context, _ *pb.TranscriptRequest, _ func(chunk *pb.TranscriptStreamResponse), _ ...ggrpc.CallOption) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
|
|
return &pb.TokenizationResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Status(_ context.Context) (*pb.StatusResponse, error) {
|
|
return &pb.StatusResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StoresSet(_ context.Context, _ *pb.StoresSetOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StoresDelete(_ context.Context, _ *pb.StoresDeleteOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StoresGet(_ context.Context, _ *pb.StoresGetOptions, _ ...ggrpc.CallOption) (*pb.StoresGetResult, error) {
|
|
return &pb.StoresGetResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StoresFind(_ context.Context, _ *pb.StoresFindOptions, _ ...ggrpc.CallOption) (*pb.StoresFindResult, error) {
|
|
return &pb.StoresFindResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Rerank(_ context.Context, _ *pb.RerankRequest, _ ...ggrpc.CallOption) (*pb.RerankResult, error) {
|
|
return &pb.RerankResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) GetTokenMetrics(_ context.Context, _ *pb.MetricsRequest, _ ...ggrpc.CallOption) (*pb.MetricsResponse, error) {
|
|
return &pb.MetricsResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) VAD(_ context.Context, _ *pb.VADRequest, _ ...ggrpc.CallOption) (*pb.VADResponse, error) {
|
|
return &pb.VADResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Diarize(_ context.Context, _ *pb.DiarizeRequest, _ ...ggrpc.CallOption) (*pb.DiarizeResponse, error) {
|
|
return &pb.DiarizeResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioEncode(_ context.Context, _ *pb.AudioEncodeRequest, _ ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) {
|
|
return &pb.AudioEncodeResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioDecode(_ context.Context, _ *pb.AudioDecodeRequest, _ ...ggrpc.CallOption) (*pb.AudioDecodeResult, error) {
|
|
return &pb.AudioDecodeResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioTransform(_ context.Context, _ *pb.AudioTransformRequest, _ ...ggrpc.CallOption) (*pb.AudioTransformResult, error) {
|
|
return &pb.AudioTransformResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioTransformStream(_ context.Context, _ ...ggrpc.CallOption) (grpc.AudioTransformStreamClient, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioToAudioStream(_ context.Context, _ ...ggrpc.CallOption) (grpc.AudioToAudioStreamClient, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Forward(_ context.Context, _ ...ggrpc.CallOption) (grpc.ForwardClient, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) ModelMetadata(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.ModelMetadataResponse, error) {
|
|
return &pb.ModelMetadataResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StartFineTune(_ context.Context, _ *pb.FineTuneRequest, _ ...ggrpc.CallOption) (*pb.FineTuneJobResult, error) {
|
|
return &pb.FineTuneJobResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) FineTuneProgress(_ context.Context, _ *pb.FineTuneProgressRequest, _ func(update *pb.FineTuneProgressUpdate), _ ...ggrpc.CallOption) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StopFineTune(_ context.Context, _ *pb.FineTuneStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) ListCheckpoints(_ context.Context, _ *pb.ListCheckpointsRequest, _ ...ggrpc.CallOption) (*pb.ListCheckpointsResponse, error) {
|
|
return &pb.ListCheckpointsResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) ExportModel(_ context.Context, _ *pb.ExportModelRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StartQuantization(_ context.Context, _ *pb.QuantizationRequest, _ ...ggrpc.CallOption) (*pb.QuantizationJobResult, error) {
|
|
return &pb.QuantizationJobResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) QuantizationProgress(_ context.Context, _ *pb.QuantizationProgressRequest, _ func(update *pb.QuantizationProgressUpdate), _ ...ggrpc.CallOption) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StopQuantization(_ context.Context, _ *pb.QuantizationStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Free(_ context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) TokenClassify(_ context.Context, _ *pb.TokenClassifyRequest, _ ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Score(_ context.Context, _ *pb.ScoreRequest, _ ...ggrpc.CallOption) (*pb.ScoreResponse, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
// --- Tests ---
|
|
|
|
var _ = Describe("InFlightTrackingClient", func() {
|
|
var (
|
|
tracker *fakeInFlightTracker
|
|
backend *fakeGRPCBackend
|
|
client *InFlightTrackingClient
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
tracker = &fakeInFlightTracker{}
|
|
backend = &fakeGRPCBackend{
|
|
predictReply: &pb.Reply{Message: []byte("hello")},
|
|
streamReplies: []*pb.Reply{{Message: []byte("chunk")}},
|
|
}
|
|
client = NewInFlightTrackingClient(backend, tracker, "node-1", "llama", 0)
|
|
})
|
|
|
|
Describe("track", func() {
|
|
It("increments and decrements via InFlightTracker", func() {
|
|
done := client.track(context.Background())
|
|
Expect(tracker.increments).To(Equal(1))
|
|
Expect(tracker.decrements).To(Equal(0))
|
|
done()
|
|
Expect(tracker.decrements).To(Equal(1))
|
|
})
|
|
|
|
It("returns no-op when increment fails", func() {
|
|
tracker.incrementErr = fmt.Errorf("registry down")
|
|
done := client.track(context.Background())
|
|
Expect(tracker.increments).To(Equal(1))
|
|
// Decrement should NOT be called on cleanup since increment failed.
|
|
done()
|
|
Expect(tracker.decrements).To(Equal(0))
|
|
})
|
|
})
|
|
|
|
Describe("Predict", func() {
|
|
It("calls track and delegates to backend", func() {
|
|
reply, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(reply.Message).To(Equal([]byte("hello")))
|
|
|
|
// track was called and cleaned up (defer).
|
|
Expect(tracker.increments).To(Equal(1))
|
|
Expect(tracker.decrements).To(Equal(1))
|
|
})
|
|
})
|
|
|
|
Describe("PredictStream", func() {
|
|
It("calls track and delegates to backend", func() {
|
|
var replies []*pb.Reply
|
|
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(r *pb.Reply) {
|
|
replies = append(replies, r)
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(replies).To(HaveLen(1))
|
|
Expect(replies[0].Message).To(Equal([]byte("chunk")))
|
|
|
|
Expect(tracker.increments).To(Equal(1))
|
|
Expect(tracker.decrements).To(Equal(1))
|
|
})
|
|
})
|
|
|
|
Describe("non-LLM inference methods track in-flight", func() {
|
|
// silero-vad and friends only ever expose a single non-Predict method.
|
|
// If that method isn't wrapped, the load-time reservation released by
|
|
// onFirstComplete never fires and in-flight is stuck at 1 forever.
|
|
assertTracked := func(call func() error) {
|
|
var firstFired int
|
|
client.OnFirstComplete(func() { firstFired++ })
|
|
err := call()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(tracker.increments).To(Equal(1), "method must increment in-flight")
|
|
Expect(tracker.decrements).To(Equal(1), "method must decrement in-flight")
|
|
Expect(firstFired).To(Equal(1), "method must release the load-time reservation")
|
|
}
|
|
|
|
It("VAD", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.VAD(context.Background(), &pb.VADRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("Diarize", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.Diarize(context.Background(), &pb.DiarizeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("VoiceVerify", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.VoiceVerify(context.Background(), &pb.VoiceVerifyRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("VoiceAnalyze", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.VoiceAnalyze(context.Background(), &pb.VoiceAnalyzeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("VoiceEmbed", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.VoiceEmbed(context.Background(), &pb.VoiceEmbedRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("FaceVerify", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.FaceVerify(context.Background(), &pb.FaceVerifyRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("FaceAnalyze", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.FaceAnalyze(context.Background(), &pb.FaceAnalyzeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("TokenClassify", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.TokenClassify(context.Background(), &pb.TokenClassifyRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("Score", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.Score(context.Background(), &pb.ScoreRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("AudioEncode", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.AudioEncode(context.Background(), &pb.AudioEncodeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("AudioDecode", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.AudioDecode(context.Background(), &pb.AudioDecodeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("AudioTransform", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.AudioTransform(context.Background(), &pb.AudioTransformRequest{})
|
|
return err
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("stale model reload (self-heal)", func() {
|
|
It("removes the replica when the backend reports the model is not loaded", func() {
|
|
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
|
|
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(tracker.removed).To(Equal(1))
|
|
})
|
|
|
|
It("keeps the replica on an unrelated error", func() {
|
|
backend.predictErr = fmt.Errorf("context deadline exceeded")
|
|
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(tracker.removed).To(Equal(0))
|
|
})
|
|
|
|
It("does not remove on success", func() {
|
|
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(tracker.removed).To(Equal(0))
|
|
})
|
|
|
|
It("self-heals on a streamed call too", func() {
|
|
backend.streamErr = fmt.Errorf("whisper: model not loaded")
|
|
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(tracker.removed).To(Equal(1))
|
|
})
|
|
})
|
|
})
|