refactor(transcription): propagate request ctx through ModelTranscription*

Replaces context.Background() with the HTTP request ctx so client
disconnects start cancelling the gRPC call. No backend-side abort wiring
yet — that comes in a later commit. Pure plumbing.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-05-07 15:35:48 +00:00
parent 3234e6d6ba
commit e65d3e1fd0
3 changed files with 16 additions and 16 deletions

View File

@@ -57,8 +57,8 @@ func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfi
return transcriptionModel, nil return transcriptionModel, nil
} }
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
return ModelTranscriptionWithOptions(TranscriptionRequest{ return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{
Audio: audio, Audio: audio,
Language: language, Language: language,
Translate: translate, Translate: translate,
@@ -67,7 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
}, ml, modelConfig, appConfig) }, ml, modelConfig, appConfig)
} }
func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig) transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -82,7 +82,7 @@ func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoad
audioSnippet = trace.AudioSnippet(req.Audio) audioSnippet = trace.AudioSnippet(req.Audio)
} }
r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads))) r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
if err != nil { if err != nil {
if appConfig.EnableTracing { if appConfig.EnableTracing {
errData := map[string]any{ errData := map[string]any{
@@ -149,7 +149,7 @@ type TranscriptionStreamChunk struct {
// invokes onChunk for each event the backend produces. Backends that don't // invokes onChunk for each event the backend produces. Backends that don't
// support real streaming should still emit one terminal event with Final set, // support real streaming should still emit one terminal event with Final set,
// which the HTTP layer turns into a single delta + done SSE pair. // which the HTTP layer turns into a single delta + done SSE pair.
func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error { func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig) transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
if err != nil { if err != nil {
return err return err
@@ -158,7 +158,7 @@ func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, m
pbReq := req.toProto(uint32(*modelConfig.Threads)) pbReq := req.toProto(uint32(*modelConfig.Threads))
pbReq.Stream = true pbReq.Stream = true
return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) { return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) {
if chunk == nil { if chunk == nil {
return return
} }

View File

@@ -62,7 +62,7 @@ func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADReques
} }
func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
} }
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
@@ -82,7 +82,7 @@ func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*sc
} }
func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
} }
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {

View File

@@ -126,7 +126,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
return streamTranscription(c, req, ml, *config, appConfig) return streamTranscription(c, req, ml, *config, appConfig)
} }
tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig) tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig)
if err != nil { if err != nil {
// Log before returning so the underlying error survives. Echo's // Log before returning so the underlying error survives. Echo's
// error handler turns this into a 500 with a generic body, which // error handler turns this into a 500 with a generic body, which
@@ -157,16 +157,16 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
Words: []schema.TranscriptionWordSeconds{}, Words: []schema.TranscriptionWordSeconds{},
Segments: []schema.TranscriptionSegmentSeconds{}, Segments: []schema.TranscriptionSegmentSeconds{},
} }
for _, word := range(tr.Words) { for _, word := range tr.Words {
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{ trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
Start: word.Start.Seconds(), Start: word.Start.Seconds(),
End: word.End.Seconds(), End: word.End.Seconds(),
Text: word.Text, Text: word.Text,
}) })
} }
for _, seg := range(tr.Segments) { for _, seg := range tr.Segments {
segWords := []schema.TranscriptionWordSeconds{} segWords := []schema.TranscriptionWordSeconds{}
for _, word := range(seg.Words) { for _, word := range seg.Words {
segWords = append(segWords, schema.TranscriptionWordSeconds{ segWords = append(segWords, schema.TranscriptionWordSeconds{
Start: word.Start.Seconds(), Start: word.Start.Seconds(),
End: word.End.Seconds(), End: word.End.Seconds(),
@@ -216,7 +216,7 @@ func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *m
var assembled strings.Builder var assembled strings.Builder
var finalResult *schema.TranscriptionResult var finalResult *schema.TranscriptionResult
err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) { err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
if chunk.Delta != "" { if chunk.Delta != "" {
assembled.WriteString(chunk.Delta) assembled.WriteString(chunk.Delta)
_ = writeEvent(map[string]any{ _ = writeEvent(map[string]any{