From 93c48e19f06cfdc0f65f7565b24c1cac406de407 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 7 May 2026 16:36:33 +0000 Subject: [PATCH] refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream} MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds context.Context as first parameter to the AIModel interface methods that wrap whisper-style transcription. Server-side gRPC handler now forwards the per-RPC ctx (server-streaming uses stream.Context()). Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter; none uses it yet — the actual cancellation primitive lands in the next commit so this is pure plumbing. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto --- backend/go/sherpa-onnx/backend.go | 4 +++- backend/go/sherpa-onnx/backend_test.go | 3 ++- backend/go/vibevoice-cpp/govibevoicecpp.go | 7 ++++--- backend/go/vibevoice-cpp/vibevoicecpp_test.go | 4 ++-- backend/go/voxtral/govoxtral.go | 3 ++- backend/go/whisper/gowhisper.go | 3 ++- pkg/grpc/base/base.go | 5 +++-- pkg/grpc/interface.go | 6 ++++-- pkg/grpc/server.go | 4 ++-- 9 files changed, 24 insertions(+), 15 deletions(-) diff --git a/backend/go/sherpa-onnx/backend.go b/backend/go/sherpa-onnx/backend.go index d73474af6..91b797aa0 100644 --- a/backend/go/sherpa-onnx/backend.go +++ b/backend/go/sherpa-onnx/backend.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "encoding/binary" "fmt" "os" @@ -998,7 +999,7 @@ func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error { // Transcription // ============================================================= -func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (s *SherpaBackend) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) { if s.onlineRecognizer != 0 { return s.runOnlineASR(req, nil) } @@ -1056,6 +1057,7 @@ func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.Transc // Closes `results` before returning so the server wrapper's reader // goroutine can exit. func (s *SherpaBackend) AudioTranscriptionStream( + _ context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse, ) error { diff --git a/backend/go/sherpa-onnx/backend_test.go b/backend/go/sherpa-onnx/backend_test.go index 46ad6d3a2..b70bc3e67 100644 --- a/backend/go/sherpa-onnx/backend_test.go +++ b/backend/go/sherpa-onnx/backend_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "os" "path/filepath" "testing" @@ -79,7 +80,7 @@ var _ = Describe("Sherpa-ONNX", func() { }) It("rejects AudioTranscription", func() { - _, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{ + _, err := (&SherpaBackend{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{ Dst: "/tmp/nonexistent.wav", }) Expect(err).To(HaveOccurred()) diff --git a/backend/go/vibevoice-cpp/govibevoicecpp.go b/backend/go/vibevoice-cpp/govibevoicecpp.go index 242f00c31..cf4945416 100644 --- a/backend/go/vibevoice-cpp/govibevoicecpp.go +++ b/backend/go/vibevoice-cpp/govibevoicecpp.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "io" @@ -480,7 +481,7 @@ func (w *byteWriter) Write(p []byte) (int, error) { return len(p), nil } -func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (v *VibevoiceCpp) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) { if v.asrModel == "" { return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded") } @@ -623,9 +624,9 @@ func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, erro // transcription, emit each segment's content as a delta, then close // with a final_result whose Text equals the concatenated deltas (the // e2e harness asserts those match). -func (v *VibevoiceCpp) AudioTranscriptionStream(req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error { +func (v *VibevoiceCpp) AudioTranscriptionStream(ctx context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error { defer close(results) - res, err := v.AudioTranscription(req) + res, err := v.AudioTranscription(ctx, req) if err != nil { return err } diff --git a/backend/go/vibevoice-cpp/vibevoicecpp_test.go b/backend/go/vibevoice-cpp/vibevoicecpp_test.go index ffb36629e..f58b24365 100644 --- a/backend/go/vibevoice-cpp/vibevoicecpp_test.go +++ b/backend/go/vibevoice-cpp/vibevoicecpp_test.go @@ -107,7 +107,7 @@ var _ = Describe("VibeVoice-cpp", func() { }) It("rejects AudioTranscription without a loaded ASR model", func() { - _, err := (&VibevoiceCpp{}).AudioTranscription(&pb.TranscriptRequest{ + _, err := (&VibevoiceCpp{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{ Dst: "/tmp/some.wav", }) Expect(err).To(HaveOccurred()) @@ -255,7 +255,7 @@ var _ = Describe("VibeVoice-cpp", func() { It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() { ch := make(chan *pb.TranscriptStreamResponse, 4) - err := (&VibevoiceCpp{}).AudioTranscriptionStream(&pb.TranscriptRequest{ + err := (&VibevoiceCpp{}).AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{ Dst: "/tmp/some.wav", }, ch) Expect(err).To(HaveOccurred()) diff --git a/backend/go/voxtral/govoxtral.go b/backend/go/voxtral/govoxtral.go index 9a296a589..ac9c5148a 100644 --- a/backend/go/voxtral/govoxtral.go +++ b/backend/go/voxtral/govoxtral.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "strings" @@ -27,7 +28,7 @@ func (v *Voxtral) Load(opts *pb.ModelOptions) error { return nil } -func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (v *Voxtral) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { dir, err := os.MkdirTemp("", "voxtral") if err != nil { return pb.TranscriptResult{}, err diff --git a/backend/go/whisper/gowhisper.go b/backend/go/whisper/gowhisper.go index 56d794f0f..0343cb76e 100644 --- a/backend/go/whisper/gowhisper.go +++ b/backend/go/whisper/gowhisper.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "path/filepath" @@ -92,7 +93,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) { }, nil } -func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { dir, err := os.MkdirTemp("", "whisper") if err != nil { return pb.TranscriptResult{}, err diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 1fdf258cd..c5a0fc62f 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -3,6 +3,7 @@ package base // This is a wrapper to satisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( + "context" "fmt" "os" @@ -57,11 +58,11 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) { return pb.TranscriptResult{}, fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error { +func (llm *Base) AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error { return fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 8fb800aec..82f8af23b 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -1,6 +1,8 @@ package grpc import ( + "context" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) @@ -22,8 +24,8 @@ type AIModel interface { VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error) VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error) VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error) - AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) - AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error + AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) + AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error TTS(*pb.TTSRequest) error TTSStream(*pb.TTSRequest, chan []byte) error SoundGeneration(*pb.SoundGenerationRequest) error diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index a931e6556..396547ca9 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -218,7 +218,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques s.llm.Lock() defer s.llm.Unlock() } - result, err := s.llm.AudioTranscription(in) + result, err := s.llm.AudioTranscription(ctx, in) if err != nil { return nil, err } @@ -260,7 +260,7 @@ func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Ba done <- true }() - err := s.llm.AudioTranscriptionStream(in, resultChan) + err := s.llm.AudioTranscriptionStream(stream.Context(), in, resultChan) <-done return err