feat(whisper): honor ctx cancellation and return codes.Canceled

A watcher goroutine watches ctx.Done() during AudioTranscription and
calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return
true at the next compute graph step, returns non-zero, and the bridge
returns 2 -> AudioTranscription maps that to codes.Canceled.

Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH)
that asserts cancellation latency under 5s and proves the abort flag
resets cleanly so the next transcription succeeds.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-05-07 16:44:22 +00:00
parent da4a35bb97
commit 85edf9d2f2
2 changed files with 129 additions and 6 deletions

View File

@@ -12,6 +12,8 @@ import (
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
@@ -94,7 +96,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
}, nil
}
func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
func (w *Whisper) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err
@@ -107,14 +109,12 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
return pb.TranscriptResult{}, err
}
// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return pb.TranscriptResult{}, err
}
defer fh.Close()
// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
@@ -122,8 +122,6 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
}
data := buf.AsFloat32Buffer().Data
// whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate
// for the converted file produced by AudioToWav above.
var duration float32
if buf.Format != nil && buf.Format.SampleRate > 0 {
duration = float32(len(data)) / float32(buf.Format.SampleRate)
@@ -131,7 +129,24 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
segsLen := uintptr(0xdeadbeef)
segsLenPtr := unsafe.Pointer(&segsLen)
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 {
// Watcher: flips the C-side abort flag when ctx is cancelled. Joined
// synchronously via close(done) so a stale watcher cannot poison the
// next call.
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
CppSetAbort(1)
case <-done:
}
}()
defer close(done)
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
if ret == 2 {
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
}
if ret != 0 {
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
}

View File

@@ -0,0 +1,108 @@
package main
import (
"context"
"os"
"sync"
"testing"
"time"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/ebitengine/purego"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
libLoadOnce sync.Once
libLoadErr error
)
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
// bridge without spinning up the gRPC server. Skips the test cleanly if the
// shared library isn't present (e.g. running before `make backends/whisper`).
func ensureLibLoaded(t *testing.T) {
t.Helper()
libLoadOnce.Do(func() {
libName := os.Getenv("WHISPER_LIBRARY")
if libName == "" {
libName = "./libgowhisper-fallback.so"
}
if _, err := os.Stat(libName); err != nil {
libLoadErr = err
return
}
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libLoadErr = err
return
}
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
purego.RegisterLibFunc(&CppNTokens, gosd, "n_tokens")
purego.RegisterLibFunc(&CppGetTokenID, gosd, "get_token_id")
purego.RegisterLibFunc(&CppGetSegmentSpeakerTurnNext, gosd, "get_segment_speaker_turn_next")
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
})
if libLoadErr != nil {
t.Skipf("whisper library not loadable: %v", libLoadErr)
}
}
// TestAudioTranscriptionCancel ensures a context cancel mid-flight aborts
// whisper_full and surfaces codes.Canceled. The follow-up call asserts that
// the C-side abort flag resets cleanly so the next request still succeeds.
//
// Skipped unless WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH are set.
func TestAudioTranscriptionCancel(t *testing.T) {
modelPath := os.Getenv("WHISPER_MODEL_PATH")
audioPath := os.Getenv("WHISPER_AUDIO_PATH")
if modelPath == "" || audioPath == "" {
t.Skip("set WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH to run this test")
}
ensureLibLoaded(t)
w := &Whisper{}
if err := w.Load(&pb.ModelOptions{ModelFile: modelPath}); err != nil {
t.Fatalf("Load: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
start := time.Now()
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
Dst: audioPath,
Threads: 4,
Language: "en",
})
elapsed := time.Since(start)
if err == nil {
t.Fatalf("expected error, got nil (transcription completed in %s — try a longer audio file)", elapsed)
}
if st, ok := status.FromError(err); !ok || st.Code() != codes.Canceled {
t.Fatalf("expected codes.Canceled, got %v", err)
}
if elapsed > 5*time.Second {
t.Fatalf("cancellation took %s, expected <5s", elapsed)
}
// Subsequent transcription must succeed — proves g_abort reset.
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
Dst: audioPath,
Threads: 4,
Language: "en",
})
if err != nil {
t.Fatalf("post-cancel transcription failed: %v", err)
}
if res.Text == "" {
t.Fatalf("post-cancel transcription returned empty text")
}
}