Files
LocalAI/core/http/endpoints/openai/transcription.go
LocalAI [bot] 2be07f61da feat(whisper): honor client cancellation via ggml abort_callback (#9710)
* refactor(transcription): propagate request ctx through ModelTranscription*

Replaces context.Background() with the HTTP request ctx so client
disconnects start cancelling the gRPC call. No backend-side abort wiring
yet — that comes in a later commit. Pure plumbing.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cli): pass ctx to backend.ModelTranscription

Follow-up to e65d3e1f which threaded ctx through ModelTranscription
but missed the CLI caller. CLI commands have no request-scoped ctx,
so context.Background() is correct here.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(audio): propagate request ctx into TTS, sound-gen, audio-transform

Same ctx-plumbing pattern applied to the rest of the audio path. CLI
callers use context.Background() since there is no request scope; HTTP
callers use c.Request().Context().

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(backend): propagate request ctx into biometric, detection, rerank, diarization paths

Replaces remaining context.Background() sites in core/backend with the
caller's ctx. After this commit, every core/backend/*.go entry point
threads the request ctx end-to-end to the gRPC client.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream}

Adds context.Context as first parameter to the AIModel interface methods
that wrap whisper-style transcription. Server-side gRPC handler now
forwards the per-RPC ctx (server-streaming uses stream.Context()).
Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter;
none uses it yet — the actual cancellation primitive lands in the next
commit so this is pure plumbing.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): add abort_callback hook in the C++ bridge

Installs a std::atomic<int> flag, wires it into
whisper_full_params.abort_callback, and exposes a set_abort(int) C
symbol so Go can flip the flag from a goroutine watching the request
context. transcribe() now distinguishes abort (return 2) from real
whisper_full failure (return 1).

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): register set_abort symbol in the purego loader

Adds the Go-side binding for the new C export so the next commit can
call CppSetAbort(1) from a watcher goroutine on ctx.Done().

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(whisper): honor ctx cancellation and return codes.Canceled

A watcher goroutine watches ctx.Done() during AudioTranscription and
calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return
true at the next compute graph step, returns non-zero, and the bridge
returns 2 -> AudioTranscription maps that to codes.Canceled.

Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH)
that asserts cancellation latency under 5s and proves the abort flag
resets cleanly so the next transcription succeeds.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(whisper): join the cancel watcher goroutine before returning

Follow-up to 85edf9d2. The previous commit used `defer close(done)` and
called the watcher "joined synchronously" — but close() only signals,
it does not block until the goroutine exits. That left a window where
a late CppSetAbort(1) from a cancelled call could land on the next
call, after its C-side g_abort reset but before whisper_full() began
polling the abort callback, corrupting the second transcription.

Switch to a sync.WaitGroup join so wg.Wait() blocks until the watcher
has actually returned from its select.

Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(whisper): short-circuit pre-cancelled ctx in AudioTranscription

If ctx is already Done() at entry, return codes.Canceled immediately
instead of running the full transcription. The C-side g_abort reset
happens at the start of transcribe() and would otherwise overwrite a
watcher-set abort flag from an already-cancelled ctx, producing a
spurious successful transcription on a request the client has already
abandoned.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(tests/distributed): update testLLM mock for new AudioTranscription signature

Phase B (93c48e19) added context.Context to AIModel.AudioTranscription
but missed the testLLM mock in tests/e2e/distributed. CI golangci-lint
caught it: *testLLM did not implement grpc.AIModel because the method
signature lacked the ctx parameter, which broke the distributed test
suite compilation and cascaded through every backend-build job that
runs `go build ./...`.

Assisted-by: Claude:claude-opus-4-7
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* test(whisper): port cancellation test to Ginkgo/Gomega

Project policy (.agents/coding-style.md, enforced by golangci-lint
forbidigo) is that all Go tests must use Ginkgo v2 + Gomega — no
stdlib testing patterns (t.Skip, t.Fatalf, etc.). Convert the
cancellation test to a Describe/It block with Skip(...) for env
gating and Expect/HaveOccurred for assertions.

Same coverage: cancel mid-flight returns codes.Canceled within 5s and
a follow-up transcription succeeds, proving the C-side g_abort flag
resets cleanly.

Assisted-by: Claude:claude-opus-4-7
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
2026-05-08 01:44:47 +02:00

294 lines
9.3 KiB
Go

package openai
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
// TranscriptEndpoint is the OpenAI Whisper API endpoint https://platform.openai.com/docs/api-reference/audio/create
// @Summary Transcribes audio into the input language.
// @Tags audio
// @accept multipart/form-data
// @Param model formData string true "model"
// @Param file formData file true "file"
// @Param temperature formData number false "sampling temperature"
// @Param timestamp_granularities formData []string false "timestamp granularities (word, segment)"
// @Param stream formData boolean false "stream partial results as SSE"
// @Success 200 {object} map[string]string "Response"
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
}
diarize := c.FormValue("diarize") != "false"
prompt := c.FormValue("prompt")
responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format"))
// OpenAI accepts `temperature` as a string in multipart form. Tolerate
// missing/invalid values rather than failing the whole request.
var temperature float32
if v := c.FormValue("temperature"); v != "" {
if t, err := strconv.ParseFloat(v, 32); err == nil {
temperature = float32(t)
}
}
// timestamp_granularities[] is a multi-value form field per the OpenAI spec.
// Echo exposes all values for a key via FormParams.
var timestampGranularities []string
if form, err := c.FormParams(); err == nil {
for _, key := range []string{"timestamp_granularities[]", "timestamp_granularities"} {
if vals, ok := form[key]; ok {
for _, v := range vals {
v = strings.TrimSpace(v)
if v != "" {
timestampGranularities = append(timestampGranularities, v)
}
}
}
}
}
stream := false
if v := c.FormValue("stream"); v != "" {
if b, err := strconv.ParseBool(v); err == nil {
stream = b
}
}
// retrieve the file data from the request
file, err := c.FormFile("file")
if err != nil {
return err
}
f, err := file.Open()
if err != nil {
return err
}
defer f.Close()
dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return err
}
defer os.RemoveAll(dir)
dst := filepath.Join(dir, path.Base(file.Filename))
dstFile, err := os.Create(dst)
if err != nil {
return err
}
if _, err := io.Copy(dstFile, f); err != nil {
xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err)
return err
}
xlog.Debug("Audio file copied", "dst", dst)
req := backend.TranscriptionRequest{
Audio: dst,
Language: input.Language,
Translate: input.Translate,
Diarize: diarize,
Prompt: prompt,
Temperature: temperature,
TimestampGranularities: timestampGranularities,
}
if stream {
return streamTranscription(c, req, ml, *config, appConfig)
}
tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig)
if err != nil {
// Log before returning so the underlying error survives. Echo's
// error handler turns this into a 500 with a generic body, which
// otherwise leaves operators chasing a silent failure — see e.g.
// distributed transcription, where the gRPC error from a remote
// node is the only signal of what actually went wrong.
xlog.Error("Transcription failed",
"model", config.Name,
"audio", dst,
"error", err)
return err
}
xlog.Debug("Transcribed", "transcription", tr)
switch responseFormat {
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatText, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt:
return c.String(http.StatusOK, schema.TranscriptionResponse(tr, responseFormat))
case schema.TranscriptionResponseFormatJson:
tr.Segments = nil
tr.Words = nil
fallthrough
case schema.TranscriptionResponseFormatJsonVerbose, "": // maintain backwards compatibility
trs := schema.TranscriptionResultSeconds{
Text: tr.Text,
Language: tr.Language,
Duration: tr.Duration,
Words: []schema.TranscriptionWordSeconds{},
Segments: []schema.TranscriptionSegmentSeconds{},
}
for _, word := range tr.Words {
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
Start: word.Start.Seconds(),
End: word.End.Seconds(),
Text: word.Text,
})
}
for _, seg := range tr.Segments {
segWords := []schema.TranscriptionWordSeconds{}
for _, word := range seg.Words {
segWords = append(segWords, schema.TranscriptionWordSeconds{
Start: word.Start.Seconds(),
End: word.End.Seconds(),
Text: word.Text,
})
}
trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{
Id: seg.Id,
Start: seg.Start.Seconds(),
End: seg.End.Seconds(),
Text: seg.Text,
Tokens: seg.Tokens,
Speaker: seg.Speaker,
Words: segWords,
})
}
return c.JSON(http.StatusOK, trs)
default:
return errors.New("invalid response_format")
}
}
}
// streamTranscription emits OpenAI-format SSE events for a transcription
// request: one `transcript.text.delta` per backend chunk, a final
// `transcript.text.done` with the assembled text, and `[DONE]`. Backends that
// can't truly stream still produce a single Final event, which we surface as
// one delta + done.
func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *model.ModelLoader, config config.ModelConfig, appConfig *config.ApplicationConfig) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().WriteHeader(http.StatusOK)
writeEvent := func(payload any) error {
data, err := json.Marshal(payload)
if err != nil {
return err
}
if _, err := fmt.Fprintf(c.Response().Writer, "data: %s\n\n", data); err != nil {
return err
}
c.Response().Flush()
return nil
}
var assembled strings.Builder
var finalResult *schema.TranscriptionResult
err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
if chunk.Delta != "" {
assembled.WriteString(chunk.Delta)
_ = writeEvent(map[string]any{
"type": "transcript.text.delta",
"delta": chunk.Delta,
})
}
if chunk.Final != nil {
finalResult = chunk.Final
}
})
if err != nil {
errPayload := map[string]any{
"type": "error",
"error": map[string]any{
"message": err.Error(),
"type": "server_error",
},
}
_ = writeEvent(errPayload)
_, _ = fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}
// Build the final event. Prefer the backend-provided final result; if the
// backend only emitted deltas, synthesize the result from what we collected.
if finalResult == nil {
finalResult = &schema.TranscriptionResult{Text: assembled.String()}
} else if finalResult.Text == "" && assembled.Len() > 0 {
finalResult.Text = assembled.String()
}
// If the backend never produced a delta but did return a final text, emit
// it as a single delta so clients always see at least one delta event.
if assembled.Len() == 0 && finalResult.Text != "" {
_ = writeEvent(map[string]any{
"type": "transcript.text.delta",
"delta": finalResult.Text,
})
}
// done carries the assembled text plus, when the backend produced them,
// per-segment timings, audio duration, and detected language. The OpenAI
// streaming spec only specifies `text`; the extra fields are an additive
// extension so streaming clients (e.g. notetaker) can build the same
// TranscriptionResultSeconds shape they get from the JSON response path
// without us forcing them off SSE just to recover segments. Spec-compliant
// clients ignore unknown fields.
doneEvent := map[string]any{
"type": "transcript.text.done",
"text": finalResult.Text,
}
if finalResult.Language != "" {
doneEvent["language"] = finalResult.Language
}
if finalResult.Duration > 0 {
doneEvent["duration"] = finalResult.Duration
}
if len(finalResult.Segments) > 0 {
segs := make([]map[string]any, 0, len(finalResult.Segments))
for _, seg := range finalResult.Segments {
segs = append(segs, map[string]any{
"id": seg.Id,
"start": seg.Start.Seconds(),
"end": seg.End.Seconds(),
"text": seg.Text,
})
}
doneEvent["segments"] = segs
}
_ = writeEvent(doneEvent)
_, _ = fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}