mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream}
Adds context.Context as first parameter to the AIModel interface methods that wrap whisper-style transcription. Server-side gRPC handler now forwards the per-RPC ctx (server-streaming uses stream.Context()). Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter; none uses it yet — the actual cancellation primitive lands in the next commit so this is pure plumbing. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -998,7 +999,7 @@ func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error {
|
|||||||
// Transcription
|
// Transcription
|
||||||
// =============================================================
|
// =============================================================
|
||||||
|
|
||||||
func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
func (s *SherpaBackend) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
if s.onlineRecognizer != 0 {
|
if s.onlineRecognizer != 0 {
|
||||||
return s.runOnlineASR(req, nil)
|
return s.runOnlineASR(req, nil)
|
||||||
}
|
}
|
||||||
@@ -1056,6 +1057,7 @@ func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.Transc
|
|||||||
// Closes `results` before returning so the server wrapper's reader
|
// Closes `results` before returning so the server wrapper's reader
|
||||||
// goroutine can exit.
|
// goroutine can exit.
|
||||||
func (s *SherpaBackend) AudioTranscriptionStream(
|
func (s *SherpaBackend) AudioTranscriptionStream(
|
||||||
|
_ context.Context,
|
||||||
req *pb.TranscriptRequest,
|
req *pb.TranscriptRequest,
|
||||||
results chan *pb.TranscriptStreamResponse,
|
results chan *pb.TranscriptStreamResponse,
|
||||||
) error {
|
) error {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -79,7 +80,7 @@ var _ = Describe("Sherpa-ONNX", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("rejects AudioTranscription", func() {
|
It("rejects AudioTranscription", func() {
|
||||||
_, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{
|
_, err := (&SherpaBackend{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||||
Dst: "/tmp/nonexistent.wav",
|
Dst: "/tmp/nonexistent.wav",
|
||||||
})
|
})
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -480,7 +481,7 @@ func (w *byteWriter) Write(p []byte) (int, error) {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
func (v *VibevoiceCpp) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
if v.asrModel == "" {
|
if v.asrModel == "" {
|
||||||
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded")
|
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded")
|
||||||
}
|
}
|
||||||
@@ -623,9 +624,9 @@ func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, erro
|
|||||||
// transcription, emit each segment's content as a delta, then close
|
// transcription, emit each segment's content as a delta, then close
|
||||||
// with a final_result whose Text equals the concatenated deltas (the
|
// with a final_result whose Text equals the concatenated deltas (the
|
||||||
// e2e harness asserts those match).
|
// e2e harness asserts those match).
|
||||||
func (v *VibevoiceCpp) AudioTranscriptionStream(req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
func (v *VibevoiceCpp) AudioTranscriptionStream(ctx context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
|
||||||
defer close(results)
|
defer close(results)
|
||||||
res, err := v.AudioTranscription(req)
|
res, err := v.AudioTranscription(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ var _ = Describe("VibeVoice-cpp", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("rejects AudioTranscription without a loaded ASR model", func() {
|
It("rejects AudioTranscription without a loaded ASR model", func() {
|
||||||
_, err := (&VibevoiceCpp{}).AudioTranscription(&pb.TranscriptRequest{
|
_, err := (&VibevoiceCpp{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||||
Dst: "/tmp/some.wav",
|
Dst: "/tmp/some.wav",
|
||||||
})
|
})
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
@@ -255,7 +255,7 @@ var _ = Describe("VibeVoice-cpp", func() {
|
|||||||
|
|
||||||
It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() {
|
It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() {
|
||||||
ch := make(chan *pb.TranscriptStreamResponse, 4)
|
ch := make(chan *pb.TranscriptStreamResponse, 4)
|
||||||
err := (&VibevoiceCpp{}).AudioTranscriptionStream(&pb.TranscriptRequest{
|
err := (&VibevoiceCpp{}).AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
|
||||||
Dst: "/tmp/some.wav",
|
Dst: "/tmp/some.wav",
|
||||||
}, ch)
|
}, ch)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -27,7 +28,7 @@ func (v *Voxtral) Load(opts *pb.ModelOptions) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
func (v *Voxtral) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
dir, err := os.MkdirTemp("", "voxtral")
|
dir, err := os.MkdirTemp("", "voxtral")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pb.TranscriptResult{}, err
|
return pb.TranscriptResult{}, err
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -92,7 +93,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
dir, err := os.MkdirTemp("", "whisper")
|
dir, err := os.MkdirTemp("", "whisper")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pb.TranscriptResult{}, err
|
return pb.TranscriptResult{}, err
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package base
|
|||||||
// This is a wrapper to satisfy the GRPC service interface
|
// This is a wrapper to satisfy the GRPC service interface
|
||||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -57,11 +58,11 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error {
|
|||||||
return fmt.Errorf("unimplemented")
|
return fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
|
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
|
func (llm *Base) AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
|
||||||
return fmt.Errorf("unimplemented")
|
return fmt.Errorf("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package grpc
|
package grpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,8 +24,8 @@ type AIModel interface {
|
|||||||
VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error)
|
VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error)
|
||||||
VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error)
|
VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error)
|
||||||
VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error)
|
VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error)
|
||||||
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
|
AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error)
|
||||||
AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
|
AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
|
||||||
TTS(*pb.TTSRequest) error
|
TTS(*pb.TTSRequest) error
|
||||||
TTSStream(*pb.TTSRequest, chan []byte) error
|
TTSStream(*pb.TTSRequest, chan []byte) error
|
||||||
SoundGeneration(*pb.SoundGenerationRequest) error
|
SoundGeneration(*pb.SoundGenerationRequest) error
|
||||||
|
|||||||
@@ -218,7 +218,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
|
|||||||
s.llm.Lock()
|
s.llm.Lock()
|
||||||
defer s.llm.Unlock()
|
defer s.llm.Unlock()
|
||||||
}
|
}
|
||||||
result, err := s.llm.AudioTranscription(in)
|
result, err := s.llm.AudioTranscription(ctx, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -260,7 +260,7 @@ func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Ba
|
|||||||
done <- true
|
done <- true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := s.llm.AudioTranscriptionStream(in, resultChan)
|
err := s.llm.AudioTranscriptionStream(stream.Context(), in, resultChan)
|
||||||
<-done
|
<-done
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user