diff --git a/core/services/nodes/inflight.go b/core/services/nodes/inflight.go index b51ef6001..bfc71b999 100644 --- a/core/services/nodes/inflight.go +++ b/core/services/nodes/inflight.go @@ -19,25 +19,40 @@ import ( // Per-replica: a single tracker instance is bound to (nodeID, modelName, replicaIndex). // The router constructs one tracker per Route() result, so each in-flight tick lands // on the correct row even when multiple replicas of the same model live on the same node. +// +// Embedding only grpc.ControlBackend (not the whole grpc.Backend) is what makes +// the in-flight accounting safe by construction: the control-plane methods pass +// through untracked, while every grpc.InferenceBackend method must be declared +// explicitly below to satisfy grpc.Backend. Adding an inference method to the +// interface therefore breaks this file's build (see the var assertion below) +// until it is wrapped with track() - so a new inference path can't be added +// without an in-flight accounting decision. type InFlightTrackingClient struct { - grpc.Backend // embed for passthrough of untracked methods - registry InFlightTracker - nodeID string - modelName string - replicaIndex int + grpc.ControlBackend // passthrough for control-plane / streaming-constructor methods + inner grpc.InferenceBackend // tracked inference methods delegate here + registry InFlightTracker + nodeID string + modelName string + replicaIndex int firstOnce sync.Once // guards onFirstComplete onFirstComplete func() // called once after the first tracked inference call completes } +// Compile-time contract: *InFlightTrackingClient must implement the FULL backend +// surface. Because it embeds only ControlBackend, this fails to compile if any +// InferenceBackend method is left unwrapped. +var _ grpc.Backend = (*InFlightTrackingClient)(nil) + // NewInFlightTrackingClient wraps a gRPC backend client with in-flight tracking. func NewInFlightTrackingClient(inner grpc.Backend, registry InFlightTracker, nodeID, modelName string, replicaIndex int) *InFlightTrackingClient { return &InFlightTrackingClient{ - Backend: inner, - registry: registry, - nodeID: nodeID, - modelName: modelName, - replicaIndex: replicaIndex, + ControlBackend: inner, + inner: inner, + registry: registry, + nodeID: nodeID, + modelName: modelName, + replicaIndex: replicaIndex, } } @@ -91,160 +106,162 @@ func (c *InFlightTrackingClient) reconcile(err error) error { func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) { defer c.track(ctx)() - reply, err := c.Backend.Predict(ctx, in, opts...) + reply, err := c.inner.Predict(ctx, in, opts...) return reply, c.reconcile(err) } func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error { defer c.track(ctx)() - return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...)) + return c.reconcile(c.inner.PredictStream(ctx, in, f, opts...)) } func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) { defer c.track(ctx)() - res, err := c.Backend.Embeddings(ctx, in, opts...) + res, err := c.inner.Embeddings(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { defer c.track(ctx)() - res, err := c.Backend.GenerateImage(ctx, in, opts...) + res, err := c.inner.GenerateImage(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { defer c.track(ctx)() - res, err := c.Backend.GenerateVideo(ctx, in, opts...) + res, err := c.inner.GenerateVideo(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { defer c.track(ctx)() - res, err := c.Backend.TTS(ctx, in, opts...) + res, err := c.inner.TTS(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error { defer c.track(ctx)() - return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...)) + return c.reconcile(c.inner.TTSStream(ctx, in, f, opts...)) } func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) { defer c.track(ctx)() - res, err := c.Backend.SoundGeneration(ctx, in, opts...) + res, err := c.inner.SoundGeneration(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) { defer c.track(ctx)() - res, err := c.Backend.AudioTranscription(ctx, in, opts...) + res, err := c.inner.AudioTranscription(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error { defer c.track(ctx)() - return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)) + return c.reconcile(c.inner.AudioTranscriptionStream(ctx, in, f, opts...)) } func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) { defer c.track(ctx)() - res, err := c.Backend.Detect(ctx, in, opts...) + res, err := c.inner.Detect(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) Depth(ctx context.Context, in *pb.DepthRequest, opts ...ggrpc.CallOption) (*pb.DepthResponse, error) { defer c.track(ctx)() - res, err := c.Backend.Depth(ctx, in, opts...) + res, err := c.inner.Depth(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) { defer c.track(ctx)() - res, err := c.Backend.Rerank(ctx, in, opts...) + res, err := c.inner.Rerank(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) VAD(ctx context.Context, in *pb.VADRequest, opts ...ggrpc.CallOption) (*pb.VADResponse, error) { defer c.track(ctx)() - res, err := c.Backend.VAD(ctx, in, opts...) + res, err := c.inner.VAD(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...ggrpc.CallOption) (*pb.DiarizeResponse, error) { defer c.track(ctx)() - res, err := c.Backend.Diarize(ctx, in, opts...) + res, err := c.inner.Diarize(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) FaceVerify(ctx context.Context, in *pb.FaceVerifyRequest, opts ...ggrpc.CallOption) (*pb.FaceVerifyResponse, error) { defer c.track(ctx)() - res, err := c.Backend.FaceVerify(ctx, in, opts...) + res, err := c.inner.FaceVerify(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) FaceAnalyze(ctx context.Context, in *pb.FaceAnalyzeRequest, opts ...ggrpc.CallOption) (*pb.FaceAnalyzeResponse, error) { defer c.track(ctx)() - res, err := c.Backend.FaceAnalyze(ctx, in, opts...) + res, err := c.inner.FaceAnalyze(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) VoiceVerify(ctx context.Context, in *pb.VoiceVerifyRequest, opts ...ggrpc.CallOption) (*pb.VoiceVerifyResponse, error) { defer c.track(ctx)() - res, err := c.Backend.VoiceVerify(ctx, in, opts...) + res, err := c.inner.VoiceVerify(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) VoiceAnalyze(ctx context.Context, in *pb.VoiceAnalyzeRequest, opts ...ggrpc.CallOption) (*pb.VoiceAnalyzeResponse, error) { defer c.track(ctx)() - res, err := c.Backend.VoiceAnalyze(ctx, in, opts...) + res, err := c.inner.VoiceAnalyze(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) VoiceEmbed(ctx context.Context, in *pb.VoiceEmbedRequest, opts ...ggrpc.CallOption) (*pb.VoiceEmbedResponse, error) { defer c.track(ctx)() - res, err := c.Backend.VoiceEmbed(ctx, in, opts...) + res, err := c.inner.VoiceEmbed(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...ggrpc.CallOption) (*pb.TokenClassifyResponse, error) { defer c.track(ctx)() - res, err := c.Backend.TokenClassify(ctx, in, opts...) + res, err := c.inner.TokenClassify(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) Score(ctx context.Context, in *pb.ScoreRequest, opts ...ggrpc.CallOption) (*pb.ScoreResponse, error) { defer c.track(ctx)() - res, err := c.Backend.Score(ctx, in, opts...) + res, err := c.inner.Score(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) SoundDetection(ctx context.Context, in *pb.SoundDetectionRequest, opts ...ggrpc.CallOption) (*pb.SoundDetectionResponse, error) { defer c.track(ctx)() - res, err := c.Backend.SoundDetection(ctx, in, opts...) + res, err := c.inner.SoundDetection(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...ggrpc.CallOption) (*pb.AudioEncodeResult, error) { defer c.track(ctx)() - res, err := c.Backend.AudioEncode(ctx, in, opts...) + res, err := c.inner.AudioEncode(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...ggrpc.CallOption) (*pb.AudioDecodeResult, error) { defer c.track(ctx)() - res, err := c.Backend.AudioDecode(ctx, in, opts...) + res, err := c.inner.AudioDecode(ctx, in, opts...) return res, c.reconcile(err) } func (c *InFlightTrackingClient) AudioTransform(ctx context.Context, in *pb.AudioTransformRequest, opts ...ggrpc.CallOption) (*pb.AudioTransformResult, error) { defer c.track(ctx)() - res, err := c.Backend.AudioTransform(ctx, in, opts...) + res, err := c.inner.AudioTransform(ctx, in, opts...) return res, c.reconcile(err) } -// AudioTransformStream, AudioToAudioStream and Forward are deliberately left as -// embedded passthrough: they return a stream client and the inference spans the -// stream's lifetime, not the constructor call. Wrapping the constructor with -// track() would increment and immediately decrement (and fire onFirstComplete) -// before any audio flows. Tracking those correctly needs the done() func tied to -// stream close, which the current Backend interface doesn't surface here. +// AudioTransformStream, AudioToAudioStream and Forward live in grpc.ControlBackend +// and are passed through via the embedded field, NOT tracked: they return a stream +// client and the inference spans the stream's lifetime, not the constructor call. +// Wrapping the constructor with track() would increment and immediately decrement +// (and fire onFirstComplete) before any audio flows. Tracking those correctly needs +// the done() func tied to stream close, which the Backend interface doesn't surface +// here. If they ever need tracking, move them to grpc.InferenceBackend - the build +// will then force an explicit wrapper here. diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index f4cd511ac..838ab9865 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -41,11 +41,34 @@ func buildClient(address string, parallel bool, wd WatchDog, enableWatchDog bool } } +// Backend is the full client surface of a model backend. It is deliberately +// composed of two sub-interfaces so that wrappers can get a COMPILE-TIME +// guarantee about which methods they must account for: +// +// - InferenceBackend - methods that each perform one discrete inference call +// (the call begins on entry and ends on return). A wrapper that does +// per-call accounting - e.g. the distributed router's in-flight tracker, +// core/services/nodes.InFlightTrackingClient - embeds only ControlBackend +// and implements every InferenceBackend method explicitly. Adding a method +// to InferenceBackend therefore breaks that wrapper's build until it is +// implemented: inference can't be added without an accounting decision. +// - ControlBackend - everything that is NOT a discrete inference call: +// lifecycle/control-plane operations and the streaming constructors whose +// work spans the returned stream rather than the constructor call. These +// are safe to pass through untracked. +// +// Keep the two sets disjoint; every backend method belongs to exactly one. type Backend interface { - IsBusy() bool - HealthCheck(ctx context.Context) (bool, error) + InferenceBackend + ControlBackend +} + +// InferenceBackend is the subset of Backend whose methods each map to a single +// inference call. Wrappers that account for in-flight work must implement these +// explicitly (see Backend). Do NOT add methods that return a stream client or +// that are control-plane only - those belong in ControlBackend. +type InferenceBackend interface { Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) - LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) @@ -53,6 +76,8 @@ type Backend interface { TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) + AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) + AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error) Depth(ctx context.Context, in *pb.DepthRequest, opts ...grpc.CallOption) (*pb.DepthResponse, error) FaceVerify(ctx context.Context, in *pb.FaceVerifyRequest, opts ...grpc.CallOption) (*pb.FaceVerifyResponse, error) @@ -60,8 +85,25 @@ type Backend interface { VoiceVerify(ctx context.Context, in *pb.VoiceVerifyRequest, opts ...grpc.CallOption) (*pb.VoiceVerifyResponse, error) VoiceAnalyze(ctx context.Context, in *pb.VoiceAnalyzeRequest, opts ...grpc.CallOption) (*pb.VoiceAnalyzeResponse, error) VoiceEmbed(ctx context.Context, in *pb.VoiceEmbedRequest, opts ...grpc.CallOption) (*pb.VoiceEmbedResponse, error) - AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) - AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...grpc.CallOption) error + Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) + TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...grpc.CallOption) (*pb.TokenClassifyResponse, error) + Score(ctx context.Context, in *pb.ScoreRequest, opts ...grpc.CallOption) (*pb.ScoreResponse, error) + VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) + Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error) + SoundDetection(ctx context.Context, in *pb.SoundDetectionRequest, opts ...grpc.CallOption) (*pb.SoundDetectionResponse, error) + AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error) + AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error) + AudioTransform(ctx context.Context, in *pb.AudioTransformRequest, opts ...grpc.CallOption) (*pb.AudioTransformResult, error) +} + +// ControlBackend is the subset of Backend that is NOT per-call inference: +// lifecycle/control-plane operations and the streaming constructors whose work +// spans the returned stream rather than the constructor call. In-flight-tracking +// wrappers embed this directly and pass it through untracked (see Backend). +type ControlBackend interface { + IsBusy() bool + HealthCheck(ctx context.Context) (bool, error) + LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) Status(ctx context.Context) (*pb.StatusResponse, error) @@ -70,24 +112,11 @@ type Backend interface { StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) - Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) - - TokenClassify(ctx context.Context, in *pb.TokenClassifyRequest, opts ...grpc.CallOption) (*pb.TokenClassifyResponse, error) - - Score(ctx context.Context, in *pb.ScoreRequest, opts ...grpc.CallOption) (*pb.ScoreResponse, error) - GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) - VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) - - Diarize(ctx context.Context, in *pb.DiarizeRequest, opts ...grpc.CallOption) (*pb.DiarizeResponse, error) - - SoundDetection(ctx context.Context, in *pb.SoundDetectionRequest, opts ...grpc.CallOption) (*pb.SoundDetectionResponse, error) - - AudioEncode(ctx context.Context, in *pb.AudioEncodeRequest, opts ...grpc.CallOption) (*pb.AudioEncodeResult, error) - AudioDecode(ctx context.Context, in *pb.AudioDecodeRequest, opts ...grpc.CallOption) (*pb.AudioDecodeResult, error) - - AudioTransform(ctx context.Context, in *pb.AudioTransformRequest, opts ...grpc.CallOption) (*pb.AudioTransformResult, error) + // Streaming constructors: these return a stream client immediately; the + // actual inference spans the stream's lifetime, not this call, so they are + // NOT tracked as a single in-flight unit. AudioTransformStream(ctx context.Context, opts ...grpc.CallOption) (AudioTransformStreamClient, error) AudioToAudioStream(ctx context.Context, opts ...grpc.CallOption) (AudioToAudioStreamClient, error)