mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-14 19:58:44 -04:00
* refactor(transcription): propagate request ctx through ModelTranscription* Replaces context.Background() with the HTTP request ctx so client disconnects start cancelling the gRPC call. No backend-side abort wiring yet — that comes in a later commit. Pure plumbing. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cli): pass ctx to backend.ModelTranscription Follow-up toe65d3e1fwhich threaded ctx through ModelTranscription but missed the CLI caller. CLI commands have no request-scoped ctx, so context.Background() is correct here. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(audio): propagate request ctx into TTS, sound-gen, audio-transform Same ctx-plumbing pattern applied to the rest of the audio path. CLI callers use context.Background() since there is no request scope; HTTP callers use c.Request().Context(). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(backend): propagate request ctx into biometric, detection, rerank, diarization paths Replaces remaining context.Background() sites in core/backend with the caller's ctx. After this commit, every core/backend/*.go entry point threads the request ctx end-to-end to the gRPC client. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactor(grpc): plumb ctx through AIModel.AudioTranscription{,Stream} Adds context.Context as first parameter to the AIModel interface methods that wrap whisper-style transcription. Server-side gRPC handler now forwards the per-RPC ctx (server-streaming uses stream.Context()). Whisper, Voxtral, vibevoice-cpp, and sherpa-onnx accept the parameter; none uses it yet — the actual cancellation primitive lands in the next commit so this is pure plumbing. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): add abort_callback hook in the C++ bridge Installs a std::atomic<int> flag, wires it into whisper_full_params.abort_callback, and exposes a set_abort(int) C symbol so Go can flip the flag from a goroutine watching the request context. transcribe() now distinguishes abort (return 2) from real whisper_full failure (return 1). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): register set_abort symbol in the purego loader Adds the Go-side binding for the new C export so the next commit can call CppSetAbort(1) from a watcher goroutine on ctx.Done(). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(whisper): honor ctx cancellation and return codes.Canceled A watcher goroutine watches ctx.Done() during AudioTranscription and calls CppSetAbort(1) on cancel. whisper_full sees abort_callback return true at the next compute graph step, returns non-zero, and the bridge returns 2 -> AudioTranscription maps that to codes.Canceled. Adds an opt-in test (gated on WHISPER_MODEL_PATH / WHISPER_AUDIO_PATH) that asserts cancellation latency under 5s and proves the abort flag resets cleanly so the next transcription succeeds. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(whisper): join the cancel watcher goroutine before returning Follow-up to85edf9d2. The previous commit used `defer close(done)` and called the watcher "joined synchronously" — but close() only signals, it does not block until the goroutine exits. That left a window where a late CppSetAbort(1) from a cancelled call could land on the next call, after its C-side g_abort reset but before whisper_full() began polling the abort callback, corrupting the second transcription. Switch to a sync.WaitGroup join so wg.Wait() blocks until the watcher has actually returned from its select. Assisted-by: Claude:claude-sonnet-4-6 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(whisper): short-circuit pre-cancelled ctx in AudioTranscription If ctx is already Done() at entry, return codes.Canceled immediately instead of running the full transcription. The C-side g_abort reset happens at the start of transcribe() and would otherwise overwrite a watcher-set abort flag from an already-cancelled ctx, producing a spurious successful transcription on a request the client has already abandoned. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(tests/distributed): update testLLM mock for new AudioTranscription signature Phase B (93c48e19) added context.Context to AIModel.AudioTranscription but missed the testLLM mock in tests/e2e/distributed. CI golangci-lint caught it: *testLLM did not implement grpc.AIModel because the method signature lacked the ctx parameter, which broke the distributed test suite compilation and cascaded through every backend-build job that runs `go build ./...`. Assisted-by: Claude:claude-opus-4-7 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * test(whisper): port cancellation test to Ginkgo/Gomega Project policy (.agents/coding-style.md, enforced by golangci-lint forbidigo) is that all Go tests must use Ginkgo v2 + Gomega — no stdlib testing patterns (t.Skip, t.Fatalf, etc.). Convert the cancellation test to a Describe/It block with Skip(...) for env gating and Expect/HaveOccurred for assertions. Same coverage: cancel mid-flight returns codes.Canceled within 5s and a follow-up transcription succeeds, proving the C-side g_abort flag resets cleanly. Assisted-by: Claude:claude-opus-4-7 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@localai.io>
294 lines
9.3 KiB
Go
294 lines
9.3 KiB
Go
package openai
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/mudler/LocalAI/core/backend"
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/http/middleware"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
model "github.com/mudler/LocalAI/pkg/model"
|
|
|
|
"github.com/mudler/xlog"
|
|
)
|
|
|
|
// TranscriptEndpoint is the OpenAI Whisper API endpoint https://platform.openai.com/docs/api-reference/audio/create
|
|
// @Summary Transcribes audio into the input language.
|
|
// @Tags audio
|
|
// @accept multipart/form-data
|
|
// @Param model formData string true "model"
|
|
// @Param file formData file true "file"
|
|
// @Param temperature formData number false "sampling temperature"
|
|
// @Param timestamp_granularities formData []string false "timestamp granularities (word, segment)"
|
|
// @Param stream formData boolean false "stream partial results as SSE"
|
|
// @Success 200 {object} map[string]string "Response"
|
|
// @Router /v1/audio/transcriptions [post]
|
|
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
|
if !ok || input.Model == "" {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
|
if !ok || config == nil {
|
|
return echo.ErrBadRequest
|
|
}
|
|
|
|
diarize := c.FormValue("diarize") != "false"
|
|
prompt := c.FormValue("prompt")
|
|
responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format"))
|
|
|
|
// OpenAI accepts `temperature` as a string in multipart form. Tolerate
|
|
// missing/invalid values rather than failing the whole request.
|
|
var temperature float32
|
|
if v := c.FormValue("temperature"); v != "" {
|
|
if t, err := strconv.ParseFloat(v, 32); err == nil {
|
|
temperature = float32(t)
|
|
}
|
|
}
|
|
|
|
// timestamp_granularities[] is a multi-value form field per the OpenAI spec.
|
|
// Echo exposes all values for a key via FormParams.
|
|
var timestampGranularities []string
|
|
if form, err := c.FormParams(); err == nil {
|
|
for _, key := range []string{"timestamp_granularities[]", "timestamp_granularities"} {
|
|
if vals, ok := form[key]; ok {
|
|
for _, v := range vals {
|
|
v = strings.TrimSpace(v)
|
|
if v != "" {
|
|
timestampGranularities = append(timestampGranularities, v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
stream := false
|
|
if v := c.FormValue("stream"); v != "" {
|
|
if b, err := strconv.ParseBool(v); err == nil {
|
|
stream = b
|
|
}
|
|
}
|
|
|
|
// retrieve the file data from the request
|
|
file, err := c.FormFile("file")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
f, err := file.Open()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
dir, err := os.MkdirTemp("", "whisper")
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer os.RemoveAll(dir)
|
|
|
|
dst := filepath.Join(dir, path.Base(file.Filename))
|
|
dstFile, err := os.Create(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := io.Copy(dstFile, f); err != nil {
|
|
xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err)
|
|
return err
|
|
}
|
|
|
|
xlog.Debug("Audio file copied", "dst", dst)
|
|
|
|
req := backend.TranscriptionRequest{
|
|
Audio: dst,
|
|
Language: input.Language,
|
|
Translate: input.Translate,
|
|
Diarize: diarize,
|
|
Prompt: prompt,
|
|
Temperature: temperature,
|
|
TimestampGranularities: timestampGranularities,
|
|
}
|
|
|
|
if stream {
|
|
return streamTranscription(c, req, ml, *config, appConfig)
|
|
}
|
|
|
|
tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig)
|
|
if err != nil {
|
|
// Log before returning so the underlying error survives. Echo's
|
|
// error handler turns this into a 500 with a generic body, which
|
|
// otherwise leaves operators chasing a silent failure — see e.g.
|
|
// distributed transcription, where the gRPC error from a remote
|
|
// node is the only signal of what actually went wrong.
|
|
xlog.Error("Transcription failed",
|
|
"model", config.Name,
|
|
"audio", dst,
|
|
"error", err)
|
|
return err
|
|
}
|
|
|
|
xlog.Debug("Transcribed", "transcription", tr)
|
|
|
|
switch responseFormat {
|
|
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatText, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt:
|
|
return c.String(http.StatusOK, schema.TranscriptionResponse(tr, responseFormat))
|
|
case schema.TranscriptionResponseFormatJson:
|
|
tr.Segments = nil
|
|
tr.Words = nil
|
|
fallthrough
|
|
case schema.TranscriptionResponseFormatJsonVerbose, "": // maintain backwards compatibility
|
|
trs := schema.TranscriptionResultSeconds{
|
|
Text: tr.Text,
|
|
Language: tr.Language,
|
|
Duration: tr.Duration,
|
|
Words: []schema.TranscriptionWordSeconds{},
|
|
Segments: []schema.TranscriptionSegmentSeconds{},
|
|
}
|
|
for _, word := range tr.Words {
|
|
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
|
|
Start: word.Start.Seconds(),
|
|
End: word.End.Seconds(),
|
|
Text: word.Text,
|
|
})
|
|
}
|
|
for _, seg := range tr.Segments {
|
|
segWords := []schema.TranscriptionWordSeconds{}
|
|
for _, word := range seg.Words {
|
|
segWords = append(segWords, schema.TranscriptionWordSeconds{
|
|
Start: word.Start.Seconds(),
|
|
End: word.End.Seconds(),
|
|
Text: word.Text,
|
|
})
|
|
}
|
|
trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{
|
|
Id: seg.Id,
|
|
Start: seg.Start.Seconds(),
|
|
End: seg.End.Seconds(),
|
|
Text: seg.Text,
|
|
Tokens: seg.Tokens,
|
|
Speaker: seg.Speaker,
|
|
Words: segWords,
|
|
})
|
|
}
|
|
return c.JSON(http.StatusOK, trs)
|
|
default:
|
|
return errors.New("invalid response_format")
|
|
}
|
|
}
|
|
}
|
|
|
|
// streamTranscription emits OpenAI-format SSE events for a transcription
|
|
// request: one `transcript.text.delta` per backend chunk, a final
|
|
// `transcript.text.done` with the assembled text, and `[DONE]`. Backends that
|
|
// can't truly stream still produce a single Final event, which we surface as
|
|
// one delta + done.
|
|
func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *model.ModelLoader, config config.ModelConfig, appConfig *config.ApplicationConfig) error {
|
|
c.Response().Header().Set("Content-Type", "text/event-stream")
|
|
c.Response().Header().Set("Cache-Control", "no-cache")
|
|
c.Response().Header().Set("Connection", "keep-alive")
|
|
c.Response().WriteHeader(http.StatusOK)
|
|
|
|
writeEvent := func(payload any) error {
|
|
data, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := fmt.Fprintf(c.Response().Writer, "data: %s\n\n", data); err != nil {
|
|
return err
|
|
}
|
|
c.Response().Flush()
|
|
return nil
|
|
}
|
|
|
|
var assembled strings.Builder
|
|
var finalResult *schema.TranscriptionResult
|
|
|
|
err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
|
if chunk.Delta != "" {
|
|
assembled.WriteString(chunk.Delta)
|
|
_ = writeEvent(map[string]any{
|
|
"type": "transcript.text.delta",
|
|
"delta": chunk.Delta,
|
|
})
|
|
}
|
|
if chunk.Final != nil {
|
|
finalResult = chunk.Final
|
|
}
|
|
})
|
|
if err != nil {
|
|
errPayload := map[string]any{
|
|
"type": "error",
|
|
"error": map[string]any{
|
|
"message": err.Error(),
|
|
"type": "server_error",
|
|
},
|
|
}
|
|
_ = writeEvent(errPayload)
|
|
_, _ = fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
|
c.Response().Flush()
|
|
return nil
|
|
}
|
|
|
|
// Build the final event. Prefer the backend-provided final result; if the
|
|
// backend only emitted deltas, synthesize the result from what we collected.
|
|
if finalResult == nil {
|
|
finalResult = &schema.TranscriptionResult{Text: assembled.String()}
|
|
} else if finalResult.Text == "" && assembled.Len() > 0 {
|
|
finalResult.Text = assembled.String()
|
|
}
|
|
// If the backend never produced a delta but did return a final text, emit
|
|
// it as a single delta so clients always see at least one delta event.
|
|
if assembled.Len() == 0 && finalResult.Text != "" {
|
|
_ = writeEvent(map[string]any{
|
|
"type": "transcript.text.delta",
|
|
"delta": finalResult.Text,
|
|
})
|
|
}
|
|
// done carries the assembled text plus, when the backend produced them,
|
|
// per-segment timings, audio duration, and detected language. The OpenAI
|
|
// streaming spec only specifies `text`; the extra fields are an additive
|
|
// extension so streaming clients (e.g. notetaker) can build the same
|
|
// TranscriptionResultSeconds shape they get from the JSON response path
|
|
// without us forcing them off SSE just to recover segments. Spec-compliant
|
|
// clients ignore unknown fields.
|
|
doneEvent := map[string]any{
|
|
"type": "transcript.text.done",
|
|
"text": finalResult.Text,
|
|
}
|
|
if finalResult.Language != "" {
|
|
doneEvent["language"] = finalResult.Language
|
|
}
|
|
if finalResult.Duration > 0 {
|
|
doneEvent["duration"] = finalResult.Duration
|
|
}
|
|
if len(finalResult.Segments) > 0 {
|
|
segs := make([]map[string]any, 0, len(finalResult.Segments))
|
|
for _, seg := range finalResult.Segments {
|
|
segs = append(segs, map[string]any{
|
|
"id": seg.Id,
|
|
"start": seg.Start.Seconds(),
|
|
"end": seg.End.Seconds(),
|
|
"text": seg.Text,
|
|
})
|
|
}
|
|
doneEvent["segments"] = segs
|
|
}
|
|
_ = writeEvent(doneEvent)
|
|
_, _ = fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
|
c.Response().Flush()
|
|
return nil
|
|
}
|