mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-26 17:37:07 -04:00
The distributed router wraps backend clients in InFlightTrackingClient so the eviction logic knows which replicas are actively serving. Every inference method must be wrapped: track() increments in-flight on entry and decrements (plus fires onFirstComplete, which releases the load-time reservation) on return. SoundDetection was added after the tracking client and never got a wrapper, so its calls fell through to the embedded passthrough Backend. The increment/decrement never ran and, critically, onFirstComplete never fired, so the reservation set at model load was never released - leaving in-flight stuck at 1 and the replica permanently ineligible for eviction. Wrap SoundDetection like the other non-LLM methods and cover it in the "non-LLM inference methods track in-flight" table test. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
449 lines
15 KiB
Go
449 lines
15 KiB
Go
package nodes
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
|
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
|
ggrpc "google.golang.org/grpc"
|
|
)
|
|
|
|
// --- Fakes ---
|
|
|
|
// fakeInFlightTracker implements InFlightTracker, counting calls.
|
|
type fakeInFlightTracker struct {
|
|
mu sync.Mutex
|
|
increments int
|
|
decrements int
|
|
removed int
|
|
incrementErr error
|
|
}
|
|
|
|
func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.removed++
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.increments++
|
|
return f.incrementErr
|
|
}
|
|
|
|
func (f *fakeInFlightTracker) DecrementInFlight(_ context.Context, _, _ string, _ int) error {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.decrements++
|
|
return nil
|
|
}
|
|
|
|
// fakeGRPCBackend implements grpc.Backend with stub methods.
|
|
// Only the methods we test (Predict, PredictStream) have real behavior;
|
|
// the rest panic if called unexpectedly.
|
|
type fakeGRPCBackend struct {
|
|
predictReply *pb.Reply
|
|
predictErr error
|
|
streamReplies []*pb.Reply
|
|
streamErr error
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) IsBusy() bool { return false }
|
|
func (f *fakeGRPCBackend) HealthCheck(_ context.Context) (bool, error) { return true, nil }
|
|
func (f *fakeGRPCBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Predict(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.Reply, error) {
|
|
return f.predictReply, f.predictErr
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) PredictStream(_ context.Context, _ *pb.PredictOptions, fn func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
|
|
for _, r := range f.streamReplies {
|
|
fn(r)
|
|
}
|
|
return f.streamErr
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Embeddings(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
|
|
return &pb.EmbeddingResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) GenerateImage(_ context.Context, _ *pb.GenerateImageRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) GenerateVideo(_ context.Context, _ *pb.GenerateVideoRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) TTS(_ context.Context, _ *pb.TTSRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) TTSStream(_ context.Context, _ *pb.TTSRequest, _ func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) SoundGeneration(_ context.Context, _ *pb.SoundGenerationRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Detect(_ context.Context, _ *pb.DetectOptions, _ ...ggrpc.CallOption) (*pb.DetectResponse, error) {
|
|
return &pb.DetectResponse{}, nil
|
|
}
|
|
func (f *fakeGRPCBackend) SoundDetection(_ context.Context, _ *pb.SoundDetectionRequest, _ ...ggrpc.CallOption) (*pb.SoundDetectionResponse, error) {
|
|
return &pb.SoundDetectionResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Depth(_ context.Context, _ *pb.DepthRequest, _ ...ggrpc.CallOption) (*pb.DepthResponse, error) {
|
|
return &pb.DepthResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) FaceVerify(_ context.Context, _ *pb.FaceVerifyRequest, _ ...ggrpc.CallOption) (*pb.FaceVerifyResponse, error) {
|
|
return &pb.FaceVerifyResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) FaceAnalyze(_ context.Context, _ *pb.FaceAnalyzeRequest, _ ...ggrpc.CallOption) (*pb.FaceAnalyzeResponse, error) {
|
|
return &pb.FaceAnalyzeResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) VoiceVerify(_ context.Context, _ *pb.VoiceVerifyRequest, _ ...ggrpc.CallOption) (*pb.VoiceVerifyResponse, error) {
|
|
return &pb.VoiceVerifyResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) VoiceAnalyze(_ context.Context, _ *pb.VoiceAnalyzeRequest, _ ...ggrpc.CallOption) (*pb.VoiceAnalyzeResponse, error) {
|
|
return &pb.VoiceAnalyzeResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) VoiceEmbed(_ context.Context, _ *pb.VoiceEmbedRequest, _ ...ggrpc.CallOption) (*pb.VoiceEmbedResponse, error) {
|
|
return &pb.VoiceEmbedResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioTranscription(_ context.Context, _ *pb.TranscriptRequest, _ ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
|
|
return &pb.TranscriptResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioTranscriptionStream(_ context.Context, _ *pb.TranscriptRequest, _ func(chunk *pb.TranscriptStreamResponse), _ ...ggrpc.CallOption) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
|
|
return &pb.TokenizationResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Status(_ context.Context) (*pb.StatusResponse, error) {
|
|
return &pb.StatusResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StoresSet(_ context.Context, _ *pb.StoresSetOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StoresDelete(_ context.Context, _ *pb.StoresDeleteOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StoresGet(_ context.Context, _ *pb.StoresGetOptions, _ ...ggrpc.CallOption) (*pb.StoresGetResult, error) {
|
|
return &pb.StoresGetResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StoresFind(_ context.Context, _ *pb.StoresFindOptions, _ ...ggrpc.CallOption) (*pb.StoresFindResult, error) {
|
|
return &pb.StoresFindResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Rerank(_ context.Context, _ *pb.RerankRequest, _ ...ggrpc.CallOption) (*pb.RerankResult, error) {
|
|
return &pb.RerankResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) GetTokenMetrics(_ context.Context, _ *pb.MetricsRequest, _ ...ggrpc.CallOption) (*pb.MetricsResponse, error) {
|
|
return &pb.MetricsResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) VAD(_ context.Context, _ *pb.VADRequest, _ ...ggrpc.CallOption) (*pb.VADResponse, error) {
|
|
return &pb.VADResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Diarize(_ context.Context, _ *pb.DiarizeRequest, _ ...ggrpc.CallOption) (*pb.DiarizeResponse, error) {
|
|
return &pb.DiarizeResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioEncode(_ context.Context, _ *pb.AudioEncodeRequest, _ ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) {
|
|
return &pb.AudioEncodeResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioDecode(_ context.Context, _ *pb.AudioDecodeRequest, _ ...ggrpc.CallOption) (*pb.AudioDecodeResult, error) {
|
|
return &pb.AudioDecodeResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioTransform(_ context.Context, _ *pb.AudioTransformRequest, _ ...ggrpc.CallOption) (*pb.AudioTransformResult, error) {
|
|
return &pb.AudioTransformResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioTransformStream(_ context.Context, _ ...ggrpc.CallOption) (grpc.AudioTransformStreamClient, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) AudioToAudioStream(_ context.Context, _ ...ggrpc.CallOption) (grpc.AudioToAudioStreamClient, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Forward(_ context.Context, _ ...ggrpc.CallOption) (grpc.ForwardClient, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) ModelMetadata(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.ModelMetadataResponse, error) {
|
|
return &pb.ModelMetadataResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StartFineTune(_ context.Context, _ *pb.FineTuneRequest, _ ...ggrpc.CallOption) (*pb.FineTuneJobResult, error) {
|
|
return &pb.FineTuneJobResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) FineTuneProgress(_ context.Context, _ *pb.FineTuneProgressRequest, _ func(update *pb.FineTuneProgressUpdate), _ ...ggrpc.CallOption) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StopFineTune(_ context.Context, _ *pb.FineTuneStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) ListCheckpoints(_ context.Context, _ *pb.ListCheckpointsRequest, _ ...ggrpc.CallOption) (*pb.ListCheckpointsResponse, error) {
|
|
return &pb.ListCheckpointsResponse{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) ExportModel(_ context.Context, _ *pb.ExportModelRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StartQuantization(_ context.Context, _ *pb.QuantizationRequest, _ ...ggrpc.CallOption) (*pb.QuantizationJobResult, error) {
|
|
return &pb.QuantizationJobResult{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) QuantizationProgress(_ context.Context, _ *pb.QuantizationProgressRequest, _ func(update *pb.QuantizationProgressUpdate), _ ...ggrpc.CallOption) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) StopQuantization(_ context.Context, _ *pb.QuantizationStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
|
|
return &pb.Result{}, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Free(_ context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) TokenClassify(_ context.Context, _ *pb.TokenClassifyRequest, _ ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeGRPCBackend) Score(_ context.Context, _ *pb.ScoreRequest, _ ...ggrpc.CallOption) (*pb.ScoreResponse, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
// --- Tests ---
|
|
|
|
var _ = Describe("InFlightTrackingClient", func() {
|
|
var (
|
|
tracker *fakeInFlightTracker
|
|
backend *fakeGRPCBackend
|
|
client *InFlightTrackingClient
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
tracker = &fakeInFlightTracker{}
|
|
backend = &fakeGRPCBackend{
|
|
predictReply: &pb.Reply{Message: []byte("hello")},
|
|
streamReplies: []*pb.Reply{{Message: []byte("chunk")}},
|
|
}
|
|
client = NewInFlightTrackingClient(backend, tracker, "node-1", "llama", 0)
|
|
})
|
|
|
|
Describe("track", func() {
|
|
It("increments and decrements via InFlightTracker", func() {
|
|
done := client.track(context.Background())
|
|
Expect(tracker.increments).To(Equal(1))
|
|
Expect(tracker.decrements).To(Equal(0))
|
|
done()
|
|
Expect(tracker.decrements).To(Equal(1))
|
|
})
|
|
|
|
It("returns no-op when increment fails", func() {
|
|
tracker.incrementErr = fmt.Errorf("registry down")
|
|
done := client.track(context.Background())
|
|
Expect(tracker.increments).To(Equal(1))
|
|
// Decrement should NOT be called on cleanup since increment failed.
|
|
done()
|
|
Expect(tracker.decrements).To(Equal(0))
|
|
})
|
|
})
|
|
|
|
Describe("Predict", func() {
|
|
It("calls track and delegates to backend", func() {
|
|
reply, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(reply.Message).To(Equal([]byte("hello")))
|
|
|
|
// track was called and cleaned up (defer).
|
|
Expect(tracker.increments).To(Equal(1))
|
|
Expect(tracker.decrements).To(Equal(1))
|
|
})
|
|
})
|
|
|
|
Describe("PredictStream", func() {
|
|
It("calls track and delegates to backend", func() {
|
|
var replies []*pb.Reply
|
|
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(r *pb.Reply) {
|
|
replies = append(replies, r)
|
|
})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(replies).To(HaveLen(1))
|
|
Expect(replies[0].Message).To(Equal([]byte("chunk")))
|
|
|
|
Expect(tracker.increments).To(Equal(1))
|
|
Expect(tracker.decrements).To(Equal(1))
|
|
})
|
|
})
|
|
|
|
Describe("non-LLM inference methods track in-flight", func() {
|
|
// silero-vad and friends only ever expose a single non-Predict method.
|
|
// If that method isn't wrapped, the load-time reservation released by
|
|
// onFirstComplete never fires and in-flight is stuck at 1 forever.
|
|
assertTracked := func(call func() error) {
|
|
var firstFired int
|
|
client.OnFirstComplete(func() { firstFired++ })
|
|
err := call()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(tracker.increments).To(Equal(1), "method must increment in-flight")
|
|
Expect(tracker.decrements).To(Equal(1), "method must decrement in-flight")
|
|
Expect(firstFired).To(Equal(1), "method must release the load-time reservation")
|
|
}
|
|
|
|
It("VAD", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.VAD(context.Background(), &pb.VADRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("Diarize", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.Diarize(context.Background(), &pb.DiarizeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("VoiceVerify", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.VoiceVerify(context.Background(), &pb.VoiceVerifyRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("VoiceAnalyze", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.VoiceAnalyze(context.Background(), &pb.VoiceAnalyzeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("VoiceEmbed", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.VoiceEmbed(context.Background(), &pb.VoiceEmbedRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("FaceVerify", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.FaceVerify(context.Background(), &pb.FaceVerifyRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("FaceAnalyze", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.FaceAnalyze(context.Background(), &pb.FaceAnalyzeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("TokenClassify", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.TokenClassify(context.Background(), &pb.TokenClassifyRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("Score", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.Score(context.Background(), &pb.ScoreRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("AudioEncode", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.AudioEncode(context.Background(), &pb.AudioEncodeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("AudioDecode", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.AudioDecode(context.Background(), &pb.AudioDecodeRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("AudioTransform", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.AudioTransform(context.Background(), &pb.AudioTransformRequest{})
|
|
return err
|
|
})
|
|
})
|
|
|
|
It("SoundDetection", func() {
|
|
assertTracked(func() error {
|
|
_, err := client.SoundDetection(context.Background(), &pb.SoundDetectionRequest{})
|
|
return err
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("stale model reload (self-heal)", func() {
|
|
It("removes the replica when the backend reports the model is not loaded", func() {
|
|
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
|
|
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(tracker.removed).To(Equal(1))
|
|
})
|
|
|
|
It("keeps the replica on an unrelated error", func() {
|
|
backend.predictErr = fmt.Errorf("context deadline exceeded")
|
|
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(tracker.removed).To(Equal(0))
|
|
})
|
|
|
|
It("does not remove on success", func() {
|
|
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(tracker.removed).To(Equal(0))
|
|
})
|
|
|
|
It("self-heals on a streamed call too", func() {
|
|
backend.streamErr = fmt.Errorf("whisper: model not loaded")
|
|
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(tracker.removed).To(Equal(1))
|
|
})
|
|
})
|
|
})
|