mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-17 04:56:52 -04:00
* refactor(transcription): propagate request ctx through ModelTranscription* Replaces context.Background() with the HTTP request ctx so client disconnects start cancelling the gRPC call. No backend-side abort wiring yet — that comes in a later commit. Pure plumbing. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cli): pass ctx to backend.ModelTranscription Follow-up toe65d3e1fwhich threaded ctx through ModelTranscription but missed the CLI caller. CLI commands have no request-scoped ctx, so context.Background() is correct here. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(audio): propagate request ctx into TTS, sound-gen, audio-transform Same ctx-plumbing pattern applied to the rest of the audio path. CLI callers use context.Background() since there is no request scope; HTTP callers use c.Request().Context(). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(backend): propagate request ctx into biometric, detection, rerank, diarization paths Replaces remaining context.Background() sites in core/backend with the caller's ctx. After this commit, every core/backend/*.go entry point threads the request ctx end-to-end to the gRPC client. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream} Adds context.Context as first parameter to the AIModel interface methods that wrap whisper-style transcription. Server-side gRPC handler now forwards the per-RPC ctx (server-streaming uses stream.Context()). Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter; none uses it yet — the actual cancellation primitive lands in the next commit so this is pure plumbing. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): add abort_callback hook in the C++ bridge Installs a std::atomic<int> flag, wires it into whisper_full_params.abort_callback, and exposes a set_abort(int) C symbol so Go can flip the flag from a goroutine watching the request context. transcribe() now distinguishes abort (return 2) from real whisper_full failure (return 1). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): register set_abort symbol in the purego loader Adds the Go-side binding for the new C export so the next commit can call CppSetAbort(1) from a watcher goroutine on ctx.Done(). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): honor ctx cancellation and return codes.Canceled A watcher goroutine watches ctx.Done() during AudioTranscription and calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return true at the next compute graph step, returns non-zero, and the bridge returns 2 -> AudioTranscription maps that to codes.Canceled. Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH) that asserts cancellation latency under 5s and proves the abort flag resets cleanly so the next transcription succeeds. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(whisper): join the cancel watcher goroutine before returning Follow-up to85edf9d2. The previous commit used `defer close(done)` and called the watcher "joined synchronously" — but close() only signals, it does not block until the goroutine exits. That left a window where a late CppSetAbort(1) from a cancelled call could land on the next call, after its C-side g_abort reset but before whisper_full() began polling the abort callback, corrupting the second transcription. Switch to a sync.WaitGroup join so wg.Wait() blocks until the watcher has actually returned from its select. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(whisper): short-circuit pre-cancelled ctx in AudioTranscription If ctx is already Done() at entry, return codes.Canceled immediately instead of running the full transcription. The C-side g_abort reset happens at the start of transcribe() and would otherwise overwrite a watcher-set abort flag from an already-cancelled ctx, producing a spurious successful transcription on a request the client has already abandoned. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(tests/distributed): update testLLM mock for new AudioTranscription signature Phase B (93c48e19) added context.Context to AIModel.AudioTranscription but missed the testLLM mock in tests/e2e/distributed. CI golangci-lint caught it: *testLLM did not implement grpc.AIModel because the method signature lacked the ctx parameter, which broke the distributed test suite compilation and cascaded through every backend-build job that runs `go build ./...`. Assisted-by: Claude:claude-opus-4-7 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * test(whisper): port cancellation test to Ginkgo/Gomega Project policy (.agents/coding-style.md, enforced by golangci-lint forbidigo) is that all Go tests must use Ginkgo v2 + Gomega — no stdlib testing patterns (t.Skip, t.Fatalf, etc.). Convert the cancellation test to a Describe/It block with Skip(...) for env gating and Expect/HaveOccurred for assertions. Same coverage: cancel mid-flight returns codes.Canceled within 5s and a follow-up transcription succeeds, proving the C-side g_abort flag resets cleanly. Assisted-by: Claude:claude-opus-4-7 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
182 lines
6.1 KiB
Go
182 lines
6.1 KiB
Go
package openai
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/LocalAI/core/backend"
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/http/middleware"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
model "github.com/mudler/LocalAI/pkg/model"
|
|
|
|
"github.com/mudler/xlog"
|
|
)
|
|
|
|
// DiarizationEndpoint runs offline speaker diarization on an uploaded
|
|
// audio file and returns "who spoke when". Backends with a pure
|
|
// diarization pipeline (sherpa-onnx + pyannote) emit only segmentation;
|
|
// backends that produce diarization as a by-product of ASR (vibevoice.cpp)
|
|
// can additionally fill in the per-segment transcript when the caller
|
|
// passes `include_text=true`.
|
|
//
|
|
// Response formats follow transcription's: `json` (default, segments only),
|
|
// `verbose_json` (adds speaker summary and per-segment text), and `rttm`
|
|
// (NIST RTTM, the standard interchange format used by pyannote/dscore).
|
|
//
|
|
// @Summary Identify speakers in audio (who spoke when).
|
|
// @Tags audio
|
|
// @accept multipart/form-data
|
|
// @Param model formData string true "model"
|
|
// @Param file formData file true "audio file"
|
|
// @Param num_speakers formData int false "exact speaker count (>0 forces; 0 = auto)"
|
|
// @Param min_speakers formData int false "lower bound when auto-detecting"
|
|
// @Param max_speakers formData int false "upper bound when auto-detecting"
|
|
// @Param clustering_threshold formData number false "clustering distance threshold when num_speakers is unknown"
|
|
// @Param min_duration_on formData number false "discard segments shorter than this (seconds)"
|
|
// @Param min_duration_off formData number false "merge gaps shorter than this (seconds)"
|
|
// @Param language formData string false "audio language hint (only meaningful for backends that bundle ASR)"
|
|
// @Param include_text formData boolean false "include per-segment transcript when the backend supports it"
|
|
// @Param response_format formData string false "json (default), verbose_json, or rttm"
|
|
// @Success 200 {object} schema.DiarizationResult
|
|
// @Router /v1/audio/diarization [post]
|
|
func DiarizationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
|
if !ok || input.Model == "" {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
modelConfig, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
|
if !ok || modelConfig == nil {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
req := backend.DiarizationRequest{
|
|
Language: input.Language,
|
|
IncludeText: parseFormBool(c, "include_text", false),
|
|
}
|
|
req.NumSpeakers = int32(parseFormInt(c, "num_speakers", 0))
|
|
req.MinSpeakers = int32(parseFormInt(c, "min_speakers", 0))
|
|
req.MaxSpeakers = int32(parseFormInt(c, "max_speakers", 0))
|
|
req.ClusteringThreshold = float32(parseFormFloat(c, "clustering_threshold", 0))
|
|
req.MinDurationOn = float32(parseFormFloat(c, "min_duration_on", 0))
|
|
req.MinDurationOff = float32(parseFormFloat(c, "min_duration_off", 0))
|
|
|
|
responseFormat := schema.DiarizationResponseFormatType(strings.ToLower(c.FormValue("response_format")))
|
|
if responseFormat == "" {
|
|
responseFormat = schema.DiarizationResponseFormatJson
|
|
}
|
|
|
|
file, err := c.FormFile("file")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
f, err := file.Open()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = f.Close() }()
|
|
|
|
dir, err := os.MkdirTemp("", "diarize")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = os.RemoveAll(dir) }()
|
|
|
|
dst := filepath.Join(dir, path.Base(file.Filename))
|
|
dstFile, err := os.Create(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := io.Copy(dstFile, f); err != nil {
|
|
xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err)
|
|
_ = dstFile.Close()
|
|
return err
|
|
}
|
|
_ = dstFile.Close()
|
|
req.Audio = dst
|
|
|
|
result, err := backend.ModelDiarization(c.Request().Context(), req, ml, *modelConfig, appConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch responseFormat {
|
|
case schema.DiarizationResponseFormatRTTM:
|
|
c.Response().Header().Set(echo.HeaderContentType, "text/plain; charset=utf-8")
|
|
return c.String(http.StatusOK, renderRTTM(result, file.Filename))
|
|
case schema.DiarizationResponseFormatJson:
|
|
// Default JSON: drop the heavy per-speaker summary and any
|
|
// optional per-segment text so simple consumers see a tight
|
|
// payload. verbose_json keeps everything.
|
|
result.Speakers = nil
|
|
for i := range result.Segments {
|
|
result.Segments[i].Text = ""
|
|
}
|
|
return c.JSON(http.StatusOK, result)
|
|
case schema.DiarizationResponseFormatJsonVerbose:
|
|
return c.JSON(http.StatusOK, result)
|
|
default:
|
|
return errors.New("invalid response_format (expected: json, verbose_json, rttm)")
|
|
}
|
|
}
|
|
}
|
|
|
|
// renderRTTM emits NIST RTTM rows. Each row:
|
|
// SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker> <NA> <NA>
|
|
// Field separators are spaces; one row per segment.
|
|
func renderRTTM(r *schema.DiarizationResult, sourceFile string) string {
|
|
id := strings.TrimSuffix(filepath.Base(sourceFile), filepath.Ext(sourceFile))
|
|
// filepath.Base("") returns "." — treat both as a missing source name and
|
|
// fall back to a stable placeholder so the RTTM row stays parseable.
|
|
if id == "" || id == "." {
|
|
id = "audio"
|
|
}
|
|
var sb strings.Builder
|
|
for _, seg := range r.Segments {
|
|
dur := seg.End - seg.Start
|
|
if dur < 0 {
|
|
dur = 0
|
|
}
|
|
fmt.Fprintf(&sb, "SPEAKER %s 1 %.3f %.3f <NA> <NA> %s <NA> <NA>\n",
|
|
id, seg.Start, dur, seg.Speaker)
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func parseFormInt(c echo.Context, key string, def int) int {
|
|
if v := c.FormValue(key); v != "" {
|
|
if n, err := strconv.Atoi(v); err == nil {
|
|
return n
|
|
}
|
|
}
|
|
return def
|
|
}
|
|
|
|
func parseFormFloat(c echo.Context, key string, def float64) float64 {
|
|
if v := c.FormValue(key); v != "" {
|
|
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
|
return f
|
|
}
|
|
}
|
|
return def
|
|
}
|
|
|
|
func parseFormBool(c echo.Context, key string, def bool) bool {
|
|
if v := c.FormValue(key); v != "" {
|
|
if b, err := strconv.ParseBool(v); err == nil {
|
|
return b
|
|
}
|
|
}
|
|
return def
|
|
}
|