refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream}

Adds context.Context as first parameter to the AIModel interface methods
that wrap whisper-style transcription. Server-side gRPC handler now
forwards the per-RPC ctx (server-streaming uses stream.Context()).
Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter;
none uses it yet — the actual cancellation primitive lands in the next
commit so this is pure plumbing.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-05-07 16:36:33 +00:00
parent 9fac275fd3
commit 93c48e19f0
9 changed files with 24 additions and 15 deletions

View File

@@ -2,6 +2,7 @@ package main
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"os"
@@ -998,7 +999,7 @@ func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error {
// Transcription
// =============================================================
func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (s *SherpaBackend) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if s.onlineRecognizer != 0 {
return s.runOnlineASR(req, nil)
}
@@ -1056,6 +1057,7 @@ func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.Transc
// Closes `results` before returning so the server wrapper's reader
// goroutine can exit.
func (s *SherpaBackend) AudioTranscriptionStream(
_ context.Context,
req *pb.TranscriptRequest,
results chan *pb.TranscriptStreamResponse,
) error {

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"os"
"path/filepath"
"testing"
@@ -79,7 +80,7 @@ var _ = Describe("Sherpa-ONNX", func() {
})
It("rejects AudioTranscription", func() {
_, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{
_, err := (&SherpaBackend{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/nonexistent.wav",
})
Expect(err).To(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/json"
"fmt"
"io"
@@ -480,7 +481,7 @@ func (w *byteWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (v *VibevoiceCpp) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if v.asrModel == "" {
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded")
}
@@ -623,9 +624,9 @@ func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, erro
// transcription, emit each segment's content as a delta, then close
// with a final_result whose Text equals the concatenated deltas (the
// e2e harness asserts those match).
func (v *VibevoiceCpp) AudioTranscriptionStream(req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
func (v *VibevoiceCpp) AudioTranscriptionStream(ctx context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
defer close(results)
res, err := v.AudioTranscription(req)
res, err := v.AudioTranscription(ctx, req)
if err != nil {
return err
}

View File

@@ -107,7 +107,7 @@ var _ = Describe("VibeVoice-cpp", func() {
})
It("rejects AudioTranscription without a loaded ASR model", func() {
_, err := (&VibevoiceCpp{}).AudioTranscription(&pb.TranscriptRequest{
_, err := (&VibevoiceCpp{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/some.wav",
})
Expect(err).To(HaveOccurred())
@@ -255,7 +255,7 @@ var _ = Describe("VibeVoice-cpp", func() {
It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() {
ch := make(chan *pb.TranscriptStreamResponse, 4)
err := (&VibevoiceCpp{}).AudioTranscriptionStream(&pb.TranscriptRequest{
err := (&VibevoiceCpp{}).AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/some.wav",
}, ch)
Expect(err).To(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"os"
"strings"
@@ -27,7 +28,7 @@ func (v *Voxtral) Load(opts *pb.ModelOptions) error {
return nil
}
func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (v *Voxtral) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "voxtral")
if err != nil {
return pb.TranscriptResult{}, err

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -92,7 +93,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
}, nil
}
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err

View File

@@ -3,6 +3,7 @@ package base
// This is a wrapper to satisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"context"
"fmt"
"os"
@@ -57,11 +58,11 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) {
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
}
func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
func (llm *Base) AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
return fmt.Errorf("unimplemented")
}

View File

@@ -1,6 +1,8 @@
package grpc
import (
"context"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
@@ -22,8 +24,8 @@ type AIModel interface {
VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error)
VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error)
VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error)
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error)
AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
TTS(*pb.TTSRequest) error
TTSStream(*pb.TTSRequest, chan []byte) error
SoundGeneration(*pb.SoundGenerationRequest) error

View File

@@ -218,7 +218,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
s.llm.Lock()
defer s.llm.Unlock()
}
result, err := s.llm.AudioTranscription(in)
result, err := s.llm.AudioTranscription(ctx, in)
if err != nil {
return nil, err
}
@@ -260,7 +260,7 @@ func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Ba
done <- true
}()
err := s.llm.AudioTranscriptionStream(in, resultChan)
err := s.llm.AudioTranscriptionStream(stream.Context(), in, resultChan)
<-done
return err