feat(whisper): honor client cancellation via ggml abort_callback (#9710)

* refactor(transcription): propagate request ctx through ModelTranscription*

Replaces context.Background() with the HTTP request ctx so client
disconnects start cancelling the gRPC call. No backend-side abort wiring
yet — that comes in a later commit. Pure plumbing.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cli): pass ctx to backend.ModelTranscription

Follow-up to e65d3e1f which threaded ctx through ModelTranscription
but missed the CLI caller. CLI commands have no request-scoped ctx,
so context.Background() is correct here.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(audio): propagate request ctx into TTS, sound-gen, audio-transform

Same ctx-plumbing pattern applied to the rest of the audio path. CLI
callers use context.Background() since there is no request scope; HTTP
callers use c.Request().Context().

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(backend): propagate request ctx into biometric, detection, rerank, diarization paths

Replaces remaining context.Background() sites in core/backend with the
caller's ctx. After this commit, every core/backend/*.go entry point
threads the request ctx end-to-end to the gRPC client.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream}

Adds context.Context as first parameter to the AIModel interface methods
that wrap whisper-style transcription. Server-side gRPC handler now
forwards the per-RPC ctx (server-streaming uses stream.Context()).
Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter;
none uses it yet — the actual cancellation primitive lands in the next
commit so this is pure plumbing.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): add abort_callback hook in the C++ bridge

Installs a std::atomic<int> flag, wires it into
whisper_full_params.abort_callback, and exposes a set_abort(int) C
symbol so Go can flip the flag from a goroutine watching the request
context. transcribe() now distinguishes abort (return 2) from real
whisper_full failure (return 1).

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): register set_abort symbol in the purego loader

Adds the Go-side binding for the new C export so the next commit can
call CppSetAbort(1) from a watcher goroutine on ctx.Done().

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): honor ctx cancellation and return codes.Canceled

A watcher goroutine watches ctx.Done() during AudioTranscription and
calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return
true at the next compute graph step, returns non-zero, and the bridge
returns 2 -> AudioTranscription maps that to codes.Canceled.

Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH)
that asserts cancellation latency under 5s and proves the abort flag
resets cleanly so the next transcription succeeds.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(whisper): join the cancel watcher goroutine before returning

Follow-up to 85edf9d2. The previous commit used `defer close(done)` and
called the watcher "joined synchronously" — but close() only signals,
it does not block until the goroutine exits. That left a window where
a late CppSetAbort(1) from a cancelled call could land on the next
call, after its C-side g_abort reset but before whisper_full() began
polling the abort callback, corrupting the second transcription.

Switch to a sync.WaitGroup join so wg.Wait() blocks until the watcher
has actually returned from its select.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(whisper): short-circuit pre-cancelled ctx in AudioTranscription

If ctx is already Done() at entry, return codes.Canceled immediately
instead of running the full transcription. The C-side g_abort reset
happens at the start of transcribe() and would otherwise overwrite a
watcher-set abort flag from an already-cancelled ctx, producing a
spurious successful transcription on a request the client has already
abandoned.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(tests/distributed): update testLLM mock for new AudioTranscription signature

Phase B (93c48e19) added context.Context to AIModel.AudioTranscription
but missed the testLLM mock in tests/e2e/distributed. CI golangci-lint
caught it: *testLLM did not implement grpc.AIModel because the method
signature lacked the ctx parameter, which broke the distributed test
suite compilation and cascaded through every backend-build job that
runs `go build ./...`.

Assisted-by: Claude:claude-opus-4-7
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* test(whisper): port cancellation test to Ginkgo/Gomega

Project policy (.agents/coding-style.md, enforced by golangci-lint
forbidigo) is that all Go tests must use Ginkgo v2 + Gomega — no
stdlib testing patterns (t.Skip, t.Fatalf, etc.). Convert the
cancellation test to a Describe/It block with Skip(...) for env
gating and Expect/HaveOccurred for assertions.

Same coverage: cancel mid-flight returns codes.Canceled within 5s and
a follow-up transcription succeeds, proving the C-side g_abort flag
resets cleanly.

Assisted-by: Claude:claude-opus-4-7
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
LocalAI [bot]
2026-05-08 01:44:47 +02:00
committed by GitHub
parent 806130bbc0
commit 2be07f61da
51 changed files with 260 additions and 75 deletions

View File

@@ -2,6 +2,7 @@ package main
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"os"
@@ -998,7 +999,7 @@ func (s *SherpaBackend) loadOnlineASR(opts *pb.ModelOptions) error {
// Transcription
// =============================================================
func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (s *SherpaBackend) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if s.onlineRecognizer != 0 {
return s.runOnlineASR(req, nil)
}
@@ -1056,6 +1057,7 @@ func (s *SherpaBackend) AudioTranscription(req *pb.TranscriptRequest) (pb.Transc
// Closes `results` before returning so the server wrapper's reader
// goroutine can exit.
func (s *SherpaBackend) AudioTranscriptionStream(
_ context.Context,
req *pb.TranscriptRequest,
results chan *pb.TranscriptStreamResponse,
) error {

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"os"
"path/filepath"
"testing"
@@ -79,7 +80,7 @@ var _ = Describe("Sherpa-ONNX", func() {
})
It("rejects AudioTranscription", func() {
_, err := (&SherpaBackend{}).AudioTranscription(&pb.TranscriptRequest{
_, err := (&SherpaBackend{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/nonexistent.wav",
})
Expect(err).To(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/json"
"fmt"
"io"
@@ -480,7 +481,7 @@ func (w *byteWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (v *VibevoiceCpp) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (v *VibevoiceCpp) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if v.asrModel == "" {
return pb.TranscriptResult{}, fmt.Errorf("vibevoice-cpp: AudioTranscription requested but no ASR model was loaded")
}
@@ -623,9 +624,9 @@ func (v *VibevoiceCpp) Diarize(req *pb.DiarizeRequest) (pb.DiarizeResponse, erro
// transcription, emit each segment's content as a delta, then close
// with a final_result whose Text equals the concatenated deltas (the
// e2e harness asserts those match).
func (v *VibevoiceCpp) AudioTranscriptionStream(req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
func (v *VibevoiceCpp) AudioTranscriptionStream(ctx context.Context, req *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error {
defer close(results)
res, err := v.AudioTranscription(req)
res, err := v.AudioTranscription(ctx, req)
if err != nil {
return err
}

View File

@@ -107,7 +107,7 @@ var _ = Describe("VibeVoice-cpp", func() {
})
It("rejects AudioTranscription without a loaded ASR model", func() {
_, err := (&VibevoiceCpp{}).AudioTranscription(&pb.TranscriptRequest{
_, err := (&VibevoiceCpp{}).AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/some.wav",
})
Expect(err).To(HaveOccurred())
@@ -255,7 +255,7 @@ var _ = Describe("VibeVoice-cpp", func() {
It("closes the channel and errors on AudioTranscriptionStream without a loaded model", func() {
ch := make(chan *pb.TranscriptStreamResponse, 4)
err := (&VibevoiceCpp{}).AudioTranscriptionStream(&pb.TranscriptRequest{
err := (&VibevoiceCpp{}).AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{
Dst: "/tmp/some.wav",
}, ch)
Expect(err).To(HaveOccurred())

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"os"
"strings"
@@ -27,7 +28,7 @@ func (v *Voxtral) Load(opts *pb.ModelOptions) error {
return nil
}
func (v *Voxtral) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (v *Voxtral) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "voxtral")
if err != nil {
return pb.TranscriptResult{}, err

View File

@@ -1,12 +1,23 @@
#include "gowhisper.h"
#include "ggml-backend.h"
#include "whisper.h"
#include <atomic>
#include <vector>
static struct whisper_vad_context *vctx;
static struct whisper_context *ctx;
static std::vector<float> flat_segs;
static std::atomic<int> g_abort{0};
static bool abort_cb(void * /*user_data*/) {
return g_abort.load(std::memory_order_relaxed) != 0;
}
extern "C" void set_abort(int v) {
g_abort.store(v, std::memory_order_relaxed);
}
static void ggml_log_cb(enum ggml_log_level level, const char *log,
void *data) {
const char *level_str;
@@ -124,10 +135,20 @@ int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
wparams.tdrz_enable = tdrz;
wparams.initial_prompt = prompt;
// Reset stale abort flag from any prior cancelled call, then install the
// ggml abort hook so a subsequent set_abort(1) from Go aborts the next
// compute graph step.
g_abort.store(0, std::memory_order_relaxed);
wparams.abort_callback = abort_cb;
wparams.abort_callback_user_data = nullptr;
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
fprintf(stderr, "info: Initial prompt: \"%s\"\n", prompt);
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
if (g_abort.load(std::memory_order_relaxed)) {
return 2; // aborted by client
}
fprintf(stderr, "error: transcription failed\n");
return 1;
}

View File

@@ -15,4 +15,5 @@ int64_t get_segment_t1(int i);
int n_tokens(int i);
int32_t get_token_id(int i, int j);
bool get_segment_speaker_turn_next(int i);
void set_abort(int v);
}

View File

@@ -1,16 +1,20 @@
package main
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"unsafe"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
@@ -24,6 +28,7 @@ var (
CppNTokens func(i int) int
CppGetTokenID func(i int, j int) int
CppGetSegmentSpeakerTurnNext func(i int) bool
CppSetAbort func(v int)
)
type Whisper struct {
@@ -92,7 +97,11 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
}, nil
}
func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (w *Whisper) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if err := ctx.Err(); err != nil {
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
}
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err
@@ -105,14 +114,12 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
return pb.TranscriptResult{}, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return pb.TranscriptResult{}, err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
@@ -120,8 +127,6 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
}
data := buf.AsFloat32Buffer().Data
// whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate
// for the converted file produced by AudioToWav above.
var duration float32
if buf.Format != nil && buf.Format.SampleRate > 0 {
duration = float32(len(data)) / float32(buf.Format.SampleRate)
@@ -129,7 +134,31 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
segsLen := uintptr(0xdeadbeef)
segsLenPtr := unsafe.Pointer(&segsLen)
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 {
// Watcher: flips the C-side abort flag when ctx is cancelled. The
// goroutine is joined synchronously (close(done) signals it to exit,
// wg.Wait() blocks until it has) so a late CppSetAbort(1) cannot fire
// after the function returns and corrupt the next transcription call.
done := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Done():
CppSetAbort(1)
case <-done:
}
}()
defer func() {
close(done)
wg.Wait()
}()
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
if ret == 2 {
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
}
if ret != 0 {
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
}

View File

@@ -0,0 +1,112 @@
package main
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/ebitengine/purego"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestWhisper(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Whisper Backend Suite")
}
var (
libLoadOnce sync.Once
libLoadErr error
)
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
// bridge without spinning up the gRPC server. Skips the current spec when the
// shared library isn't present (e.g. running before `make backends/whisper`).
func ensureLibLoaded() {
libLoadOnce.Do(func() {
libName := os.Getenv("WHISPER_LIBRARY")
if libName == "" {
libName = "./libgowhisper-fallback.so"
}
if _, err := os.Stat(libName); err != nil {
libLoadErr = err
return
}
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libLoadErr = err
return
}
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
purego.RegisterLibFunc(&CppNTokens, gosd, "n_tokens")
purego.RegisterLibFunc(&CppGetTokenID, gosd, "get_token_id")
purego.RegisterLibFunc(&CppGetSegmentSpeakerTurnNext, gosd, "get_segment_speaker_turn_next")
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
})
if libLoadErr != nil {
Skip("whisper library not loadable: " + libLoadErr.Error())
}
}
// fixturesOrSkip returns the model + audio paths or skips the spec if either
// env var is unset. The test never runs in default CI — it requires a real
// whisper model and a long audio file (~3 minutes) on disk.
func fixturesOrSkip() (string, string) {
modelPath := os.Getenv("WHISPER_MODEL_PATH")
audioPath := os.Getenv("WHISPER_AUDIO_PATH")
if modelPath == "" || audioPath == "" {
Skip("set WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH to run this spec")
}
return modelPath, audioPath
}
var _ = Describe("Whisper", func() {
Context("AudioTranscription cancellation", func() {
It("returns codes.Canceled and resets the abort flag for the next call", func() {
modelPath, audioPath := fixturesOrSkip()
ensureLibLoaded()
w := &Whisper{}
Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed())
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
start := time.Now()
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
Dst: audioPath,
Threads: 4,
Language: "en",
})
elapsed := time.Since(start)
Expect(err).To(HaveOccurred(), "transcription completed in %s without cancel — try a longer audio file", elapsed)
st, ok := status.FromError(err)
Expect(ok).To(BeTrue(), "expected gRPC status error, got %v", err)
Expect(st.Code()).To(Equal(codes.Canceled), "expected codes.Canceled, got %v", err)
Expect(elapsed).To(BeNumerically("<", 5*time.Second), "cancellation took %s, expected <5s", elapsed)
// Subsequent transcription must succeed — proves g_abort reset.
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: audioPath,
Threads: 4,
Language: "en",
})
Expect(err).ToNot(HaveOccurred(), "post-cancel transcription failed")
Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text")
})
})
})

View File

@@ -41,6 +41,7 @@ func main() {
{&CppNTokens, "n_tokens"},
{&CppGetTokenID, "get_token_id"},
{&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"},
{&CppSetAbort, "set_abort"},
}
for _, lf := range libFuncs {

View File

@@ -40,6 +40,7 @@ type AudioTransformOutputs struct {
// required; `referencePath` is optional (empty => backend zero-fills the
// reference channel).
func ModelAudioTransform(
ctx context.Context,
audioPath, referencePath string,
opts AudioTransformOptions,
loader *model.ModelLoader,
@@ -81,7 +82,7 @@ func ModelAudioTransform(
startTime = time.Now()
}
res, err := transformModel.AudioTransform(context.Background(), &proto.AudioTransformRequest{
res, err := transformModel.AudioTransform(ctx, &proto.AudioTransformRequest{
AudioPath: audioPath,
ReferencePath: referencePath,
Dst: dst,

View File

@@ -12,6 +12,7 @@ import (
)
func Detection(
ctx context.Context,
sourceFile string,
prompt string,
points []float32,
@@ -38,7 +39,7 @@ func Detection(
startTime = time.Now()
}
res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{
res, err := detectionModel.Detect(ctx, &proto.DetectOptions{
Src: sourceFile,
Prompt: prompt,
Points: points,

View File

@@ -63,7 +63,7 @@ func loadDiarizationModel(ml *model.ModelLoader, modelConfig config.ModelConfig,
// ModelDiarization runs the Diarize RPC against the configured backend
// and returns a normalized schema.DiarizationResult.
func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) {
func ModelDiarization(ctx context.Context, req DiarizationRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.DiarizationResult, error) {
m, err := loadDiarizationModel(ml, modelConfig, appConfig)
if err != nil {
return nil, err
@@ -74,7 +74,7 @@ func ModelDiarization(req DiarizationRequest, ml *model.ModelLoader, modelConfig
threads = uint32(*modelConfig.Threads)
}
r, err := m.Diarize(context.Background(), req.toProto(threads))
r, err := m.Diarize(ctx, req.toProto(threads))
if err != nil {
return nil, err
}

View File

@@ -12,6 +12,7 @@ import (
)
func FaceAnalyze(
ctx context.Context,
img string,
actions []string,
antiSpoofing bool,
@@ -35,7 +36,7 @@ func FaceAnalyze(
startTime = time.Now()
}
res, err := faceModel.FaceAnalyze(context.Background(), &proto.FaceAnalyzeRequest{
res, err := faceModel.FaceAnalyze(ctx, &proto.FaceAnalyzeRequest{
Img: img,
Actions: actions,
AntiSpoofing: antiSpoofing,

View File

@@ -14,6 +14,7 @@ import (
// backend picks the highest-confidence face and returns its
// L2-normalized embedding.
func FaceEmbed(
ctx context.Context,
imgBase64 string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
@@ -32,7 +33,7 @@ func FaceEmbed(
predictOpts := gRPCPredictOpts(modelConfig, loader.ModelPath)
predictOpts.Images = []string{imgBase64}
res, err := faceModel.Embeddings(context.Background(), predictOpts)
res, err := faceModel.Embeddings(ctx, predictOpts)
if err != nil {
return nil, err
}

View File

@@ -12,6 +12,7 @@ import (
)
func FaceVerify(
ctx context.Context,
img1, img2 string,
threshold float32,
antiSpoofing bool,
@@ -35,7 +36,7 @@ func FaceVerify(
startTime = time.Now()
}
res, err := faceModel.FaceVerify(context.Background(), &proto.FaceVerifyRequest{
res, err := faceModel.FaceVerify(ctx, &proto.FaceVerifyRequest{
Img1: img1,
Img2: img2,
Threshold: threshold,

View File

@@ -11,7 +11,7 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
opts := ModelOptions(modelConfig, appConfig)
rerankModel, err := loader.Load(opts...)
if err != nil {
@@ -29,7 +29,7 @@ func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *
startTime = time.Now()
}
res, err := rerankModel.Rerank(context.Background(), request)
res, err := rerankModel.Rerank(ctx, request)
if appConfig.EnableTracing {
errStr := ""

View File

@@ -15,6 +15,7 @@ import (
)
func SoundGeneration(
ctx context.Context,
text string,
duration *float32,
temperature *float32,
@@ -101,7 +102,7 @@ func SoundGeneration(
startTime = time.Now()
}
res, err := soundGenModel.SoundGeneration(context.Background(), req)
res, err := soundGenModel.SoundGeneration(ctx, req)
if appConfig.EnableTracing {
errStr := ""

View File

@@ -10,6 +10,7 @@ import (
)
func TokenMetrics(
ctx context.Context,
modelFile string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
@@ -26,7 +27,7 @@ func TokenMetrics(
return nil, fmt.Errorf("could not loadmodel model")
}
res, err := model.GetTokenMetrics(context.Background(), &proto.MetricsRequest{})
res, err := model.GetTokenMetrics(ctx, &proto.MetricsRequest{})
return res, err
}

View File

@@ -57,8 +57,8 @@ func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfi
return transcriptionModel, nil
}
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
return ModelTranscriptionWithOptions(TranscriptionRequest{
func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{
Audio: audio,
Language: language,
Translate: translate,
@@ -67,7 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
}, ml, modelConfig, appConfig)
}
func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
if err != nil {
return nil, err
@@ -82,7 +82,7 @@ func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoad
audioSnippet = trace.AudioSnippet(req.Audio)
}
r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads)))
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
if err != nil {
if appConfig.EnableTracing {
errData := map[string]any{
@@ -149,7 +149,7 @@ type TranscriptionStreamChunk struct {
// invokes onChunk for each event the backend produces. Backends that don't
// support real streaming should still emit one terminal event with Final set,
// which the HTTP layer turns into a single delta + done SSE pair.
func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
if err != nil {
return err
@@ -158,7 +158,7 @@ func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, m
pbReq := req.toProto(uint32(*modelConfig.Threads))
pbReq.Stream = true
return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) {
return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) {
if chunk == nil {
return
}
@@ -187,12 +187,12 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
}
var words []schema.TranscriptionWord
for _, w := range s.Words {
var word = schema.TranscriptionWord {
var word = schema.TranscriptionWord{
Start: time.Duration(w.Start),
End: time.Duration(w.End),
Text: w.Text,
}
words = append(words, word)
words = append(words, word)
tr.Words = append(tr.Words, word)
}
tr.Segments = append(tr.Segments,

View File

@@ -21,6 +21,7 @@ import (
)
func ModelTTS(
ctx context.Context,
text,
voice,
language string,
@@ -70,7 +71,7 @@ func ModelTTS(
startTime = time.Now()
}
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
@@ -121,6 +122,7 @@ func ModelTTS(
}
func ModelTTSStream(
ctx context.Context,
text,
voice,
language string,
@@ -172,7 +174,7 @@ func ModelTTSStream(
var totalPCMBytes int
snippetCapped := false
err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,

View File

@@ -12,6 +12,7 @@ import (
)
func VoiceAnalyze(
ctx context.Context,
audio string,
actions []string,
loader *model.ModelLoader,
@@ -34,7 +35,7 @@ func VoiceAnalyze(
startTime = time.Now()
}
res, err := voiceModel.VoiceAnalyze(context.Background(), &proto.VoiceAnalyzeRequest{
res, err := voiceModel.VoiceAnalyze(ctx, &proto.VoiceAnalyzeRequest{
Audio: audio,
Actions: actions,
})

View File

@@ -16,6 +16,7 @@ import (
// OpenAI-compatible and text-only), this call takes an audio path and
// returns the backend's speaker-encoder output.
func VoiceEmbed(
ctx context.Context,
audioPath string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
@@ -37,7 +38,7 @@ func VoiceEmbed(
startTime = time.Now()
}
res, err := voiceModel.VoiceEmbed(context.Background(), &proto.VoiceEmbedRequest{
res, err := voiceModel.VoiceEmbed(ctx, &proto.VoiceEmbedRequest{
Audio: audioPath,
})

View File

@@ -12,6 +12,7 @@ import (
)
func VoiceVerify(
ctx context.Context,
audio1, audio2 string,
threshold float32,
antiSpoofing bool,
@@ -35,7 +36,7 @@ func VoiceVerify(
startTime = time.Now()
}
res, err := voiceModel.VoiceVerify(context.Background(), &proto.VoiceVerifyRequest{
res, err := voiceModel.VoiceVerify(ctx, &proto.VoiceVerifyRequest{
Audio1: audio1,
Audio2: audio2,
Threshold: threshold,

View File

@@ -97,7 +97,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
inputFile = &t.InputFile
}
filePath, _, err := backend.SoundGeneration(text,
filePath, _, err := backend.SoundGeneration(context.Background(), text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor),
nil, "", "", nil, "", "", "", nil,

View File

@@ -71,7 +71,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
}
}()
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts)
tr, err := backend.ModelTranscription(context.Background(), t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts)
if err != nil {
return err
}

View File

@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
options.Backend = t.Backend
options.Model = t.Model
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
if err != nil {
return err
}

View File

@@ -44,6 +44,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
bpm = &b
}
filePath, _, err := backend.SoundGeneration(
c.Request().Context(),
input.Text, input.Duration, input.Temperature, input.DoSample,
nil, nil,
input.Think, input.Caption, input.Lyrics, bpm, input.Keyscale,

View File

@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -52,7 +52,7 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
Documents: input.Documents,
}
results, err := backend.Rerank(request, ml, appConfig, *cfg)
results, err := backend.Rerank(c.Request().Context(), request, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -109,7 +109,7 @@ func AudioTransformEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
}
}
out, _, err := backend.ModelAudioTransform(audioPath, referencePath, backend.AudioTransformOptions{
out, _, err := backend.ModelAudioTransform(c.Request().Context(), audioPath, referencePath, backend.AudioTransformOptions{
Params: params,
}, ml, appConfig, *cfg)
if err != nil {

View File

@@ -38,7 +38,7 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
return err
}
res, err := backend.Detection(image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
res, err := backend.Detection(c.Request().Context(), image, input.Prompt, input.Points, input.Boxes, input.Threshold, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -35,7 +35,7 @@ func FaceAnalyzeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
}
xlog.Debug("FaceAnalyze", "model", cfg.Name, "backend", cfg.Backend, "actions", input.Actions)
res, err := backend.FaceAnalyze(img, input.Actions, input.AntiSpoofing, ml, appConfig, *cfg)
res, err := backend.FaceAnalyze(c.Request().Context(), img, input.Actions, input.AntiSpoofing, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -41,7 +41,7 @@ func FaceEmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
}
xlog.Debug("FaceEmbed", "model", cfg.Name, "backend", cfg.Backend)
vec, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
vec, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -45,7 +45,7 @@ func FaceIdentifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
threshold := cmp.Or(input.Threshold, defaultIdentifyThreshold)
xlog.Debug("FaceIdentify", "model", cfg.Name, "topK", topK, "threshold", threshold)
probe, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
probe, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -39,7 +39,7 @@ func FaceRegisterEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
}
xlog.Debug("FaceRegister", "model", cfg.Name, "name", input.Name)
embedding, err := backend.FaceEmbed(img, ml, appConfig, *cfg)
embedding, err := backend.FaceEmbed(c.Request().Context(), img, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -39,7 +39,7 @@ func FaceVerifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
}
xlog.Debug("FaceVerify", "model", cfg.Name, "backend", cfg.Backend)
res, err := backend.FaceVerify(img1, img2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
res, err := backend.FaceVerify(c.Request().Context(), img1, img2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -49,7 +49,7 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
}
xlog.Debug("Token Metrics for model", "model", modelFile)
response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg)
response, err := backend.TokenMetrics(c.Request().Context(), modelFile, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
c.Response().Header().Set("Connection", "keep-alive")
// Stream audio chunks as they're generated
err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
_, writeErr := c.Response().Write(audioChunk)
if writeErr != nil {
return writeErr
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
}
// Non-streaming TTS (existing behavior)
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -36,7 +36,7 @@ func VoiceAnalyzeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
defer cleanup()
xlog.Debug("VoiceAnalyze", "model", cfg.Name, "backend", cfg.Backend, "actions", input.Actions)
res, err := backend.VoiceAnalyze(audio, input.Actions, ml, appConfig, *cfg)
res, err := backend.VoiceAnalyze(c.Request().Context(), audio, input.Actions, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -41,7 +41,7 @@ func VoiceEmbedEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
defer cleanup()
xlog.Debug("VoiceEmbed", "model", cfg.Name, "backend", cfg.Backend)
res, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
res, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -47,7 +47,7 @@ func VoiceIdentifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
threshold := cmp.Or(input.Threshold, defaultVoiceIdentifyThreshold)
xlog.Debug("VoiceIdentify", "model", cfg.Name, "topK", topK, "threshold", threshold)
embed, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
embed, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -40,7 +40,7 @@ func VoiceRegisterEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
defer cleanup()
xlog.Debug("VoiceRegister", "model", cfg.Name, "name", input.Name)
res, err := backend.VoiceEmbed(audio, ml, appConfig, *cfg)
res, err := backend.VoiceEmbed(c.Request().Context(), audio, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -42,7 +42,7 @@ func VoiceVerifyEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
defer cleanup2()
xlog.Debug("VoiceVerify", "model", cfg.Name, "backend", cfg.Backend)
res, err := backend.VoiceVerify(audio1, audio2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
res, err := backend.VoiceVerify(c.Request().Context(), audio1, audio2, input.Threshold, input.AntiSpoofing, ml, appConfig, *cfg)
if err != nil {
return mapBackendError(err)
}

View File

@@ -105,7 +105,7 @@ func DiarizationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, ap
_ = dstFile.Close()
req.Audio = dst
result, err := backend.ModelDiarization(req, ml, *modelConfig, appConfig)
result, err := backend.ModelDiarization(c.Request().Context(), req, ml, *modelConfig, appConfig)
if err != nil {
return err
}

View File

@@ -62,7 +62,7 @@ func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADReques
}
func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
}
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
@@ -82,7 +82,7 @@ func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*sc
}
func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
}
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
@@ -241,7 +241,7 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
}
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
return backend.ModelTTS(text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
}
func (m *wrappedModel) PredictConfig() *config.ModelConfig {

View File

@@ -126,7 +126,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
return streamTranscription(c, req, ml, *config, appConfig)
}
tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig)
tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig)
if err != nil {
// Log before returning so the underlying error survives. Echo's
// error handler turns this into a 500 with a generic body, which
@@ -157,16 +157,16 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
Words: []schema.TranscriptionWordSeconds{},
Segments: []schema.TranscriptionSegmentSeconds{},
}
for _, word := range(tr.Words) {
for _, word := range tr.Words {
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
Start: word.Start.Seconds(),
End: word.End.Seconds(),
Text: word.Text,
})
}
for _, seg := range(tr.Segments) {
for _, seg := range tr.Segments {
segWords := []schema.TranscriptionWordSeconds{}
for _, word := range(seg.Words) {
for _, word := range seg.Words {
segWords = append(segWords, schema.TranscriptionWordSeconds{
Start: word.Start.Seconds(),
End: word.End.Seconds(),
@@ -174,7 +174,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
})
}
trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{
Id: seg.Id,
Id: seg.Id,
Start: seg.Start.Seconds(),
End: seg.End.Seconds(),
Text: seg.Text,
@@ -216,7 +216,7 @@ func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *m
var assembled strings.Builder
var finalResult *schema.TranscriptionResult
err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
if chunk.Delta != "" {
assembled.WriteString(chunk.Delta)
_ = writeEvent(map[string]any{

View File

@@ -3,6 +3,7 @@ package base
// This is a wrapper to satisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"context"
"fmt"
"os"
@@ -57,11 +58,11 @@ func (llm *Base) GenerateVideo(*pb.GenerateVideoRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (llm *Base) AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error) {
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
}
func (llm *Base) AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
func (llm *Base) AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error {
return fmt.Errorf("unimplemented")
}

View File

@@ -1,6 +1,8 @@
package grpc
import (
"context"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
@@ -22,8 +24,8 @@ type AIModel interface {
VoiceVerify(*pb.VoiceVerifyRequest) (pb.VoiceVerifyResponse, error)
VoiceAnalyze(*pb.VoiceAnalyzeRequest) (pb.VoiceAnalyzeResponse, error)
VoiceEmbed(*pb.VoiceEmbedRequest) (pb.VoiceEmbedResponse, error)
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
AudioTranscriptionStream(*pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
AudioTranscription(context.Context, *pb.TranscriptRequest) (pb.TranscriptResult, error)
AudioTranscriptionStream(context.Context, *pb.TranscriptRequest, chan *pb.TranscriptStreamResponse) error
TTS(*pb.TTSRequest) error
TTSStream(*pb.TTSRequest, chan []byte) error
SoundGeneration(*pb.SoundGenerationRequest) error

View File

@@ -218,7 +218,7 @@ func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
s.llm.Lock()
defer s.llm.Unlock()
}
result, err := s.llm.AudioTranscription(in)
result, err := s.llm.AudioTranscription(ctx, in)
if err != nil {
return nil, err
}
@@ -260,7 +260,7 @@ func (s *server) AudioTranscriptionStream(in *pb.TranscriptRequest, stream pb.Ba
done <- true
}()
err := s.llm.AudioTranscriptionStream(in, resultChan)
err := s.llm.AudioTranscriptionStream(stream.Context(), in, resultChan)
<-done
return err

View File

@@ -93,7 +93,7 @@ func (t *testLLM) SoundGeneration(req *pb.SoundGenerationRequest) error {
return nil
}
func (t *testLLM) AudioTranscription(req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (t *testLLM) AudioTranscription(_ context.Context, req *pb.TranscriptRequest) (pb.TranscriptResult, error) {
t.lastAudioDst = req.Dst
return pb.TranscriptResult{Text: "transcribed text"}, nil
}