From 85edf9d2f24ec3adaab46e88d4ef2222dbde0c49 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 7 May 2026 16:44:22 +0000 Subject: [PATCH] feat(whisper): honor ctx cancellation and return codes.Canceled A watcher goroutine watches ctx.Done() during AudioTranscription and calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return true at the next compute graph step, returns non-zero, and the bridge returns 2 -> AudioTranscription maps that to codes.Canceled. Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH) that asserts cancellation latency under 5s and proves the abort flag resets cleanly so the next transcription succeeds. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto --- backend/go/whisper/gowhisper.go | 27 +++++-- backend/go/whisper/gowhisper_test.go | 108 +++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 backend/go/whisper/gowhisper_test.go diff --git a/backend/go/whisper/gowhisper.go b/backend/go/whisper/gowhisper.go index c108329d3..065fa713a 100644 --- a/backend/go/whisper/gowhisper.go +++ b/backend/go/whisper/gowhisper.go @@ -12,6 +12,8 @@ import ( "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/utils" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -94,7 +96,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) { }, nil } -func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { +func (w *Whisper) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) { dir, err := os.MkdirTemp("", "whisper") if err != nil { return pb.TranscriptResult{}, err @@ -107,14 +109,12 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque return pb.TranscriptResult{}, err } - // Open samples fh, err := os.Open(convertedPath) if err != nil { return pb.TranscriptResult{}, err } defer fh.Close() - // Read samples d := wav.NewDecoder(fh) buf, err := d.FullPCMBuffer() if err != nil { @@ -122,8 +122,6 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque } data := buf.AsFloat32Buffer().Data - // whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate - // for the converted file produced by AudioToWav above. var duration float32 if buf.Format != nil && buf.Format.SampleRate > 0 { duration = float32(len(data)) / float32(buf.Format.SampleRate) @@ -131,7 +129,24 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque segsLen := uintptr(0xdeadbeef) segsLenPtr := unsafe.Pointer(&segsLen) - if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 { + // Watcher: flips the C-side abort flag when ctx is cancelled. Joined + // synchronously via close(done) so a stale watcher cannot poison the + // next call. + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + CppSetAbort(1) + case <-done: + } + }() + defer close(done) + + ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt) + if ret == 2 { + return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled") + } + if ret != 0 { return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe") } diff --git a/backend/go/whisper/gowhisper_test.go b/backend/go/whisper/gowhisper_test.go new file mode 100644 index 000000000..e90fe837a --- /dev/null +++ b/backend/go/whisper/gowhisper_test.go @@ -0,0 +1,108 @@ +package main + +import ( + "context" + "os" + "sync" + "testing" + "time" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/ebitengine/purego" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var ( + libLoadOnce sync.Once + libLoadErr error +) + +// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the +// bridge without spinning up the gRPC server. Skips the test cleanly if the +// shared library isn't present (e.g. running before `make backends/whisper`). +func ensureLibLoaded(t *testing.T) { + t.Helper() + libLoadOnce.Do(func() { + libName := os.Getenv("WHISPER_LIBRARY") + if libName == "" { + libName = "./libgowhisper-fallback.so" + } + if _, err := os.Stat(libName); err != nil { + libLoadErr = err + return + } + gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err != nil { + libLoadErr = err + return + } + purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model") + purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe") + purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text") + purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0") + purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1") + purego.RegisterLibFunc(&CppNTokens, gosd, "n_tokens") + purego.RegisterLibFunc(&CppGetTokenID, gosd, "get_token_id") + purego.RegisterLibFunc(&CppGetSegmentSpeakerTurnNext, gosd, "get_segment_speaker_turn_next") + purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort") + }) + if libLoadErr != nil { + t.Skipf("whisper library not loadable: %v", libLoadErr) + } +} + +// TestAudioTranscriptionCancel ensures a context cancel mid-flight aborts +// whisper_full and surfaces codes.Canceled. The follow-up call asserts that +// the C-side abort flag resets cleanly so the next request still succeeds. +// +// Skipped unless WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH are set. +func TestAudioTranscriptionCancel(t *testing.T) { + modelPath := os.Getenv("WHISPER_MODEL_PATH") + audioPath := os.Getenv("WHISPER_AUDIO_PATH") + if modelPath == "" || audioPath == "" { + t.Skip("set WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH to run this test") + } + ensureLibLoaded(t) + + w := &Whisper{} + if err := w.Load(&pb.ModelOptions{ModelFile: modelPath}); err != nil { + t.Fatalf("Load: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + start := time.Now() + _, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{ + Dst: audioPath, + Threads: 4, + Language: "en", + }) + elapsed := time.Since(start) + if err == nil { + t.Fatalf("expected error, got nil (transcription completed in %s — try a longer audio file)", elapsed) + } + if st, ok := status.FromError(err); !ok || st.Code() != codes.Canceled { + t.Fatalf("expected codes.Canceled, got %v", err) + } + if elapsed > 5*time.Second { + t.Fatalf("cancellation took %s, expected <5s", elapsed) + } + + // Subsequent transcription must succeed — proves g_abort reset. + res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{ + Dst: audioPath, + Threads: 4, + Language: "en", + }) + if err != nil { + t.Fatalf("post-cancel transcription failed: %v", err) + } + if res.Text == "" { + t.Fatalf("post-cancel transcription returned empty text") + } +}