mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
feat(whisper): honor ctx cancellation and return codes.Canceled
A watcher goroutine watches ctx.Done() during AudioTranscription and calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return true at the next compute graph step, returns non-zero, and the bridge returns 2 -> AudioTranscription maps that to codes.Canceled. Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH) that asserts cancellation latency under 5s and proves the abort flag resets cleanly so the next transcription succeeds. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -12,6 +12,8 @@ import (
|
|||||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
"github.com/mudler/LocalAI/pkg/utils"
|
"github.com/mudler/LocalAI/pkg/utils"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -94,7 +96,7 @@ func (w *Whisper) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
func (w *Whisper) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||||
dir, err := os.MkdirTemp("", "whisper")
|
dir, err := os.MkdirTemp("", "whisper")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pb.TranscriptResult{}, err
|
return pb.TranscriptResult{}, err
|
||||||
@@ -107,14 +109,12 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
|
|||||||
return pb.TranscriptResult{}, err
|
return pb.TranscriptResult{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open samples
|
|
||||||
fh, err := os.Open(convertedPath)
|
fh, err := os.Open(convertedPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pb.TranscriptResult{}, err
|
return pb.TranscriptResult{}, err
|
||||||
}
|
}
|
||||||
defer fh.Close()
|
defer fh.Close()
|
||||||
|
|
||||||
// Read samples
|
|
||||||
d := wav.NewDecoder(fh)
|
d := wav.NewDecoder(fh)
|
||||||
buf, err := d.FullPCMBuffer()
|
buf, err := d.FullPCMBuffer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -122,8 +122,6 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := buf.AsFloat32Buffer().Data
|
data := buf.AsFloat32Buffer().Data
|
||||||
// whisper.cpp resamples to 16 kHz internally; this matches buf.Format.SampleRate
|
|
||||||
// for the converted file produced by AudioToWav above.
|
|
||||||
var duration float32
|
var duration float32
|
||||||
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
if buf.Format != nil && buf.Format.SampleRate > 0 {
|
||||||
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
duration = float32(len(data)) / float32(buf.Format.SampleRate)
|
||||||
@@ -131,7 +129,24 @@ func (w *Whisper) AudioTranscription(_ context.Context, opts *pb.TranscriptReque
|
|||||||
segsLen := uintptr(0xdeadbeef)
|
segsLen := uintptr(0xdeadbeef)
|
||||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||||
|
|
||||||
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 {
|
// Watcher: flips the C-side abort flag when ctx is cancelled. Joined
|
||||||
|
// synchronously via close(done) so a stale watcher cannot poison the
|
||||||
|
// next call.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
CppSetAbort(1)
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt)
|
||||||
|
if ret == 2 {
|
||||||
|
return pb.TranscriptResult{}, status.Error(codes.Canceled, "transcription cancelled")
|
||||||
|
}
|
||||||
|
if ret != 0 {
|
||||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
108
backend/go/whisper/gowhisper_test.go
Normal file
108
backend/go/whisper/gowhisper_test.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
libLoadOnce sync.Once
|
||||||
|
libLoadErr error
|
||||||
|
)
|
||||||
|
|
||||||
|
// ensureLibLoaded mirrors main.go's bootstrap so a Go test can drive the
|
||||||
|
// bridge without spinning up the gRPC server. Skips the test cleanly if the
|
||||||
|
// shared library isn't present (e.g. running before `make backends/whisper`).
|
||||||
|
func ensureLibLoaded(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
libLoadOnce.Do(func() {
|
||||||
|
libName := os.Getenv("WHISPER_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "./libgowhisper-fallback.so"
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(libName); err != nil {
|
||||||
|
libLoadErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
libLoadErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
purego.RegisterLibFunc(&CppLoadModel, gosd, "load_model")
|
||||||
|
purego.RegisterLibFunc(&CppTranscribe, gosd, "transcribe")
|
||||||
|
purego.RegisterLibFunc(&CppGetSegmentText, gosd, "get_segment_text")
|
||||||
|
purego.RegisterLibFunc(&CppGetSegmentStart, gosd, "get_segment_t0")
|
||||||
|
purego.RegisterLibFunc(&CppGetSegmentEnd, gosd, "get_segment_t1")
|
||||||
|
purego.RegisterLibFunc(&CppNTokens, gosd, "n_tokens")
|
||||||
|
purego.RegisterLibFunc(&CppGetTokenID, gosd, "get_token_id")
|
||||||
|
purego.RegisterLibFunc(&CppGetSegmentSpeakerTurnNext, gosd, "get_segment_speaker_turn_next")
|
||||||
|
purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort")
|
||||||
|
})
|
||||||
|
if libLoadErr != nil {
|
||||||
|
t.Skipf("whisper library not loadable: %v", libLoadErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAudioTranscriptionCancel ensures a context cancel mid-flight aborts
|
||||||
|
// whisper_full and surfaces codes.Canceled. The follow-up call asserts that
|
||||||
|
// the C-side abort flag resets cleanly so the next request still succeeds.
|
||||||
|
//
|
||||||
|
// Skipped unless WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH are set.
|
||||||
|
func TestAudioTranscriptionCancel(t *testing.T) {
|
||||||
|
modelPath := os.Getenv("WHISPER_MODEL_PATH")
|
||||||
|
audioPath := os.Getenv("WHISPER_AUDIO_PATH")
|
||||||
|
if modelPath == "" || audioPath == "" {
|
||||||
|
t.Skip("set WHISPER_MODEL_PATH and WHISPER_AUDIO_PATH to run this test")
|
||||||
|
}
|
||||||
|
ensureLibLoaded(t)
|
||||||
|
|
||||||
|
w := &Whisper{}
|
||||||
|
if err := w.Load(&pb.ModelOptions{ModelFile: modelPath}); err != nil {
|
||||||
|
t.Fatalf("Load: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
_, err := w.AudioTranscription(ctx, &pb.TranscriptRequest{
|
||||||
|
Dst: audioPath,
|
||||||
|
Threads: 4,
|
||||||
|
Language: "en",
|
||||||
|
})
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error, got nil (transcription completed in %s — try a longer audio file)", elapsed)
|
||||||
|
}
|
||||||
|
if st, ok := status.FromError(err); !ok || st.Code() != codes.Canceled {
|
||||||
|
t.Fatalf("expected codes.Canceled, got %v", err)
|
||||||
|
}
|
||||||
|
if elapsed > 5*time.Second {
|
||||||
|
t.Fatalf("cancellation took %s, expected <5s", elapsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsequent transcription must succeed — proves g_abort reset.
|
||||||
|
res, err := w.AudioTranscription(context.Background(), &pb.TranscriptRequest{
|
||||||
|
Dst: audioPath,
|
||||||
|
Threads: 4,
|
||||||
|
Language: "en",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("post-cancel transcription failed: %v", err)
|
||||||
|
}
|
||||||
|
if res.Text == "" {
|
||||||
|
t.Fatalf("post-cancel transcription returned empty text")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user