diff --git a/backend/go/sherpa-onnx/backend.go b/backend/go/sherpa-onnx/backend.go index d73474af6..91b797aa0 100644 --- a/backend/go/sherpa-onnx/backend.go +++ b/backend/go/sherpa-onnx/backend.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "encoding/binary" "fmt" "os" @@ -998,7 +999,7 @@ func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error { // Transcription // ============================================================= -func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (s *SherpaBackend) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) { if s.onlineRecognizer != 0 { return s.runOnlineASR(req, nil) } @@ -1056,6 +1057,7 @@ func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.Transc // Closes `results` before returning so the server wrapper's reader // goroutine can exit. func (s *SherpaBackend) AudioTranscriptionStream( + _ context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse, ) error { diff --git a/backend/go/sherpa-onnx/backend_test.go b/backend/go/sherpa-onnx/backend_test.go index 46ad6d3a2..b70bc3e67 100644 --- a/backend/go/sherpa-onnx/backend_test.go +++ b/backend/go/sherpa-onnx/backend_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "os" "path/filepath" "testing" @@ -79,7 +80,7 @@ var _ = Describe("Sherpa-ONNX", func() { }) It("rejects AudioTranscription", func() { - _, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{ + _, err := (&SherpaBackend{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{ Dst: "/tmp/nonexistent.wav", }) Expect(err).To(HaveOccurred()) diff --git a/backend/go/vibevoice-cpp/govibevoicecpp.go b/backend/go/vibevoice-cpp/govibevoicecpp.go index 242f00c31..cf4945416 100644 --- a/backend/go/vibevoice-cpp/govibevoicecpp.go +++ b/backend/go/vibevoice-cpp/govibevoicecpp.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "io" @@ -480,7 +481,7 @@ func (w *byteWriter) Write(p []byte) (int, error) { return len(p), nil } -func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (v *VibevoiceCpp) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) { if v.asrModel == "" { return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded") } @@ -623,9 +624,9 @@ func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, erro // transcription, emit each segment's content as a delta, then close // with a final_result whose Text equals the concatenated deltas (the // e2e harness asserts those match). -func (v *VibevoiceCpp) AudioTranscriptionStream(req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error { +func (v *VibevoiceCpp) AudioTranscriptionStream(ctx context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error { defer close(results) - res, err := v.AudioTranscription(req) + res, err := v.AudioTranscription(ctx, req) if err != nil { return err } diff --git a/backend/go/vibevoice-cpp/vibevoicecpp_test.go b/backend/go/vibevoice-cpp/vibevoicecpp_test.go index ffb36629e..f58b24365 100644 --- a/backend/go/vibevoice-cpp/vibevoicecpp_test.go +++ b/backend/go/vibevoice-cpp/vibevoicecpp_test.go @@ -107,7 +107,7 @@ var _ = Describe("VibeVoice-cpp", func() { }) It("rejects AudioTranscription without a loaded ASR model", func() { - _, err := (&VibevoiceCpp{}).AudioTranscription(&pb.TranscriptRequest{ + _, err := (&VibevoiceCpp{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{ Dst: "/tmp/some.wav", }) Expect(err).To(HaveOccurred()) @@ -255,7 +255,7 @@ var _ = Describe("VibeVoice-cpp", func() { It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() { ch := make(chan *pb.TranscriptStreamResponse, 4) - err := (&VibevoiceCpp{}).AudioTranscriptionStream(&pb.TranscriptRequest{ + err := (&VibevoiceCpp{}).AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{ Dst: "/tmp/some.wav", }, ch) Expect(err).To(HaveOccurred()) diff --git a/backend/go/voxtral/govoxtral.go b/backend/go/voxtral/govoxtral.go index 9a296a589..ac9c5148a 100644 --- a/backend/go/voxtral/govoxtral.go +++ b/backend/go/voxtral/govoxtral.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "strings" @@ -27,7 +28,7 @@ func (v *Voxtral) Load(opts *pb.ModelOptions) error { return nil } -func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (v *Voxtral) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { dir, err := os.MkdirTemp("", "voxtral") if err != nil { return pb.TranscriptResult{}, err diff --git a/backend/go/whisper/cpp/gowhisper.cpp b/backend/go/whisper/cpp/gowhisper.cpp index f1756d780..77dbc74f9 100644 --- a/backend/go/whisper/cpp/gowhisper.cpp +++ b/backend/go/whisper/cpp/gowhisper.cpp @@ -1,12 +1,23 @@ #include "gowhisper.h" #include "ggml-backend.h" #include "whisper.h" +#include #include static struct whisper_vad_context *vctx; static struct whisper_context *ctx; static std::vector flat_segs; +static std::atomic g_abort{0}; + +static bool abort_cb(void * /*user_data*/) { + return g_abort.load(std::memory_order_relaxed) != 0; +} + +extern "C" void set_abort(int v) { + g_abort.store(v, std::memory_order_relaxed); +} + static void ggml_log_cb(enum ggml_log_level level, const char *log, void *data) { const char *level_str; @@ -124,10 +135,20 @@ int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz, wparams.tdrz_enable = tdrz; wparams.initial_prompt = prompt; + // Reset stale abort flag from any prior cancelled call, then install the + // ggml abort hook so a subsequent set_abort(1) from Go aborts the next + // compute graph step. + g_abort.store(0, std::memory_order_relaxed); + wparams.abort_callback = abort_cb; + wparams.abort_callback_user_data = nullptr; + fprintf(stderr, "info: Enable tdrz: %d\n", tdrz); fprintf(stderr, "info: Initial prompt: \"%s\"\n", prompt); if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) { + if (g_abort.load(std::memory_order_relaxed)) { + return 2; // aborted by client + } fprintf(stderr, "error: transcription failed\n"); return 1; } diff --git a/backend/go/whisper/cpp/gowhisper.h b/backend/go/whisper/cpp/gowhisper.h index 0e061cf93..b8c7b6cb6 100644 --- a/backend/go/whisper/cpp/gowhisper.h +++ b/backend/go/whisper/cpp/gowhisper.h @@ -15,4 +15,5 @@ int64_t get_segment_t1(int i); int n_tokens(int i); int32_t get_token_id(int i, int j); bool get_segment_speaker_turn_next(int i); +void set_abort(int v); } diff --git a/backend/go/whisper/gowhisper.go b/backend/go/whisper/gowhisper.go index 56d794f0f..feeb49c3e 100644 --- a/backend/go/whisper/gowhisper.go +++ b/backend/go/whisper/gowhisper.go @@ -1,16 +1,20 @@ package main import ( + "context" "fmt" "os" "path/filepath" "strings" + "sync" "unsafe" "github.com/go-audio/wav" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/utils" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -24,6 +28,7 @@ var ( CppNTokens func(i int) int CppGetTokenID func(i int, j int) int CppGetSegmentSpeakerTurnNext func(i int) bool + CppSetAbort func(v int) ) type Whisper struct { @@ -92,7 +97,11 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) { }, nil } -func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (w *Whisper) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { + if err := ctx.Err(); err != nil { + return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled") + } + dir, err := os.MkdirTemp("", "whisper") if err != nil { return pb.TranscriptResult{}, err @@ -105,14 +114,12 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR return pb.TranscriptResult{}, err } - // Open samples fh, err := os.Open(convertedPath) if err != nil { return pb.TranscriptResult{}, err } defer fh.Close() - // Read samples d := wav.NewDecoder(fh) buf, err := d.FullPCMBuffer() if err != nil { @@ -120,8 +127,6 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR } data := buf.AsFloat32Buffer().Data - // whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate - // for the converted file produced by AudioToWav above. var duration float32 if buf.Format != nil && buf.Format.SampleRate > 0 { duration = float32(len(data)) / float32(buf.Format.SampleRate) @@ -129,7 +134,31 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR segsLen := uintptr(0xdeadbeef) segsLenPtr := unsafe.Pointer(&segsLen) - if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 { + // Watcher: flips the C-side abort flag when ctx is cancelled. The + // goroutine is joined synchronously (close(done) signals it to exit, + // wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire + // after the function returns and corrupt the next transcription call. + done := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + CppSetAbort(1) + case <-done: + } + }() + defer func() { + close(done) + wg.Wait() + }() + + ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt) + if ret == 2 { + return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled") + } + if ret != 0 { return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe") } diff --git a/backend/go/whisper/gowhisper_test.go b/backend/go/whisper/gowhisper_test.go new file mode 100644 index 000000000..1b89e6615 --- /dev/null +++ b/backend/go/whisper/gowhisper_test.go @@ -0,0 +1,112 @@ +package main + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/ebitengine/purego" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestWhisper(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Whisper Backend Suite") +} + +var ( + libLoadOnce sync.Once + libLoadErr error +) + +// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the +// bridge without spinning up the gRPC server. Skips the current spec when the +// shared library isn't present (e.g. running before `make backends/whisper`). +func ensureLibLoaded() { + libLoadOnce.Do(func() { + libName := os.Getenv("WHISPER_LIBRARY") + if libName == "" { + libName = "./libgowhisper-fallback.so" + } + if _, err := os.Stat(libName); err != nil { + libLoadErr = err + return + } + gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + libLoadErr = err + return + } + purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model") + purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe") + purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text") + purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0") + purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1") + purego.RegisterLibFunc(&CppNTokens, gosd, "n_tokens") + purego.RegisterLibFunc(&CppGetTokenID, gosd, "get_token_id") + purego.RegisterLibFunc(&CppGetSegmentSpeakerTurnNext, gosd, "get_segment_speaker_turn_next") + purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort") + }) + if libLoadErr != nil { + Skip("whisper library not loadable: " + libLoadErr.Error()) + } +} + +// fixturesOrSkip returns the model + audio paths or skips the spec if either +// env var is unset. The test never runs in default CI — it requires a real +// whisper model and a long audio file (~3 minutes) on disk. +func fixturesOrSkip() (string, string) { + modelPath := os.Getenv("WHISPER_MODEL_PATH") + audioPath := os.Getenv("WHISPER_AUDIO_PATH") + if modelPath == "" || audioPath == "" { + Skip("set WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH to run this spec") + } + return modelPath, audioPath +} + +var _ = Describe("Whisper", func() { + Context("AudioTranscription cancellation", func() { + It("returns codes.Canceled and resets the abort flag for the next call", func() { + modelPath, audioPath := fixturesOrSkip() + ensureLibLoaded() + + w := &Whisper{} + Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed()) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + start := time.Now() + _, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{ + Dst: audioPath, + Threads: 4, + Language: "en", + }) + elapsed := time.Since(start) + + Expect(err).To(HaveOccurred(), "transcription completed in %s without cancel — try a longer audio file", elapsed) + st, ok := status.FromError(err) + Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err) + Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err) + Expect(elapsed).To(BeNumerically("<", 5*time.Second), "cancellation took %s, expected <5s", elapsed) + + // Subsequent transcription must succeed — proves g_abort reset. + res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{ + Dst: audioPath, + Threads: 4, + Language: "en", + }) + Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed") + Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text") + }) + }) +}) diff --git a/backend/go/whisper/main.go b/backend/go/whisper/main.go index 794c0a228..df35e9792 100644 --- a/backend/go/whisper/main.go +++ b/backend/go/whisper/main.go @@ -41,6 +41,7 @@ func main() { {&CppNTokens, "n_tokens"}, {&CppGetTokenID, "get_token_id"}, {&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"}, + {&CppSetAbort, "set_abort"}, } for _, lf := range libFuncs { diff --git a/core/backend/audio_transform.go b/core/backend/audio_transform.go index ea44ab6fd..3dbc8c833 100644 --- a/core/backend/audio_transform.go +++ b/core/backend/audio_transform.go @@ -40,6 +40,7 @@ type AudioTransformOutputs struct { // required; `referencePath` is optional (empty => backend zero-fills the // reference channel). func ModelAudioTransform( + ctx context.Context, audioPath, referencePath string, opts AudioTransformOptions, loader *model.ModelLoader, @@ -81,7 +82,7 @@ func ModelAudioTransform( startTime = time.Now() } - res, err := transformModel.AudioTransform(context.Background(), &proto.AudioTransformRequest{ + res, err := transformModel.AudioTransform(ctx, &proto.AudioTransformRequest{ AudioPath: audioPath, ReferencePath: referencePath, Dst: dst, diff --git a/core/backend/detection.go b/core/backend/detection.go index 1a98c47a9..13a923e9f 100644 --- a/core/backend/detection.go +++ b/core/backend/detection.go @@ -12,6 +12,7 @@ import ( ) func Detection( + ctx context.Context, sourceFile string, prompt string, points []float32, @@ -38,7 +39,7 @@ func Detection( startTime = time.Now() } - res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{ + res, err := detectionModel.Detect(ctx, &proto.DetectOptions{ Src: sourceFile, Prompt: prompt, Points: points, diff --git a/core/backend/diarization.go b/core/backend/diarization.go index d311d4c45..ba973d773 100644 --- a/core/backend/diarization.go +++ b/core/backend/diarization.go @@ -63,7 +63,7 @@ func loadDiarizationModel(ml *model.ModelLoader, modelConfig config.ModelConfig, // ModelDiarization runs the Diarize RPC against the configured backend // and returns a normalized schema.DiarizationResult. -func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) { +func ModelDiarization(ctx context.Context, req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) { m, err := loadDiarizationModel(ml, modelConfig, appConfig) if err != nil { return nil, err @@ -74,7 +74,7 @@ func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig threads = uint32(*modelConfig.Threads) } - r, err := m.Diarize(context.Background(), req.toProto(threads)) + r, err := m.Diarize(ctx, req.toProto(threads)) if err != nil { return nil, err } diff --git a/core/backend/face_analyze.go b/core/backend/face_analyze.go index 7293b09c7..24d70ac40 100644 --- a/core/backend/face_analyze.go +++ b/core/backend/face_analyze.go @@ -12,6 +12,7 @@ import ( ) func FaceAnalyze( + ctx context.Context, img string, actions []string, antiSpoofing bool, @@ -35,7 +36,7 @@ func FaceAnalyze( startTime = time.Now() } - res, err := faceModel.FaceAnalyze(context.Background(), &proto.FaceAnalyzeRequest{ + res, err := faceModel.FaceAnalyze(ctx, &proto.FaceAnalyzeRequest{ Img: img, Actions: actions, AntiSpoofing: antiSpoofing, diff --git a/core/backend/face_embed.go b/core/backend/face_embed.go index 77bbb4a7c..dc9fecad0 100644 --- a/core/backend/face_embed.go +++ b/core/backend/face_embed.go @@ -14,6 +14,7 @@ import ( // backend picks the highest-confidence face and returns its // L2-normalized embedding. func FaceEmbed( + ctx context.Context, imgBase64 string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, @@ -32,7 +33,7 @@ func FaceEmbed( predictOpts := gRPCPredictOpts(modelConfig, loader.ModelPath) predictOpts.Images = []string{imgBase64} - res, err := faceModel.Embeddings(context.Background(), predictOpts) + res, err := faceModel.Embeddings(ctx, predictOpts) if err != nil { return nil, err } diff --git a/core/backend/face_verify.go b/core/backend/face_verify.go index 43b128e79..15b7dcdaf 100644 --- a/core/backend/face_verify.go +++ b/core/backend/face_verify.go @@ -12,6 +12,7 @@ import ( ) func FaceVerify( + ctx context.Context, img1, img2 string, threshold float32, antiSpoofing bool, @@ -35,7 +36,7 @@ func FaceVerify( startTime = time.Now() } - res, err := faceModel.FaceVerify(context.Background(), &proto.FaceVerifyRequest{ + res, err := faceModel.FaceVerify(ctx, &proto.FaceVerifyRequest{ Img1: img1, Img2: img2, Threshold: threshold, diff --git a/core/backend/rerank.go b/core/backend/rerank.go index 4b8f8b288..9672a1ca8 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -11,7 +11,7 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) { +func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) { opts := ModelOptions(modelConfig, appConfig) rerankModel, err := loader.Load(opts...) if err != nil { @@ -29,7 +29,7 @@ func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig * startTime = time.Now() } - res, err := rerankModel.Rerank(context.Background(), request) + res, err := rerankModel.Rerank(ctx, request) if appConfig.EnableTracing { errStr := "" diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index f7f4d2f82..dccc4df74 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -15,6 +15,7 @@ import ( ) func SoundGeneration( + ctx context.Context, text string, duration *float32, temperature *float32, @@ -101,7 +102,7 @@ func SoundGeneration( startTime = time.Now() } - res, err := soundGenModel.SoundGeneration(context.Background(), req) + res, err := soundGenModel.SoundGeneration(ctx, req) if appConfig.EnableTracing { errStr := "" diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go index 4a9289eec..45d60a406 100644 --- a/core/backend/token_metrics.go +++ b/core/backend/token_metrics.go @@ -10,6 +10,7 @@ import ( ) func TokenMetrics( + ctx context.Context, modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, @@ -26,7 +27,7 @@ func TokenMetrics( return nil, fmt.Errorf("could not loadmodel model") } - res, err := model.GetTokenMetrics(context.Background(), &proto.MetricsRequest{}) + res, err := model.GetTokenMetrics(ctx, &proto.MetricsRequest{}) return res, err } diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 6d0b4c63d..be651516b 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -57,8 +57,8 @@ func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfi return transcriptionModel, nil } -func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { - return ModelTranscriptionWithOptions(TranscriptionRequest{ +func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { + return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{ Audio: audio, Language: language, Translate: translate, @@ -67,7 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt }, ml, modelConfig, appConfig) } -func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { +func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig) if err != nil { return nil, err @@ -82,7 +82,7 @@ func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoad audioSnippet = trace.AudioSnippet(req.Audio) } - r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads))) + r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads))) if err != nil { if appConfig.EnableTracing { errData := map[string]any{ @@ -149,7 +149,7 @@ type TranscriptionStreamChunk struct { // invokes onChunk for each event the backend produces. Backends that don't // support real streaming should still emit one terminal event with Final set, // which the HTTP layer turns into a single delta + done SSE pair. -func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error { +func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error { transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig) if err != nil { return err @@ -158,7 +158,7 @@ func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, m pbReq := req.toProto(uint32(*modelConfig.Threads)) pbReq.Stream = true - return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) { + return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) { if chunk == nil { return } @@ -187,12 +187,12 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR } var words []schema.TranscriptionWord for _, w := range s.Words { - var word = schema.TranscriptionWord { + var word = schema.TranscriptionWord{ Start: time.Duration(w.Start), End: time.Duration(w.End), Text: w.Text, } - words = append(words, word) + words = append(words, word) tr.Words = append(tr.Words, word) } tr.Segments = append(tr.Segments, diff --git a/core/backend/tts.go b/core/backend/tts.go index 2f3d31193..9af9d0d44 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -21,6 +21,7 @@ import ( ) func ModelTTS( + ctx context.Context, text, voice, language string, @@ -70,7 +71,7 @@ func ModelTTS( startTime = time.Now() } - res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{ + res, err := ttsModel.TTS(ctx, &proto.TTSRequest{ Text: text, Model: modelPath, Voice: voice, @@ -121,6 +122,7 @@ func ModelTTS( } func ModelTTSStream( + ctx context.Context, text, voice, language string, @@ -172,7 +174,7 @@ func ModelTTSStream( var totalPCMBytes int snippetCapped := false - err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{ + err = ttsModel.TTSStream(ctx, &proto.TTSRequest{ Text: text, Model: modelPath, Voice: voice, diff --git a/core/backend/voice_analyze.go b/core/backend/voice_analyze.go index 47ffebe5e..022692921 100644 --- a/core/backend/voice_analyze.go +++ b/core/backend/voice_analyze.go @@ -12,6 +12,7 @@ import ( ) func VoiceAnalyze( + ctx context.Context, audio string, actions []string, loader *model.ModelLoader, @@ -34,7 +35,7 @@ func VoiceAnalyze( startTime = time.Now() } - res, err := voiceModel.VoiceAnalyze(context.Background(), &proto.VoiceAnalyzeRequest{ + res, err := voiceModel.VoiceAnalyze(ctx, &proto.VoiceAnalyzeRequest{ Audio: audio, Actions: actions, }) diff --git a/core/backend/voice_embed.go b/core/backend/voice_embed.go index e72842591..6cdc9b6a2 100644 --- a/core/backend/voice_embed.go +++ b/core/backend/voice_embed.go @@ -16,6 +16,7 @@ import ( // OpenAI-compatible and text-only), this call takes an audio path and // returns the backend's speaker-encoder output. func VoiceEmbed( + ctx context.Context, audioPath string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, @@ -37,7 +38,7 @@ func VoiceEmbed( startTime = time.Now() } - res, err := voiceModel.VoiceEmbed(context.Background(), &proto.VoiceEmbedRequest{ + res, err := voiceModel.VoiceEmbed(ctx, &proto.VoiceEmbedRequest{ Audio: audioPath, }) diff --git a/core/backend/voice_verify.go b/core/backend/voice_verify.go index 97cc7b9b1..bd4c04808 100644 --- a/core/backend/voice_verify.go +++ b/core/backend/voice_verify.go @@ -12,6 +12,7 @@ import ( ) func VoiceVerify( + ctx context.Context, audio1, audio2 string, threshold float32, antiSpoofing bool, @@ -35,7 +36,7 @@ func VoiceVerify( startTime = time.Now() } - res, err := voiceModel.VoiceVerify(context.Background(), &proto.VoiceVerifyRequest{ + res, err := voiceModel.VoiceVerify(ctx, &proto.VoiceVerifyRequest{ Audio1: audio1, Audio2: audio2, Threshold: threshold, diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go index 4798de628..3eb6cfa4b 100644 --- a/core/cli/soundgeneration.go +++ b/core/cli/soundgeneration.go @@ -97,7 +97,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { inputFile = &t.InputFile } - filePath, _, err := backend.SoundGeneration(text, + filePath, _, err := backend.SoundGeneration(context.Background(), text, parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), nil, "", "", nil, "", "", "", nil, diff --git a/core/cli/transcript.go b/core/cli/transcript.go index d62beadf0..06764f4dd 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -71,7 +71,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { } }() - tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts) + tr, err := backend.ModelTranscription(context.Background(), t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts) if err != nil { return err } diff --git a/core/cli/tts.go b/core/cli/tts.go index 72d4ee24b..0f7b8bc6c 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error { options.Backend = t.Backend options.Model = t.Model - filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options) + filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options) if err != nil { return err } diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go index 7034ea042..eb9152e43 100644 --- a/core/http/endpoints/elevenlabs/soundgeneration.go +++ b/core/http/endpoints/elevenlabs/soundgeneration.go @@ -44,6 +44,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader bpm = &b } filePath, _, err := backend.SoundGeneration( + c.Request().Context(), input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, input.Think, input.Caption, input.Lyrics, bpm, input.Keyscale, diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index 3fc8c8f07..110ae292a 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID) - filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg) + filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go index 6dabd35f6..348446b28 100644 --- a/core/http/endpoints/jina/rerank.go +++ b/core/http/endpoints/jina/rerank.go @@ -52,7 +52,7 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app Documents: input.Documents, } - results, err := backend.Rerank(request, ml, appConfig, *cfg) + results, err := backend.Rerank(c.Request().Context(), request, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/localai/audio_transform.go b/core/http/endpoints/localai/audio_transform.go index b8ce8530d..a11c6595a 100644 --- a/core/http/endpoints/localai/audio_transform.go +++ b/core/http/endpoints/localai/audio_transform.go @@ -109,7 +109,7 @@ func AudioTransformEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, } } - out, _, err := backend.ModelAudioTransform(audioPath, referencePath, backend.AudioTransformOptions{ + out, _, err := backend.ModelAudioTransform(c.Request().Context(), audioPath, referencePath, backend.AudioTransformOptions{ Params: params, }, ml, appConfig, *cfg) if err != nil { diff --git a/core/http/endpoints/localai/detection.go b/core/http/endpoints/localai/detection.go index 0a9463e59..0f1c72282 100644 --- a/core/http/endpoints/localai/detection.go +++ b/core/http/endpoints/localai/detection.go @@ -38,7 +38,7 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC return err } - res, err := backend.Detection(image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg) + res, err := backend.Detection(c.Request().Context(), image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/face_analyze.go b/core/http/endpoints/localai/face_analyze.go index e4eda3ddd..441b7f0af 100644 --- a/core/http/endpoints/localai/face_analyze.go +++ b/core/http/endpoints/localai/face_analyze.go @@ -35,7 +35,7 @@ func FaceAnalyzeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap } xlog.Debug("FaceAnalyze", "model", cfg.Name, "backend", cfg.Backend, "actions", input.Actions) - res, err := backend.FaceAnalyze(img, input.Actions, input.AntiSpoofing, ml, appConfig, *cfg) + res, err := backend.FaceAnalyze(c.Request().Context(), img, input.Actions, input.AntiSpoofing, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/face_embed.go b/core/http/endpoints/localai/face_embed.go index 7a0f18e34..58524cafa 100644 --- a/core/http/endpoints/localai/face_embed.go +++ b/core/http/endpoints/localai/face_embed.go @@ -41,7 +41,7 @@ func FaceEmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC } xlog.Debug("FaceEmbed", "model", cfg.Name, "backend", cfg.Backend) - vec, err := backend.FaceEmbed(img, ml, appConfig, *cfg) + vec, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/face_identify.go b/core/http/endpoints/localai/face_identify.go index 527174127..15e7e2c3c 100644 --- a/core/http/endpoints/localai/face_identify.go +++ b/core/http/endpoints/localai/face_identify.go @@ -45,7 +45,7 @@ func FaceIdentifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a threshold := cmp.Or(input.Threshold, defaultIdentifyThreshold) xlog.Debug("FaceIdentify", "model", cfg.Name, "topK", topK, "threshold", threshold) - probe, err := backend.FaceEmbed(img, ml, appConfig, *cfg) + probe, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/face_register.go b/core/http/endpoints/localai/face_register.go index 308a194a7..fbeb29e0c 100644 --- a/core/http/endpoints/localai/face_register.go +++ b/core/http/endpoints/localai/face_register.go @@ -39,7 +39,7 @@ func FaceRegisterEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a } xlog.Debug("FaceRegister", "model", cfg.Name, "name", input.Name) - embedding, err := backend.FaceEmbed(img, ml, appConfig, *cfg) + embedding, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/face_verify.go b/core/http/endpoints/localai/face_verify.go index 26398b7f8..ef608c57a 100644 --- a/core/http/endpoints/localai/face_verify.go +++ b/core/http/endpoints/localai/face_verify.go @@ -39,7 +39,7 @@ func FaceVerifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app } xlog.Debug("FaceVerify", "model", cfg.Name, "backend", cfg.Backend) - res, err := backend.FaceVerify(img1, img2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg) + res, err := backend.FaceVerify(c.Request().Context(), img1, img2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/get_token_metrics.go b/core/http/endpoints/localai/get_token_metrics.go index 36b0301b7..215928ab1 100644 --- a/core/http/endpoints/localai/get_token_metrics.go +++ b/core/http/endpoints/localai/get_token_metrics.go @@ -49,7 +49,7 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a } xlog.Debug("Token Metrics for model", "model", modelFile) - response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg) + response, err := backend.TokenMetrics(c.Request().Context(), modelFile, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 40e488191..fe9199c24 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig c.Response().Header().Set("Connection", "keep-alive") // Stream audio chunks as they're generated - err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error { + err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error { _, writeErr := c.Response().Write(audioChunk) if writeErr != nil { return writeErr @@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig } // Non-streaming TTS (existing behavior) - filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg) + filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/localai/voice_analyze.go b/core/http/endpoints/localai/voice_analyze.go index 4712cd5b0..ff4d3c45d 100644 --- a/core/http/endpoints/localai/voice_analyze.go +++ b/core/http/endpoints/localai/voice_analyze.go @@ -36,7 +36,7 @@ func VoiceAnalyzeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a defer cleanup() xlog.Debug("VoiceAnalyze", "model", cfg.Name, "backend", cfg.Backend, "actions", input.Actions) - res, err := backend.VoiceAnalyze(audio, input.Actions, ml, appConfig, *cfg) + res, err := backend.VoiceAnalyze(c.Request().Context(), audio, input.Actions, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/voice_embed.go b/core/http/endpoints/localai/voice_embed.go index 1f878efd6..d84a4dc2b 100644 --- a/core/http/endpoints/localai/voice_embed.go +++ b/core/http/endpoints/localai/voice_embed.go @@ -41,7 +41,7 @@ func VoiceEmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app defer cleanup() xlog.Debug("VoiceEmbed", "model", cfg.Name, "backend", cfg.Backend) - res, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg) + res, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/voice_identify.go b/core/http/endpoints/localai/voice_identify.go index b048bf96f..eda5aec3d 100644 --- a/core/http/endpoints/localai/voice_identify.go +++ b/core/http/endpoints/localai/voice_identify.go @@ -47,7 +47,7 @@ func VoiceIdentifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, threshold := cmp.Or(input.Threshold, defaultVoiceIdentifyThreshold) xlog.Debug("VoiceIdentify", "model", cfg.Name, "topK", topK, "threshold", threshold) - embed, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg) + embed, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/voice_register.go b/core/http/endpoints/localai/voice_register.go index 27605cd71..d8d97d619 100644 --- a/core/http/endpoints/localai/voice_register.go +++ b/core/http/endpoints/localai/voice_register.go @@ -40,7 +40,7 @@ func VoiceRegisterEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, defer cleanup() xlog.Debug("VoiceRegister", "model", cfg.Name, "name", input.Name) - res, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg) + res, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/localai/voice_verify.go b/core/http/endpoints/localai/voice_verify.go index 9e81b8a15..d762ea51b 100644 --- a/core/http/endpoints/localai/voice_verify.go +++ b/core/http/endpoints/localai/voice_verify.go @@ -42,7 +42,7 @@ func VoiceVerifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap defer cleanup2() xlog.Debug("VoiceVerify", "model", cfg.Name, "backend", cfg.Backend) - res, err := backend.VoiceVerify(audio1, audio2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg) + res, err := backend.VoiceVerify(c.Request().Context(), audio1, audio2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg) if err != nil { return mapBackendError(err) } diff --git a/core/http/endpoints/openai/diarization.go b/core/http/endpoints/openai/diarization.go index 2f927ddae..75e9715db 100644 --- a/core/http/endpoints/openai/diarization.go +++ b/core/http/endpoints/openai/diarization.go @@ -105,7 +105,7 @@ func DiarizationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap _ = dstFile.Close() req.Audio = dst - result, err := backend.ModelDiarization(req, ml, *modelConfig, appConfig) + result, err := backend.ModelDiarization(c.Request().Context(), req, ml, *modelConfig, appConfig) if err != nil { return err } diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index dd9baf1b7..bfeb70739 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -62,7 +62,7 @@ func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADReques } func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { - return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) + return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) } func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { @@ -82,7 +82,7 @@ func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*sc } func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { - return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) + return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) } func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { @@ -241,7 +241,7 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im } func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) { - return backend.ModelTTS(text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig) + return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig) } func (m *wrappedModel) PredictConfig() *config.ModelConfig { diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 81f89b927..e57aac141 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -126,7 +126,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app return streamTranscription(c, req, ml, *config, appConfig) } - tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig) + tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig) if err != nil { // Log before returning so the underlying error survives. Echo's // error handler turns this into a 500 with a generic body, which @@ -157,16 +157,16 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app Words: []schema.TranscriptionWordSeconds{}, Segments: []schema.TranscriptionSegmentSeconds{}, } - for _, word := range(tr.Words) { + for _, word := range tr.Words { trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{ Start: word.Start.Seconds(), End: word.End.Seconds(), Text: word.Text, }) } - for _, seg := range(tr.Segments) { + for _, seg := range tr.Segments { segWords := []schema.TranscriptionWordSeconds{} - for _, word := range(seg.Words) { + for _, word := range seg.Words { segWords = append(segWords, schema.TranscriptionWordSeconds{ Start: word.Start.Seconds(), End: word.End.Seconds(), @@ -174,7 +174,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app }) } trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{ - Id: seg.Id, + Id: seg.Id, Start: seg.Start.Seconds(), End: seg.End.Seconds(), Text: seg.Text, @@ -216,7 +216,7 @@ func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *m var assembled strings.Builder var finalResult *schema.TranscriptionResult - err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) { + err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) { if chunk.Delta != "" { assembled.WriteString(chunk.Delta) _ = writeEvent(map[string]any{ diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 1fdf258cd..c5a0fc62f 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -3,6 +3,7 @@ package base // This is a wrapper to satisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( + "context" "fmt" "os" @@ -57,11 +58,11 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) { return pb.TranscriptResult{}, fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error { +func (llm *Base) AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error { return fmt.Errorf("unimplemented") } diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 8fb800aec..82f8af23b 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -1,6 +1,8 @@ package grpc import ( + "context" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) @@ -22,8 +24,8 @@ type AIModel interface { VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error) VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error) VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error) - AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) - AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error + AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) + AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error TTS(*pb.TTSRequest) error TTSStream(*pb.TTSRequest, chan []byte) error SoundGeneration(*pb.SoundGenerationRequest) error diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index a931e6556..396547ca9 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -218,7 +218,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques s.llm.Lock() defer s.llm.Unlock() } - result, err := s.llm.AudioTranscription(in) + result, err := s.llm.AudioTranscription(ctx, in) if err != nil { return nil, err } @@ -260,7 +260,7 @@ func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Ba done <- true }() - err := s.llm.AudioTranscriptionStream(in, resultChan) + err := s.llm.AudioTranscriptionStream(stream.Context(), in, resultChan) <-done return err diff --git a/tests/e2e/distributed/distributed_full_flow_test.go b/tests/e2e/distributed/distributed_full_flow_test.go index 84867f804..5215c5617 100644 --- a/tests/e2e/distributed/distributed_full_flow_test.go +++ b/tests/e2e/distributed/distributed_full_flow_test.go @@ -93,7 +93,7 @@ func (t *testLLM) SoundGeneration(req *pb.SoundGenerationRequest) error { return nil } -func (t *testLLM) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (t *testLLM) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) { t.lastAudioDst = req.Dst return pb.TranscriptResult{Text: "transcribed text"}, nil }