Files
LocalAI/core/services/nodes/inflight_test.go
Ettore Di Giacinto 59108fbe32 feat: add distributed mode (#9124)
* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 00:47:27 +02:00

242 lines
8.2 KiB
Go

package nodes
import (
"context"
"fmt"
"sync"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
ggrpc "google.golang.org/grpc"
)
// --- Fakes ---
// fakeInFlightTracker implements InFlightTracker, counting calls.
type fakeInFlightTracker struct {
mu sync.Mutex
increments int
decrements int
incrementErr error
}
func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string) error {
f.mu.Lock()
defer f.mu.Unlock()
f.increments++
return f.incrementErr
}
func (f *fakeInFlightTracker) DecrementInFlight(_ context.Context, _, _ string) error {
f.mu.Lock()
defer f.mu.Unlock()
f.decrements++
return nil
}
// fakeGRPCBackend implements grpc.Backend with stub methods.
// Only the methods we test (Predict, PredictStream) have real behavior;
// the rest panic if called unexpectedly.
type fakeGRPCBackend struct {
predictReply *pb.Reply
predictErr error
streamReplies []*pb.Reply
streamErr error
}
func (f *fakeGRPCBackend) IsBusy() bool { return false }
func (f *fakeGRPCBackend) HealthCheck(_ context.Context) (bool, error) { return true, nil }
func (f *fakeGRPCBackend) LoadModel(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) Predict(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.Reply, error) {
return f.predictReply, f.predictErr
}
func (f *fakeGRPCBackend) PredictStream(_ context.Context, _ *pb.PredictOptions, fn func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
for _, r := range f.streamReplies {
fn(r)
}
return f.streamErr
}
func (f *fakeGRPCBackend) Embeddings(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
return &pb.EmbeddingResult{}, nil
}
func (f *fakeGRPCBackend) GenerateImage(_ context.Context, _ *pb.GenerateImageRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) GenerateVideo(_ context.Context, _ *pb.GenerateVideoRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) TTS(_ context.Context, _ *pb.TTSRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) TTSStream(_ context.Context, _ *pb.TTSRequest, _ func(reply *pb.Reply), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) SoundGeneration(_ context.Context, _ *pb.SoundGenerationRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) Detect(_ context.Context, _ *pb.DetectOptions, _ ...ggrpc.CallOption) (*pb.DetectResponse, error) {
return &pb.DetectResponse{}, nil
}
func (f *fakeGRPCBackend) AudioTranscription(_ context.Context, _ *pb.TranscriptRequest, _ ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
return &pb.TranscriptResult{}, nil
}
func (f *fakeGRPCBackend) TokenizeString(_ context.Context, _ *pb.PredictOptions, _ ...ggrpc.CallOption) (*pb.TokenizationResponse, error) {
return &pb.TokenizationResponse{}, nil
}
func (f *fakeGRPCBackend) Status(_ context.Context) (*pb.StatusResponse, error) {
return &pb.StatusResponse{}, nil
}
func (f *fakeGRPCBackend) StoresSet(_ context.Context, _ *pb.StoresSetOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StoresDelete(_ context.Context, _ *pb.StoresDeleteOptions, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StoresGet(_ context.Context, _ *pb.StoresGetOptions, _ ...ggrpc.CallOption) (*pb.StoresGetResult, error) {
return &pb.StoresGetResult{}, nil
}
func (f *fakeGRPCBackend) StoresFind(_ context.Context, _ *pb.StoresFindOptions, _ ...ggrpc.CallOption) (*pb.StoresFindResult, error) {
return &pb.StoresFindResult{}, nil
}
func (f *fakeGRPCBackend) Rerank(_ context.Context, _ *pb.RerankRequest, _ ...ggrpc.CallOption) (*pb.RerankResult, error) {
return &pb.RerankResult{}, nil
}
func (f *fakeGRPCBackend) GetTokenMetrics(_ context.Context, _ *pb.MetricsRequest, _ ...ggrpc.CallOption) (*pb.MetricsResponse, error) {
return &pb.MetricsResponse{}, nil
}
func (f *fakeGRPCBackend) VAD(_ context.Context, _ *pb.VADRequest, _ ...ggrpc.CallOption) (*pb.VADResponse, error) {
return &pb.VADResponse{}, nil
}
func (f *fakeGRPCBackend) AudioEncode(_ context.Context, _ *pb.AudioEncodeRequest, _ ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) {
return &pb.AudioEncodeResult{}, nil
}
func (f *fakeGRPCBackend) AudioDecode(_ context.Context, _ *pb.AudioDecodeRequest, _ ...ggrpc.CallOption) (*pb.AudioDecodeResult, error) {
return &pb.AudioDecodeResult{}, nil
}
func (f *fakeGRPCBackend) ModelMetadata(_ context.Context, _ *pb.ModelOptions, _ ...ggrpc.CallOption) (*pb.ModelMetadataResponse, error) {
return &pb.ModelMetadataResponse{}, nil
}
func (f *fakeGRPCBackend) StartFineTune(_ context.Context, _ *pb.FineTuneRequest, _ ...ggrpc.CallOption) (*pb.FineTuneJobResult, error) {
return &pb.FineTuneJobResult{}, nil
}
func (f *fakeGRPCBackend) FineTuneProgress(_ context.Context, _ *pb.FineTuneProgressRequest, _ func(update *pb.FineTuneProgressUpdate), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) StopFineTune(_ context.Context, _ *pb.FineTuneStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) ListCheckpoints(_ context.Context, _ *pb.ListCheckpointsRequest, _ ...ggrpc.CallOption) (*pb.ListCheckpointsResponse, error) {
return &pb.ListCheckpointsResponse{}, nil
}
func (f *fakeGRPCBackend) ExportModel(_ context.Context, _ *pb.ExportModelRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
func (f *fakeGRPCBackend) StartQuantization(_ context.Context, _ *pb.QuantizationRequest, _ ...ggrpc.CallOption) (*pb.QuantizationJobResult, error) {
return &pb.QuantizationJobResult{}, nil
}
func (f *fakeGRPCBackend) QuantizationProgress(_ context.Context, _ *pb.QuantizationProgressRequest, _ func(update *pb.QuantizationProgressUpdate), _ ...ggrpc.CallOption) error {
return nil
}
func (f *fakeGRPCBackend) StopQuantization(_ context.Context, _ *pb.QuantizationStopRequest, _ ...ggrpc.CallOption) (*pb.Result, error) {
return &pb.Result{}, nil
}
// --- Tests ---
var _ = Describe("InFlightTrackingClient", func() {
var (
tracker *fakeInFlightTracker
backend *fakeGRPCBackend
client *InFlightTrackingClient
)
BeforeEach(func() {
tracker = &fakeInFlightTracker{}
backend = &fakeGRPCBackend{
predictReply: &pb.Reply{Message: []byte("hello")},
streamReplies: []*pb.Reply{{Message: []byte("chunk")}},
}
client = NewInFlightTrackingClient(backend, tracker, "node-1", "llama")
})
Describe("track", func() {
It("increments and decrements via InFlightTracker", func() {
done := client.track(context.Background())
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(0))
done()
Expect(tracker.decrements).To(Equal(1))
})
It("returns no-op when increment fails", func() {
tracker.incrementErr = fmt.Errorf("registry down")
done := client.track(context.Background())
Expect(tracker.increments).To(Equal(1))
// Decrement should NOT be called on cleanup since increment failed.
done()
Expect(tracker.decrements).To(Equal(0))
})
})
Describe("Predict", func() {
It("calls track and delegates to backend", func() {
reply, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(reply.Message).To(Equal([]byte("hello")))
// track was called and cleaned up (defer).
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(1))
})
})
Describe("PredictStream", func() {
It("calls track and delegates to backend", func() {
var replies []*pb.Reply
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(r *pb.Reply) {
replies = append(replies, r)
})
Expect(err).ToNot(HaveOccurred())
Expect(replies).To(HaveLen(1))
Expect(replies[0].Message).To(Equal([]byte("chunk")))
Expect(tracker.increments).To(Equal(1))
Expect(tracker.decrements).To(Equal(1))
})
})
})