refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream}

Adds context.Context as first parameter to the AIModel interface methods
that wrap whisper-style transcription. Server-side gRPC handler now
forwards the per-RPC ctx (server-streaming uses stream.Context()).
Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter;
none uses it yet — the actual cancellation primitive lands in the next
commit so this is pure plumbing.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-05-07 16:36:33 +00:00
parent 9fac275fd3
commit 93c48e19f0
9 changed files with 24 additions and 15 deletions

View File

@@ -2,6 +2,7 @@ package main
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os" "os"
@@ -998,7 +999,7 @@ func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error {
// Transcription // Transcription
// ============================================================= // =============================================================
func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) { func (s *SherpaBackend) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if s.onlineRecognizer != 0 { if s.onlineRecognizer != 0 {
return s.runOnlineASR(req, nil) return s.runOnlineASR(req, nil)
} }
@@ -1056,6 +1057,7 @@ func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.Transc
// Closes `results` before returning so the server wrapper's reader // Closes `results` before returning so the server wrapper's reader
// goroutine can exit. // goroutine can exit.
func (s *SherpaBackend) AudioTranscriptionStream( func (s *SherpaBackend) AudioTranscriptionStream(
_ context.Context,
req *pb.TranscriptRequest, req *pb.TranscriptRequest,
results chan *pb.TranscriptStreamResponse, results chan *pb.TranscriptStreamResponse,
) error { ) error {

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -79,7 +80,7 @@ var _ = Describe("Sherpa-ONNX", func() {
}) })
It("rejects AudioTranscription", func() { It("rejects AudioTranscription", func() {
_, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{ _, err := (&SherpaBackend{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/nonexistent.wav", Dst: "/tmp/nonexistent.wav",
}) })
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -480,7 +481,7 @@ func (w *byteWriter) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) { func (v *VibevoiceCpp) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if v.asrModel == "" { if v.asrModel == "" {
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded") return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded")
} }
@@ -623,9 +624,9 @@ func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, erro
// transcription, emit each segment's content as a delta, then close // transcription, emit each segment's content as a delta, then close
// with a final_result whose Text equals the concatenated deltas (the // with a final_result whose Text equals the concatenated deltas (the
// e2e harness asserts those match). // e2e harness asserts those match).
func (v *VibevoiceCpp) AudioTranscriptionStream(req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error { func (v *VibevoiceCpp) AudioTranscriptionStream(ctx context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
defer close(results) defer close(results)
res, err := v.AudioTranscription(req) res, err := v.AudioTranscription(ctx, req)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -107,7 +107,7 @@ var _ = Describe("VibeVoice-cpp", func() {
}) })
It("rejects AudioTranscription without a loaded ASR model", func() { It("rejects AudioTranscription without a loaded ASR model", func() {
_, err := (&VibevoiceCpp{}).AudioTranscription(&pb.TranscriptRequest{ _, err := (&VibevoiceCpp{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/some.wav", Dst: "/tmp/some.wav",
}) })
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
@@ -255,7 +255,7 @@ var _ = Describe("VibeVoice-cpp", func() {
It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() { It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() {
ch := make(chan *pb.TranscriptStreamResponse, 4) ch := make(chan *pb.TranscriptStreamResponse, 4)
err := (&VibevoiceCpp{}).AudioTranscriptionStream(&pb.TranscriptRequest{ err := (&VibevoiceCpp{}).AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/some.wav", Dst: "/tmp/some.wav",
}, ch) }, ch)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"strings" "strings"
@@ -27,7 +28,7 @@ func (v *Voxtral) Load(opts *pb.ModelOptions) error {
return nil return nil
} }
func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { func (v *Voxtral) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "voxtral") dir, err := os.MkdirTemp("", "voxtral")
if err != nil { if err != nil {
return pb.TranscriptResult{}, err return pb.TranscriptResult{}, err

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -92,7 +93,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
}, nil }, nil
} }
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "whisper") dir, err := os.MkdirTemp("", "whisper")
if err != nil { if err != nil {
return pb.TranscriptResult{}, err return pb.TranscriptResult{}, err

View File

@@ -3,6 +3,7 @@ package base
// This is a wrapper to satisfy the GRPC service interface // This is a wrapper to satisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import ( import (
"context"
"fmt" "fmt"
"os" "os"
@@ -57,11 +58,11 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error {
return fmt.Errorf("unimplemented") return fmt.Errorf("unimplemented")
} }
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) { func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) {
return pb.TranscriptResult{}, fmt.Errorf("unimplemented") return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
} }
func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error { func (llm *Base) AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
return fmt.Errorf("unimplemented") return fmt.Errorf("unimplemented")
} }

View File

@@ -1,6 +1,8 @@
package grpc package grpc
import ( import (
"context"
pb "github.com/mudler/LocalAI/pkg/grpc/proto" pb "github.com/mudler/LocalAI/pkg/grpc/proto"
) )
@@ -22,8 +24,8 @@ type AIModel interface {
VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error) VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error)
VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error) VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error)
VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error) VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error)
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error)
AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
TTS(*pb.TTSRequest) error TTS(*pb.TTSRequest) error
TTSStream(*pb.TTSRequest, chan []byte) error TTSStream(*pb.TTSRequest, chan []byte) error
SoundGeneration(*pb.SoundGenerationRequest) error SoundGeneration(*pb.SoundGenerationRequest) error

View File

@@ -218,7 +218,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
s.llm.Lock() s.llm.Lock()
defer s.llm.Unlock() defer s.llm.Unlock()
} }
result, err := s.llm.AudioTranscription(in) result, err := s.llm.AudioTranscription(ctx, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -260,7 +260,7 @@ func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Ba
done <- true done <- true
}() }()
err := s.llm.AudioTranscriptionStream(in, resultChan) err := s.llm.AudioTranscriptionStream(stream.Context(), in, resultChan)
<-done <-done
return err return err