refactor(distributed): make in-flight tracking coverage a compile-time contract (#10476)

PR #10475 fixed SoundDetection in-flight tracking, but the underlying trap
remains: InFlightTrackingClient embedded the whole grpc.Backend interface
"for passthrough of untracked methods", so any newly added inference method
is silently satisfied by the embedded passthrough and never wrapped with
track(). That leaves onFirstComplete unfired and in-flight stuck at 1 - the
exact SoundDetection bug, waiting to recur for the next backend method.

Close the gap at the type level instead of relying on reviewers to remember:

- Split grpc.Backend into two composed sub-interfaces: InferenceBackend
  (methods that are one discrete inference call and must be tracked) and
  ControlBackend (control-plane calls plus the streaming constructors whose
  work spans the returned stream, safe to pass through). The classification
  now lives next to the interface it documents.
- InFlightTrackingClient embeds only grpc.ControlBackend and implements every
  InferenceBackend method explicitly, delegating to an inner InferenceBackend.
  A `var _ grpc.Backend = (*InFlightTrackingClient)(nil)` assertion makes the
  package fail to compile if any inference method is left unwrapped.

Now adding a method to InferenceBackend is a build error (at the assertion and
every call site: "does not implement grpc.Backend (missing method X)"), not a
silent runtime leak - and the obvious fix is to copy a neighbouring wrapper,
which calls track(). No runtime guard or reviewer vigilance required.

Pure refactor: the composed Backend interface is identical to the old flat
one, so all implementers and consumers are unaffected (verified with a full
`go build ./...`). Behaviour is unchanged; the existing nodes suite passes.


Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
LocalAI [bot]
2026-06-24 11:08:29 +02:00
committed by GitHub
parent fc618dcee6
commit e5620989dd
2 changed files with 109 additions and 63 deletions

View File

@@ -19,25 +19,40 @@ import (
// Per-replica: a single tracker instance is bound to (nodeID, modelName, replicaIndex).
// The router constructs one tracker per Route() result, so each in-flight tick lands
// on the correct row even when multiple replicas of the same model live on the same node.
//
// Embedding only grpc.ControlBackend (not the whole grpc.Backend) is what makes
// the in-flight accounting safe by construction: the control-plane methods pass
// through untracked, while every grpc.InferenceBackend method must be declared
// explicitly below to satisfy grpc.Backend. Adding an inference method to the
// interface therefore breaks this file's build (see the var assertion below)
// until it is wrapped with track() - so a new inference path can't be added
// without an in-flight accounting decision.
type InFlightTrackingClient struct {
grpc.Backend // embed for passthrough of untracked methods
registry InFlightTracker
nodeID string
modelName string
replicaIndex int
grpc.ControlBackend // passthrough for control-plane / streaming-constructor methods
inner grpc.InferenceBackend // tracked inference methods delegate here
registry InFlightTracker
nodeID string
modelName string
replicaIndex int
firstOnce sync.Once // guards onFirstComplete
onFirstComplete func() // called once after the first tracked inference call completes
}
// Compile-time contract: *InFlightTrackingClient must implement the FULL backend
// surface. Because it embeds only ControlBackend, this fails to compile if any
// InferenceBackend method is left unwrapped.
var _ grpc.Backend = (*InFlightTrackingClient)(nil)
// NewInFlightTrackingClient wraps a gRPC backend client with in-flight tracking.
func NewInFlightTrackingClient(inner grpc.Backend, registry InFlightTracker, nodeID, modelName string, replicaIndex int) *InFlightTrackingClient {
return &InFlightTrackingClient{
Backend: inner,
registry: registry,
nodeID: nodeID,
modelName: modelName,
replicaIndex: replicaIndex,
ControlBackend: inner,
inner: inner,
registry: registry,
nodeID: nodeID,
modelName: modelName,
replicaIndex: replicaIndex,
}
}
@@ -91,160 +106,162 @@ func (c *InFlightTrackingClient) reconcile(err error) error {
func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
defer c.track(ctx)()
reply, err := c.Backend.Predict(ctx, in, opts...)
reply, err := c.inner.Predict(ctx, in, opts...)
return reply, c.reconcile(err)
}
func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...))
return c.reconcile(c.inner.PredictStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
defer c.track(ctx)()
res, err := c.Backend.Embeddings(ctx, in, opts...)
res, err := c.inner.Embeddings(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
res, err := c.Backend.GenerateImage(ctx, in, opts...)
res, err := c.inner.GenerateImage(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
res, err := c.Backend.GenerateVideo(ctx, in, opts...)
res, err := c.inner.GenerateVideo(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
res, err := c.Backend.TTS(ctx, in, opts...)
res, err := c.inner.TTS(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...))
return c.reconcile(c.inner.TTSStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
res, err := c.Backend.SoundGeneration(ctx, in, opts...)
res, err := c.inner.SoundGeneration(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
defer c.track(ctx)()
res, err := c.Backend.AudioTranscription(ctx, in, opts...)
res, err := c.inner.AudioTranscription(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...))
return c.reconcile(c.inner.AudioTranscriptionStream(ctx, in, f, opts...))
}
func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.Detect(ctx, in, opts...)
res, err := c.inner.Detect(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) Depth(ctx context.Context, in *pb.DepthRequest, opts ...ggrpc.CallOption) (*pb.DepthResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.Depth(ctx, in, opts...)
res, err := c.inner.Depth(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
defer c.track(ctx)()
res, err := c.Backend.Rerank(ctx, in, opts...)
res, err := c.inner.Rerank(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) VAD(ctx context.Context, in *pb.VADRequest, opts ...ggrpc.CallOption) (*pb.VADResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.VAD(ctx, in, opts...)
res, err := c.inner.VAD(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...ggrpc.CallOption) (*pb.DiarizeResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.Diarize(ctx, in, opts...)
res, err := c.inner.Diarize(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) FaceVerify(ctx context.Context, in *pb.FaceVerifyRequest, opts ...ggrpc.CallOption) (*pb.FaceVerifyResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.FaceVerify(ctx, in, opts...)
res, err := c.inner.FaceVerify(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) FaceAnalyze(ctx context.Context, in *pb.FaceAnalyzeRequest, opts ...ggrpc.CallOption) (*pb.FaceAnalyzeResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.FaceAnalyze(ctx, in, opts...)
res, err := c.inner.FaceAnalyze(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) VoiceVerify(ctx context.Context, in *pb.VoiceVerifyRequest, opts ...ggrpc.CallOption) (*pb.VoiceVerifyResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.VoiceVerify(ctx, in, opts...)
res, err := c.inner.VoiceVerify(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) VoiceAnalyze(ctx context.Context, in *pb.VoiceAnalyzeRequest, opts ...ggrpc.CallOption) (*pb.VoiceAnalyzeResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.VoiceAnalyze(ctx, in, opts...)
res, err := c.inner.VoiceAnalyze(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) VoiceEmbed(ctx context.Context, in *pb.VoiceEmbedRequest, opts ...ggrpc.CallOption) (*pb.VoiceEmbedResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.VoiceEmbed(ctx, in, opts...)
res, err := c.inner.VoiceEmbed(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.TokenClassify(ctx, in, opts...)
res, err := c.inner.TokenClassify(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) Score(ctx context.Context, in *pb.ScoreRequest, opts ...ggrpc.CallOption) (*pb.ScoreResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.Score(ctx, in, opts...)
res, err := c.inner.Score(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) SoundDetection(ctx context.Context, in *pb.SoundDetectionRequest, opts ...ggrpc.CallOption) (*pb.SoundDetectionResponse, error) {
defer c.track(ctx)()
res, err := c.Backend.SoundDetection(ctx, in, opts...)
res, err := c.inner.SoundDetection(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) {
defer c.track(ctx)()
res, err := c.Backend.AudioEncode(ctx, in, opts...)
res, err := c.inner.AudioEncode(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...ggrpc.CallOption) (*pb.AudioDecodeResult, error) {
defer c.track(ctx)()
res, err := c.Backend.AudioDecode(ctx, in, opts...)
res, err := c.inner.AudioDecode(ctx, in, opts...)
return res, c.reconcile(err)
}
func (c *InFlightTrackingClient) AudioTransform(ctx context.Context, in *pb.AudioTransformRequest, opts ...ggrpc.CallOption) (*pb.AudioTransformResult, error) {
defer c.track(ctx)()
res, err := c.Backend.AudioTransform(ctx, in, opts...)
res, err := c.inner.AudioTransform(ctx, in, opts...)
return res, c.reconcile(err)
}
// AudioTransformStream, AudioToAudioStream and Forward are deliberately left as
// embedded passthrough: they return a stream client and the inference spans the
// stream's lifetime, not the constructor call. Wrapping the constructor with
// track() would increment and immediately decrement (and fire onFirstComplete)
// before any audio flows. Tracking those correctly needs the done() func tied to
// stream close, which the current Backend interface doesn't surface here.
// AudioTransformStream, AudioToAudioStream and Forward live in grpc.ControlBackend
// and are passed through via the embedded field, NOT tracked: they return a stream
// client and the inference spans the stream's lifetime, not the constructor call.
// Wrapping the constructor with track() would increment and immediately decrement
// (and fire onFirstComplete) before any audio flows. Tracking those correctly needs
// the done() func tied to stream close, which the Backend interface doesn't surface
// here. If they ever need tracking, move them to grpc.InferenceBackend - the build
// will then force an explicit wrapper here.

View File

@@ -41,11 +41,34 @@ func buildClient(address string, parallel bool, wd WatchDog, enableWatchDog bool
}
}
// Backend is the full client surface of a model backend. It is deliberately
// composed of two sub-interfaces so that wrappers can get a COMPILE-TIME
// guarantee about which methods they must account for:
//
// - InferenceBackend - methods that each perform one discrete inference call
// (the call begins on entry and ends on return). A wrapper that does
// per-call accounting - e.g. the distributed router's in-flight tracker,
// core/services/nodes.InFlightTrackingClient - embeds only ControlBackend
// and implements every InferenceBackend method explicitly. Adding a method
// to InferenceBackend therefore breaks that wrapper's build until it is
// implemented: inference can't be added without an accounting decision.
// - ControlBackend - everything that is NOT a discrete inference call:
// lifecycle/control-plane operations and the streaming constructors whose
// work spans the returned stream rather than the constructor call. These
// are safe to pass through untracked.
//
// Keep the two sets disjoint; every backend method belongs to exactly one.
type Backend interface {
IsBusy() bool
HealthCheck(ctx context.Context) (bool, error)
InferenceBackend
ControlBackend
}
// InferenceBackend is the subset of Backend whose methods each map to a single
// inference call. Wrappers that account for in-flight work must implement these
// explicitly (see Backend). Do NOT add methods that return a stream client or
// that are control-plane only - those belong in ControlBackend.
type InferenceBackend interface {
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
@@ -53,6 +76,8 @@ type Backend interface {
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error
Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error)
Depth(ctx context.Context, in *pb.DepthRequest, opts ...grpc.CallOption) (*pb.DepthResponse, error)
FaceVerify(ctx context.Context, in *pb.FaceVerifyRequest, opts ...grpc.CallOption) (*pb.FaceVerifyResponse, error)
@@ -60,8 +85,25 @@ type Backend interface {
VoiceVerify(ctx context.Context, in *pb.VoiceVerifyRequest, opts ...grpc.CallOption) (*pb.VoiceVerifyResponse, error)
VoiceAnalyze(ctx context.Context, in *pb.VoiceAnalyzeRequest, opts ...grpc.CallOption) (*pb.VoiceAnalyzeResponse, error)
VoiceEmbed(ctx context.Context, in *pb.VoiceEmbedRequest, opts ...grpc.CallOption) (*pb.VoiceEmbedResponse, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...grpc.CallOption) (*pb.TokenClassifyResponse, error)
Score(ctx context.Context, in *pb.ScoreRequest, opts ...grpc.CallOption) (*pb.ScoreResponse, error)
VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error)
Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error)
SoundDetection(ctx context.Context, in *pb.SoundDetectionRequest, opts ...grpc.CallOption) (*pb.SoundDetectionResponse, error)
AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error)
AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error)
AudioTransform(ctx context.Context, in *pb.AudioTransformRequest, opts ...grpc.CallOption) (*pb.AudioTransformResult, error)
}
// ControlBackend is the subset of Backend that is NOT per-call inference:
// lifecycle/control-plane operations and the streaming constructors whose work
// spans the returned stream rather than the constructor call. In-flight-tracking
// wrappers embed this directly and pass it through untracked (see Backend).
type ControlBackend interface {
IsBusy() bool
HealthCheck(ctx context.Context) (bool, error)
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
Status(ctx context.Context) (*pb.StatusResponse, error)
@@ -70,24 +112,11 @@ type Backend interface {
StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error)
StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error)
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...grpc.CallOption) (*pb.TokenClassifyResponse, error)
Score(ctx context.Context, in *pb.ScoreRequest, opts ...grpc.CallOption) (*pb.ScoreResponse, error)
GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error)
VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error)
Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error)
SoundDetection(ctx context.Context, in *pb.SoundDetectionRequest, opts ...grpc.CallOption) (*pb.SoundDetectionResponse, error)
AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error)
AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error)
AudioTransform(ctx context.Context, in *pb.AudioTransformRequest, opts ...grpc.CallOption) (*pb.AudioTransformResult, error)
// Streaming constructors: these return a stream client immediately; the
// actual inference spans the stream's lifetime, not this call, so they are
// NOT tracked as a single in-flight unit.
AudioTransformStream(ctx context.Context, opts ...grpc.CallOption) (AudioTransformStreamClient, error)
AudioToAudioStream(ctx context.Context, opts ...grpc.CallOption) (AudioToAudioStreamClient, error)