mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
feat(whisper): honor ctx cancellation and return codes.Canceled
A watcher goroutine watches ctx.Done() during AudioTranscription and calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return true at the next compute graph step, returns non-zero, and the bridge returns 2 -> AudioTranscription maps that to codes.Canceled. Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH) that asserts cancellation latency under 5s and proves the abort flag resets cleanly so the next transcription succeeds. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -94,7 +96,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
func (w *Whisper) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
dir, err := os.MkdirTemp("", "whisper")
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
@@ -107,14 +109,12 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
|
||||
// Open samples
|
||||
fh, err := os.Open(convertedPath)
|
||||
if err != nil {
|
||||
return pb.TranscriptResult{}, err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Read samples
|
||||
d := wav.NewDecoder(fh)
|
||||
buf, err := d.FullPCMBuffer()
|
||||
if err != nil {
|
||||
@@ -122,8 +122,6 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
|
||||
}
|
||||
|
||||
data := buf.AsFloat32Buffer().Data
|
||||
// whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate
|
||||
// for the converted file produced by AudioToWav above.
|
||||
var duration float32
|
||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||
@@ -131,7 +129,24 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 {
|
||||
// Watcher: flips the C-side abort flag when ctx is cancelled. Joined
|
||||
// synchronously via close(done) so a stale watcher cannot poison the
|
||||
// next call.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
CppSetAbort(1)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||
if ret == 2 {
|
||||
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||
}
|
||||
if ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
|
||||
108
backend/go/whisper/gowhisper_test.go
Normal file
108
backend/go/whisper/gowhisper_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/ebitengine/purego"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
libLoadOnce sync.Once
|
||||
libLoadErr error
|
||||
)
|
||||
|
||||
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||
// bridge without spinning up the gRPC server. Skips the test cleanly if the
|
||||
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||
func ensureLibLoaded(t *testing.T) {
|
||||
t.Helper()
|
||||
libLoadOnce.Do(func() {
|
||||
libName := os.Getenv("WHISPER_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgowhisper-fallback.so"
|
||||
}
|
||||
if _, err := os.Stat(libName); err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
libLoadErr = err
|
||||
return
|
||||
}
|
||||
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||
purego.RegisterLibFunc(&CppNTokens, gosd, "n_tokens")
|
||||
purego.RegisterLibFunc(&CppGetTokenID, gosd, "get_token_id")
|
||||
purego.RegisterLibFunc(&CppGetSegmentSpeakerTurnNext, gosd, "get_segment_speaker_turn_next")
|
||||
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||
})
|
||||
if libLoadErr != nil {
|
||||
t.Skipf("whisper library not loadable: %v", libLoadErr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudioTranscriptionCancel ensures a context cancel mid-flight aborts
|
||||
// whisper_full and surfaces codes.Canceled. The follow-up call asserts that
|
||||
// the C-side abort flag resets cleanly so the next request still succeeds.
|
||||
//
|
||||
// Skipped unless WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH are set.
|
||||
func TestAudioTranscriptionCancel(t *testing.T) {
|
||||
modelPath := os.Getenv("WHISPER_MODEL_PATH")
|
||||
audioPath := os.Getenv("WHISPER_AUDIO_PATH")
|
||||
if modelPath == "" || audioPath == "" {
|
||||
t.Skip("set WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH to run this test")
|
||||
}
|
||||
ensureLibLoaded(t)
|
||||
|
||||
w := &Whisper{}
|
||||
if err := w.Load(&pb.ModelOptions{ModelFile: modelPath}); err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil (transcription completed in %s — try a longer audio file)", elapsed)
|
||||
}
|
||||
if st, ok := status.FromError(err); !ok || st.Code() != codes.Canceled {
|
||||
t.Fatalf("expected codes.Canceled, got %v", err)
|
||||
}
|
||||
if elapsed > 5*time.Second {
|
||||
t.Fatalf("cancellation took %s, expected <5s", elapsed)
|
||||
}
|
||||
|
||||
// Subsequent transcription must succeed — proves g_abort reset.
|
||||
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||
Dst: audioPath,
|
||||
Threads: 4,
|
||||
Language: "en",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("post-cancel transcription failed: %v", err)
|
||||
}
|
||||
if res.Text == "" {
|
||||
t.Fatalf("post-cancel transcription returned empty text")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user