diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml index 6d840c919..52f33ebde 100644 --- a/.github/workflows/test-extra.yml +++ b/.github/workflows/test-extra.yml @@ -43,6 +43,7 @@ jobs: insightface: ${{ steps.detect.outputs.insightface }} speaker-recognition: ${{ steps.detect.outputs.speaker-recognition }} sherpa-onnx: ${{ steps.detect.outputs.sherpa-onnx }} + whisper: ${{ steps.detect.outputs.whisper }} steps: - name: Checkout repository uses: actions/checkout@v6 @@ -583,6 +584,27 @@ jobs: - name: Build sherpa-onnx backend image and run streaming ASR gRPC e2e tests run: | make test-extra-backend-sherpa-onnx-transcription + # End-to-end transcription via the e2e-backends gRPC harness against + # the whisper.cpp backend. Drives AudioTranscription (offline) and + # AudioTranscriptionStream (real, segment-callback-driven deltas) on + # ggml-base.en + the JFK 11s clip. + tests-whisper-grpc-transcription: + needs: detect-changes + if: needs.detect-changes.outputs.whisper == 'true' || needs.detect-changes.outputs.run-all == 'true' + runs-on: ubuntu-latest + timeout-minutes: 90 + steps: + - name: Clone + uses: actions/checkout@v6 + with: + submodules: true + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.25.4' + - name: Build whisper backend image and run transcription gRPC e2e tests + run: | + make test-extra-backend-whisper-transcription # VITS TTS via the sherpa-onnx backend. Drives both TTS (file write) and # TTSStream (PCM chunks) on the e2e-backends harness. tests-sherpa-onnx-grpc-tts: diff --git a/Makefile b/Makefile index 8d97d675a..49556077e 100644 --- a/Makefile +++ b/Makefile @@ -897,6 +897,18 @@ test-extra-backend-vibevoice-cpp-transcription: docker-build-vibevoice-cpp BACKEND_TEST_CAPS=health,load,transcription \ $(MAKE) test-extra-backend +## Audio transcription wrapper for the whisper.cpp backend. +## Drives the AudioTranscription / AudioTranscriptionStream RPCs against +## ggml-base.en (~145 MB) using the JFK 11s clip. The streaming spec +## asserts len(deltas) >= 1 and concat(deltas) == final.Text - whisper- +## specific multi-segment assertions live in backend/go/whisper/gowhisper_test.go. +test-extra-backend-whisper-transcription: docker-build-whisper + BACKEND_IMAGE=local-ai-backend:whisper \ + BACKEND_TEST_MODEL_URL=https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin \ + BACKEND_TEST_AUDIO_URL=https://github.com/ggml-org/whisper.cpp/raw/master/samples/jfk.wav \ + BACKEND_TEST_CAPS=health,load,transcription \ + $(MAKE) test-extra-backend + ## LocalVQE audio transform (joint AEC + noise suppression + dereverb). ## Exercises the audio_transform capability end-to-end: batch transform ## of a real WAV fixture and bidi streaming of synthetic silent frames. diff --git a/backend/go/whisper/cpp/gowhisper.cpp b/backend/go/whisper/cpp/gowhisper.cpp index 77dbc74f9..aab790557 100644 --- a/backend/go/whisper/cpp/gowhisper.cpp +++ b/backend/go/whisper/cpp/gowhisper.cpp @@ -10,14 +10,38 @@ static std::vector flat_segs; static std::atomic g_abort{0}; +static std::atomic g_go_new_segment_cb{0}; +static std::atomic g_go_new_segment_user_data{0}; + static bool abort_cb(void * /*user_data*/) { return g_abort.load(std::memory_order_relaxed) != 0; } +static void new_segment_cb(struct whisper_context *cb_ctx, + struct whisper_state * /*state*/, int n_new, + void * /*user_data*/) { + uintptr_t go_cb = g_go_new_segment_cb.load(std::memory_order_relaxed); + if (go_cb == 0) { + return; + } + int total = whisper_full_n_segments(cb_ctx); + int idx_first = total - n_new; + if (idx_first < 0) { + idx_first = 0; + } + uintptr_t ud = g_go_new_segment_user_data.load(std::memory_order_relaxed); + reinterpret_cast(go_cb)(idx_first, n_new, ud); +} + extern "C" void set_abort(int v) { g_abort.store(v, std::memory_order_relaxed); } +extern "C" void set_new_segment_callback(uintptr_t cb_ptr, uintptr_t user_data) { + g_go_new_segment_cb.store(cb_ptr, std::memory_order_relaxed); + g_go_new_segment_user_data.store(user_data, std::memory_order_relaxed); +} + static void ggml_log_cb(enum ggml_log_level level, const char *log, void *data) { const char *level_str; @@ -139,6 +163,14 @@ int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz, // ggml abort hook so a subsequent set_abort(1) from Go aborts the next // compute graph step. g_abort.store(0, std::memory_order_relaxed); + // Only install the new-segment callback when streaming is requested + // (Go side calls set_new_segment_callback before transcribe()). Leaving + // it always-on is harmless but adds a function-pointer dispatch per + // segment for the offline path. + if (g_go_new_segment_cb.load(std::memory_order_relaxed) != 0) { + wparams.new_segment_callback = new_segment_cb; + wparams.new_segment_callback_user_data = nullptr; + } wparams.abort_callback = abort_cb; wparams.abort_callback_user_data = nullptr; diff --git a/backend/go/whisper/cpp/gowhisper.h b/backend/go/whisper/cpp/gowhisper.h index b8c7b6cb6..4bac87eef 100644 --- a/backend/go/whisper/cpp/gowhisper.h +++ b/backend/go/whisper/cpp/gowhisper.h @@ -16,4 +16,15 @@ int n_tokens(int i); int32_t get_token_id(int i, int j); bool get_segment_speaker_turn_next(int i); void set_abort(int v); + +// Function pointer from Go (returned by purego.NewCallback). Invoked once +// per new-segment event during whisper_full(). The callback runs on the +// decode thread - if Go blocks (slow gRPC consumer), the decode blocks +// too. That is the intended backpressure path. +typedef void (*go_new_segment_cb)(int idx_first, int n_new, uintptr_t user_data); + +// Install the callback used by the next transcribe() call. Pass cb=0 to +// clear. user_data is opaque to C; the Go side uses it to look up +// per-call state. +void set_new_segment_callback(uintptr_t cb_ptr, uintptr_t user_data); } diff --git a/backend/go/whisper/gowhisper.go b/backend/go/whisper/gowhisper.go index feeb49c3e..94609339f 100644 --- a/backend/go/whisper/gowhisper.go +++ b/backend/go/whisper/gowhisper.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "unsafe" "github.com/go-audio/wav" @@ -29,8 +30,83 @@ var ( CppGetTokenID func(i int, j int) int CppGetSegmentSpeakerTurnNext func(i int) bool CppSetAbort func(v int) + // Set by main.go via purego.RegisterLibFunc. Installs (or clears with cb=0) + // the C-side trampoline that whisper.cpp invokes per new segment. + CppSetNewSegmentCallback func(cbPtr uintptr, userData uintptr) ) +// streamCallStates maps per-AudioTranscriptionStream call IDs to the +// state the Go callback needs to emit deltas. Only one entry is ever +// live today (base.SingleThread), but the map shape mirrors +// sherpa-onnx's TTS callback registry and survives a future SingleThread +// removal without a contract change. +var ( + streamCallStates sync.Map // uint64 -> *streamCallState + streamCallSeq atomic.Uint64 + goNewSegmentCb uintptr // purego.NewCallback(onNewSegment) result; set in main.go at boot +) + +type streamCallState struct { + results chan *pb.TranscriptStreamResponse + diarize bool + // nextIdx tracks how many segments we've already emitted. The C + // trampoline passes idx_first = total - n_new, but we walk from + // nextIdx to (idx_first + n_new) defensively in case whisper.cpp ever + // coalesces multiple commits into a single callback invocation. + nextIdx int + // assembled mirrors the literal concat of every Delta sent on results. + // We reuse it as the final TranscriptResult.Text so the e2e + // invariant `final.Text == concat(deltas)` holds exactly. Written from + // the cgo decode thread inside onNewSegment and read by the streaming + // method after CppTranscribe returns; the cgo boundary provides the + // happens-before edge. + assembled strings.Builder +} + +// onNewSegment is the Go side of the C trampoline declared in +// gowhisper.cpp:new_segment_cb. Whisper.cpp invokes it once per +// new-segment event during whisper_full(). Reads segment text via the +// existing CppGetSegment* getters (safe to call against the singleton +// ctx; whisper.cpp is the only writer and it has already published the +// segments by the time this fires). +// +// Sends deltas synchronously: if the channel is full, this blocks the +// whisper decode thread. That's the intended backpressure path - +// dropping deltas would break the concat(deltas) == final.Text invariant +// the e2e suite asserts. +func onNewSegment(idxFirst int32, nNew int32, userData uintptr) { + v, ok := streamCallStates.Load(uint64(userData)) + if !ok { + return // call already torn down (race with cancel + cb fire) + } + state := v.(*streamCallState) + end := int(idxFirst) + int(nNew) + for i := state.nextIdx; i < end; i++ { + txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "�") + txt = strings.TrimSpace(txt) + if state.diarize && CppGetSegmentSpeakerTurnNext(i) { + txt += " [SPEAKER_TURN]" + } + if txt == "" { + state.nextIdx = i + 1 + continue + } + // Prefix subsequent deltas with a single space so the assembled + // stream reads as one space-joined transcript. The first delta has + // no leading space, otherwise concat(deltas) would not match + // final.Text and the e2e invariant would break. + var delta string + if state.assembled.Len() == 0 { + delta = txt + } else { + delta = " " + txt + } + state.results <- &pb.TranscriptStreamResponse{Delta: delta} + state.assembled.WriteString(delta) + state.nextIdx = i + 1 + } +} + type Whisper struct { base.SingleThread } @@ -200,3 +276,120 @@ func (w *Whisper) AudioTranscription(ctx context.Context, opts *pb.TranscriptReq Duration: duration, }, nil } + +// AudioTranscriptionStream runs whisper_full() and emits deltas via +// whisper.cpp's new_segment_callback as segments are decoded, then a +// final TranscriptResult. The offline AudioTranscription is unchanged; +// both paths share whisper's single-instance ctx and the SingleThread +// concurrency model. +func (w *Whisper) AudioTranscriptionStream(ctx context.Context, opts *pb.TranscriptRequest, results chan *pb.TranscriptStreamResponse) error { + defer close(results) + + if err := ctx.Err(); err != nil { + return status.Error(codes.Canceled, "transcription cancelled") + } + + dir, err := os.MkdirTemp("", "whisper") + if err != nil { + return err + } + defer func() { _ = os.RemoveAll(dir) }() + + convertedPath := filepath.Join(dir, "converted.wav") + if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil { + return err + } + + fh, err := os.Open(convertedPath) + if err != nil { + return err + } + defer func() { _ = fh.Close() }() + + d := wav.NewDecoder(fh) + buf, err := d.FullPCMBuffer() + if err != nil { + return err + } + data := buf.AsFloat32Buffer().Data + var duration float32 + if buf.Format != nil && buf.Format.SampleRate > 0 { + duration = float32(len(data)) / float32(buf.Format.SampleRate) + } + + // Register per-call state and install the C-side callback. defer + // teardown so even a panic clears the C pointer (otherwise a stale + // callback fires on the next AudioTranscription call). + callID := streamCallSeq.Add(1) + state := &streamCallState{ + results: results, + diarize: opts.Diarize, + } + streamCallStates.Store(callID, state) + CppSetNewSegmentCallback(goNewSegmentCb, uintptr(callID)) + defer func() { + CppSetNewSegmentCallback(0, 0) + streamCallStates.Delete(callID) + }() + + // Same abort-watcher pattern as AudioTranscription. Joined synchronously + // so a late CppSetAbort(1) cannot fire after this function returns. + done := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + CppSetAbort(1) + case <-done: + } + }() + defer func() { + close(done) + wg.Wait() + }() + + segsLen := uintptr(0xdeadbeef) + segsLenPtr := unsafe.Pointer(&segsLen) + ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt) + if ret == 2 { + return status.Error(codes.Canceled, "transcription cancelled") + } + if ret != 0 { + return fmt.Errorf("Failed Transcribe") + } + + // Build the final TranscriptResult. Segments[] mirrors the offline + // path so the SSE done event carries the same per-segment shape. + // final.Text reuses the assembled stream so concat(deltas) == final.Text + // holds exactly, matching the e2e contract. + segments := []*pb.TranscriptSegment{} + for i := range int(segsLen) { + s := CppGetSegmentStart(i) * 10000000 + t := CppGetSegmentEnd(i) * 10000000 + txt := strings.ToValidUTF8(strings.Clone(CppGetSegmentText(i)), "�") + tokens := make([]int32, CppNTokens(i)) + if opts.Diarize && CppGetSegmentSpeakerTurnNext(i) { + txt += " [SPEAKER_TURN]" + } + for j := range tokens { + tokens[j] = int32(CppGetTokenID(i, j)) + } + segments = append(segments, &pb.TranscriptSegment{ + Id: int32(i), + Text: txt, + Start: s, End: t, + Tokens: tokens, + }) + } + + final := &pb.TranscriptResult{ + Segments: segments, + Text: state.assembled.String(), + Language: opts.Language, + Duration: duration, + } + results <- &pb.TranscriptStreamResponse{FinalResult: final} + return nil +} diff --git a/backend/go/whisper/gowhisper_test.go b/backend/go/whisper/gowhisper_test.go index 1b89e6615..9ff53202d 100644 --- a/backend/go/whisper/gowhisper_test.go +++ b/backend/go/whisper/gowhisper_test.go @@ -3,6 +3,7 @@ package main import ( "context" "os" + "strings" "sync" "testing" "time" @@ -52,6 +53,7 @@ func ensureLibLoaded() { purego.RegisterLibFunc(&CppGetTokenID, gosd, "get_token_id") purego.RegisterLibFunc(&CppGetSegmentSpeakerTurnNext, gosd, "get_segment_speaker_turn_next") purego.RegisterLibFunc(&CppSetAbort, gosd, "set_abort") + purego.RegisterLibFunc(&CppSetNewSegmentCallback, gosd, "set_new_segment_callback") }) if libLoadErr != nil { Skip("whisper library not loadable: " + libLoadErr.Error()) @@ -109,4 +111,64 @@ var _ = Describe("Whisper", func() { Expect(res.Text).ToNot(BeEmpty(), "post-cancel transcription returned empty text") }) }) + + Context("AudioTranscriptionStream", func() { + It("emits multiple deltas progressively for a multi-segment clip", func() { + modelPath, audioPath := fixturesOrSkip() + ensureLibLoaded() + + // The streaming method dispatches through the package-level + // goNewSegmentCb. main.go normally builds it; in this test + // process main() is never called, so build it here lazily. + // purego.NewCallback returns a stable pointer; calling it once + // per process is correct. + if goNewSegmentCb == 0 { + goNewSegmentCb = purego.NewCallback(onNewSegment) + } + + w := &Whisper{} + Expect(w.Load(&pb.ModelOptions{ModelFile: modelPath})).To(Succeed()) + + results := make(chan *pb.TranscriptStreamResponse, 64) + done := make(chan error, 1) + go func() { + done <- w.AudioTranscriptionStream(context.Background(), &pb.TranscriptRequest{ + Dst: audioPath, + Threads: 4, + Language: "en", + Stream: true, + }, results) + }() + + var deltas []string + var assembled strings.Builder + var finalText string + var finalSegmentCount int + for chunk := range results { + if d := chunk.GetDelta(); d != "" { + deltas = append(deltas, d) + assembled.WriteString(d) + } + if final := chunk.GetFinalResult(); final != nil { + finalText = final.GetText() + finalSegmentCount = len(final.GetSegments()) + } + } + Expect(<-done).ToNot(HaveOccurred()) + + // The whisper-specific bar: real streaming via new_segment_callback + // fires once per decoded segment, so a multi-segment clip MUST + // produce >=2 delta events. A faked-streaming impl (run + // whisper_full to completion, then walk the segment list) would + // also pass len(deltas) >= 1, which is why the generic e2e spec + // is not strict enough. + Expect(len(deltas)).To(BeNumerically(">=", 2), + "expected multiple deltas from a multi-segment clip, got %d (assembled=%q)", + len(deltas), assembled.String()) + Expect(finalSegmentCount).To(BeNumerically(">=", 2), + "expected final to carry multiple segments") + Expect(assembled.String()).To(Equal(finalText), + "concat(deltas) must equal final.Text") + }) + }) }) diff --git a/backend/go/whisper/main.go b/backend/go/whisper/main.go index df35e9792..e48b24519 100644 --- a/backend/go/whisper/main.go +++ b/backend/go/whisper/main.go @@ -42,12 +42,18 @@ func main() { {&CppGetTokenID, "get_token_id"}, {&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"}, {&CppSetAbort, "set_abort"}, + {&CppSetNewSegmentCallback, "set_new_segment_callback"}, } for _, lf := range libFuncs { purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name) } + // Build a stable C-callable function pointer from the Go callback. The + // pointer lives for the lifetime of the process; per-call dispatch is + // keyed by user_data through streamCallStates. + goNewSegmentCb = purego.NewCallback(onNewSegment) + flag.Parse() if err := grpc.StartServer(*addr, &Whisper{}); err != nil {