mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-22 23:58:25 -04:00
feat(whisper): add abort_callback hook in the C++ bridge
Installs a std::atomic<int> flag, wires it into whisper_full_params.abort_callback, and exposes a set_abort(int) C symbol so Go can flip the flag from a goroutine watching the request context. transcribe() now distinguishes abort (return 2) from real whisper_full failure (return 1). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -1,12 +1,23 @@
|
||||
#include "gowhisper.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "whisper.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
static struct whisper_vad_context *vctx;
|
||||
static struct whisper_context *ctx;
|
||||
static std::vector<float> flat_segs;
|
||||
|
||||
static std::atomic<int> g_abort{0};
|
||||
|
||||
static bool abort_cb(void * /*user_data*/) {
|
||||
return g_abort.load(std::memory_order_relaxed) != 0;
|
||||
}
|
||||
|
||||
extern "C" void set_abort(int v) {
|
||||
g_abort.store(v, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||
void *data) {
|
||||
const char *level_str;
|
||||
@@ -124,10 +135,20 @@ int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||
wparams.tdrz_enable = tdrz;
|
||||
wparams.initial_prompt = prompt;
|
||||
|
||||
// Reset stale abort flag from any prior cancelled call, then install the
|
||||
// ggml abort hook so a subsequent set_abort(1) from Go aborts the next
|
||||
// compute graph step.
|
||||
g_abort.store(0, std::memory_order_relaxed);
|
||||
wparams.abort_callback = abort_cb;
|
||||
wparams.abort_callback_user_data = nullptr;
|
||||
|
||||
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
|
||||
fprintf(stderr, "info: Initial prompt: \"%s\"\n", prompt);
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
|
||||
if (g_abort.load(std::memory_order_relaxed)) {
|
||||
return 2; // aborted by client
|
||||
}
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -15,4 +15,5 @@ int64_t get_segment_t1(int i);
|
||||
int n_tokens(int i);
|
||||
int32_t get_token_id(int i, int j);
|
||||
bool get_segment_speaker_turn_next(int i);
|
||||
void set_abort(int v);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user