mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream}
Adds context.Context as first parameter to the AIModel interface methods that wrap whisper-style transcription. Server-side gRPC handler now forwards the per-RPC ctx (server-streaming uses stream.Context()). Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter; none uses it yet — the actual cancellation primitive lands in the next commit so this is pure plumbing. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -998,7 +999,7 @@ func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error {
|
||||
// Transcription
|
||||
// =============================================================
|
||||
|
||||
func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (s *SherpaBackend) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if s.onlineRecognizer != 0 {
|
||||
return s.runOnlineASR(req, nil)
|
||||
}
|
||||
@@ -1056,6 +1057,7 @@ func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.Transc
|
||||
// Closes `results` before returning so the server wrapper's reader
|
||||
// goroutine can exit.
|
||||
func (s *SherpaBackend) AudioTranscriptionStream(
|
||||
_ context.Context,
|
||||
req *pb.TranscriptRequest,
|
||||
results chan *pb.TranscriptStreamResponse,
|
||||
) error {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -79,7 +80,7 @@ var _ = Describe("Sherpa-ONNX", func() {
|
||||
})
|
||||
|
||||
It("rejects AudioTranscription", func() {
|
||||
_, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{
|
||||
_, err := (&SherpaBackend{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: "/tmp/nonexistent.wav",
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -480,7 +481,7 @@ func (w *byteWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (v *VibevoiceCpp) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
if v.asrModel == "" {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded")
|
||||
}
|
||||
@@ -623,9 +624,9 @@ func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, erro
|
||||
// transcription, emit each segment's content as a delta, then close
|
||||
// with a final_result whose Text equals the concatenated deltas (the
|
||||
// e2e harness asserts those match).
|
||||
func (v *VibevoiceCpp) AudioTranscriptionStream(req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
func (v *VibevoiceCpp) AudioTranscriptionStream(ctx context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||
defer close(results)
|
||||
res, err := v.AudioTranscription(req)
|
||||
res, err := v.AudioTranscription(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ var _ = Describe("VibeVoice-cpp", func() {
|
||||
})
|
||||
|
||||
It("rejects AudioTranscription without a loaded ASR model", func() {
|
||||
_, err := (&VibevoiceCpp{}).AudioTranscription(&pb.TranscriptRequest{
|
||||
_, err := (&VibevoiceCpp{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: "/tmp/some.wav",
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -255,7 +255,7 @@ var _ = Describe("VibeVoice-cpp", func() {
|
||||
|
||||
It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() {
|
||||
ch := make(chan *pb.TranscriptStreamResponse, 4)
|
||||
err := (&VibevoiceCpp{}).AudioTranscriptionStream(&pb.TranscriptRequest{
|
||||
err := (&VibevoiceCpp{}).AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: "/tmp/some.wav",
|
||||
}, ch)
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -27,7 +28,7 @@ func (v *Voxtral) Load(opts *pb.ModelOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (v *Voxtral) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
dir, err := os.MkdirTemp("", "voxtral")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -92,7 +93,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
|
||||
@@ -3,6 +3,7 @@ package base
|
||||
// This is a wrapper to satisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -57,11 +58,11 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
|
||||
func (llm *Base) AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
|
||||
return fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
)
|
||||
|
||||
@@ -22,8 +24,8 @@ type AIModel interface {
|
||||
VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error)
|
||||
VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error)
|
||||
VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error)
|
||||
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||
AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
|
||||
AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||
AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
|
||||
TTS(*pb.TTSRequest) error
|
||||
TTSStream(*pb.TTSRequest, chan []byte) error
|
||||
SoundGeneration(*pb.SoundGenerationRequest) error
|
||||
|
||||
@@ -218,7 +218,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
||||
s.llm.Lock()
|
||||
defer s.llm.Unlock()
|
||||
}
|
||||
result, err := s.llm.AudioTranscription(in)
|
||||
result, err := s.llm.AudioTranscription(ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -260,7 +260,7 @@ func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Ba
|
||||
done <- true
|
||||
}()
|
||||
|
||||
err := s.llm.AudioTranscriptionStream(in, resultChan)
|
||||
err := s.llm.AudioTranscriptionStream(stream.Context(), in, resultChan)
|
||||
<-done
|
||||
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user