Files
LocalAI/core/http/endpoints/openai/diarization.go
LocalAI [bot] 2be07f61da feat(whisper): honor client cancellation via ggml abort_callback (#9710)
* refactor(transcription): propagate request ctx through ModelTranscription*

Replaces context.Background() with the HTTP request ctx so client
disconnects start cancelling the gRPC call. No backend-side abort wiring
yet — that comes in a later commit. Pure plumbing.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cli): pass ctx to backend.ModelTranscription

Follow-up to e65d3e1f which threaded ctx through ModelTranscription
but missed the CLI caller. CLI commands have no request-scoped ctx,
so context.Background() is correct here.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(audio): propagate request ctx into TTS, sound-gen, audio-transform

Same ctx-plumbing pattern applied to the rest of the audio path. CLI
callers use context.Background() since there is no request scope; HTTP
callers use c.Request().Context().

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(backend): propagate request ctx into biometric, detection, rerank, diarization paths

Replaces remaining context.Background() sites in core/backend with the
caller's ctx. After this commit, every core/backend/*.go entry point
threads the request ctx end-to-end to the gRPC client.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream}

Adds context.Context as first parameter to the AIModel interface methods
that wrap whisper-style transcription. Server-side gRPC handler now
forwards the per-RPC ctx (server-streaming uses stream.Context()).
Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter;
none uses it yet — the actual cancellation primitive lands in the next
commit so this is pure plumbing.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): add abort_callback hook in the C++ bridge

Installs a std::atomic<int> flag, wires it into
whisper_full_params.abort_callback, and exposes a set_abort(int) C
symbol so Go can flip the flag from a goroutine watching the request
context. transcribe() now distinguishes abort (return 2) from real
whisper_full failure (return 1).

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): register set_abort symbol in the purego loader

Adds the Go-side binding for the new C export so the next commit can
call CppSetAbort(1) from a watcher goroutine on ctx.Done().

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): honor ctx cancellation and return codes.Canceled

A watcher goroutine watches ctx.Done() during AudioTranscription and
calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return
true at the next compute graph step, returns non-zero, and the bridge
returns 2 -> AudioTranscription maps that to codes.Canceled.

Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH)
that asserts cancellation latency under 5s and proves the abort flag
resets cleanly so the next transcription succeeds.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(whisper): join the cancel watcher goroutine before returning

Follow-up to 85edf9d2. The previous commit used `defer close(done)` and
called the watcher "joined synchronously" — but close() only signals,
it does not block until the goroutine exits. That left a window where
a late CppSetAbort(1) from a cancelled call could land on the next
call, after its C-side g_abort reset but before whisper_full() began
polling the abort callback, corrupting the second transcription.

Switch to a sync.WaitGroup join so wg.Wait() blocks until the watcher
has actually returned from its select.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(whisper): short-circuit pre-cancelled ctx in AudioTranscription

If ctx is already Done() at entry, return codes.Canceled immediately
instead of running the full transcription. The C-side g_abort reset
happens at the start of transcribe() and would otherwise overwrite a
watcher-set abort flag from an already-cancelled ctx, producing a
spurious successful transcription on a request the client has already
abandoned.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(tests/distributed): update testLLM mock for new AudioTranscription signature

Phase B (93c48e19) added context.Context to AIModel.AudioTranscription
but missed the testLLM mock in tests/e2e/distributed. CI golangci-lint
caught it: *testLLM did not implement grpc.AIModel because the method
signature lacked the ctx parameter, which broke the distributed test
suite compilation and cascaded through every backend-build job that
runs `go build ./...`.

Assisted-by: Claude:claude-opus-4-7
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* test(whisper): port cancellation test to Ginkgo/Gomega

Project policy (.agents/coding-style.md, enforced by golangci-lint
forbidigo) is that all Go tests must use Ginkgo v2 + Gomega — no
stdlib testing patterns (t.Skip, t.Fatalf, etc.). Convert the
cancellation test to a Describe/It block with Skip(...) for env
gating and Expect/HaveOccurred for assertions.

Same coverage: cancel mid-flight returns codes.Canceled within 5s and
a follow-up transcription succeeds, proving the C-side g_abort flag
resets cleanly.

Assisted-by: Claude:claude-opus-4-7
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-05-08 01:44:47 +02:00

182 lines
6.1 KiB
Go

package openai
import (
"errors"
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
// DiarizationEndpoint runs offline speaker diarization on an uploaded
// audio file and returns "who spoke when". Backends with a pure
// diarization pipeline (sherpa-onnx + pyannote) emit only segmentation;
// backends that produce diarization as a by-product of ASR (vibevoice.cpp)
// can additionally fill in the per-segment transcript when the caller
// passes `include_text=true`.
//
// Response formats follow transcription's: `json` (default, segments only),
// `verbose_json` (adds speaker summary and per-segment text), and `rttm`
// (NIST RTTM, the standard interchange format used by pyannote/dscore).
//
// @Summary Identify speakers in audio (who spoke when).
// @Tags audio
// @accept multipart/form-data
// @Param model formData string true "model"
// @Param file formData file true "audio file"
// @Param num_speakers formData int false "exact speaker count (>0 forces; 0 = auto)"
// @Param min_speakers formData int false "lower bound when auto-detecting"
// @Param max_speakers formData int false "upper bound when auto-detecting"
// @Param clustering_threshold formData number false "clustering distance threshold when num_speakers is unknown"
// @Param min_duration_on formData number false "discard segments shorter than this (seconds)"
// @Param min_duration_off formData number false "merge gaps shorter than this (seconds)"
// @Param language formData string false "audio language hint (only meaningful for backends that bundle ASR)"
// @Param include_text formData boolean false "include per-segment transcript when the backend supports it"
// @Param response_format formData string false "json (default), verbose_json, or rttm"
// @Success 200 {object} schema.DiarizationResult
// @Router /v1/audio/diarization [post]
func DiarizationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
modelConfig, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || modelConfig == nil {
return echo.ErrBadRequest
}
req := backend.DiarizationRequest{
Language: input.Language,
IncludeText: parseFormBool(c, "include_text", false),
}
req.NumSpeakers = int32(parseFormInt(c, "num_speakers", 0))
req.MinSpeakers = int32(parseFormInt(c, "min_speakers", 0))
req.MaxSpeakers = int32(parseFormInt(c, "max_speakers", 0))
req.ClusteringThreshold = float32(parseFormFloat(c, "clustering_threshold", 0))
req.MinDurationOn = float32(parseFormFloat(c, "min_duration_on", 0))
req.MinDurationOff = float32(parseFormFloat(c, "min_duration_off", 0))
responseFormat := schema.DiarizationResponseFormatType(strings.ToLower(c.FormValue("response_format")))
if responseFormat == "" {
responseFormat = schema.DiarizationResponseFormatJson
}
file, err := c.FormFile("file")
if err != nil {
return err
}
f, err := file.Open()
if err != nil {
return err
}
defer func() { _ = f.Close() }()
dir, err := os.MkdirTemp("", "diarize")
if err != nil {
return err
}
defer func() { _ = os.RemoveAll(dir) }()
dst := filepath.Join(dir, path.Base(file.Filename))
dstFile, err := os.Create(dst)
if err != nil {
return err
}
if _, err := io.Copy(dstFile, f); err != nil {
xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err)
_ = dstFile.Close()
return err
}
_ = dstFile.Close()
req.Audio = dst
result, err := backend.ModelDiarization(c.Request().Context(), req, ml, *modelConfig, appConfig)
if err != nil {
return err
}
switch responseFormat {
case schema.DiarizationResponseFormatRTTM:
c.Response().Header().Set(echo.HeaderContentType, "text/plain; charset=utf-8")
return c.String(http.StatusOK, renderRTTM(result, file.Filename))
case schema.DiarizationResponseFormatJson:
// Default JSON: drop the heavy per-speaker summary and any
// optional per-segment text so simple consumers see a tight
// payload. verbose_json keeps everything.
result.Speakers = nil
for i := range result.Segments {
result.Segments[i].Text = ""
}
return c.JSON(http.StatusOK, result)
case schema.DiarizationResponseFormatJsonVerbose:
return c.JSON(http.StatusOK, result)
default:
return errors.New("invalid response_format (expected: json, verbose_json, rttm)")
}
}
}
// renderRTTM emits NIST RTTM rows. Each row:
// SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker> <NA> <NA>
// Field separators are spaces; one row per segment.
func renderRTTM(r *schema.DiarizationResult, sourceFile string) string {
id := strings.TrimSuffix(filepath.Base(sourceFile), filepath.Ext(sourceFile))
// filepath.Base("") returns "." — treat both as a missing source name and
// fall back to a stable placeholder so the RTTM row stays parseable.
if id == "" || id == "." {
id = "audio"
}
var sb strings.Builder
for _, seg := range r.Segments {
dur := seg.End - seg.Start
if dur < 0 {
dur = 0
}
fmt.Fprintf(&sb, "SPEAKER %s 1 %.3f %.3f <NA> <NA> %s <NA> <NA>\n",
id, seg.Start, dur, seg.Speaker)
}
return sb.String()
}
func parseFormInt(c echo.Context, key string, def int) int {
if v := c.FormValue(key); v != "" {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return def
}
func parseFormFloat(c echo.Context, key string, def float64) float64 {
if v := c.FormValue(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f
}
}
return def
}
func parseFormBool(c echo.Context, key string, def bool) bool {
if v := c.FormValue(key); v != "" {
if b, err := strconv.ParseBool(v); err == nil {
return b
}
}
return def
}