refactor(transcription): propagate request ctx through ModelTranscription*

Replaces context.Background() with the HTTP request ctx so client
disconnects start cancelling the gRPC call. No backend-side abort wiring
yet — that comes in a later commit. Pure plumbing.

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-05-07 15:35:48 +00:00
parent 3234e6d6ba
commit e65d3e1fd0
3 changed files with 16 additions and 16 deletions

View File

@@ -57,8 +57,8 @@ func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfi
return transcriptionModel, nil return transcriptionModel, nil
} }
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
return ModelTranscriptionWithOptions(TranscriptionRequest{ return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{
Audio: audio, Audio: audio,
Language: language, Language: language,
Translate: translate, Translate: translate,
@@ -67,7 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
}, ml, modelConfig, appConfig) }, ml, modelConfig, appConfig)
} }
func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig) transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -82,7 +82,7 @@ func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoad
audioSnippet = trace.AudioSnippet(req.Audio) audioSnippet = trace.AudioSnippet(req.Audio)
} }
r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads))) r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
if err != nil { if err != nil {
if appConfig.EnableTracing { if appConfig.EnableTracing {
errData := map[string]any{ errData := map[string]any{
@@ -149,7 +149,7 @@ type TranscriptionStreamChunk struct {
// invokes onChunk for each event the backend produces. Backends that don't // invokes onChunk for each event the backend produces. Backends that don't
// support real streaming should still emit one terminal event with Final set, // support real streaming should still emit one terminal event with Final set,
// which the HTTP layer turns into a single delta + done SSE pair. // which the HTTP layer turns into a single delta + done SSE pair.
func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error { func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig) transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
if err != nil { if err != nil {
return err return err
@@ -158,7 +158,7 @@ func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, m
pbReq := req.toProto(uint32(*modelConfig.Threads)) pbReq := req.toProto(uint32(*modelConfig.Threads))
pbReq.Stream = true pbReq.Stream = true
return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) { return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) {
if chunk == nil { if chunk == nil {
return return
} }
@@ -187,12 +187,12 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
} }
var words []schema.TranscriptionWord var words []schema.TranscriptionWord
for _, w := range s.Words { for _, w := range s.Words {
var word = schema.TranscriptionWord { var word = schema.TranscriptionWord{
Start: time.Duration(w.Start), Start: time.Duration(w.Start),
End: time.Duration(w.End), End: time.Duration(w.End),
Text: w.Text, Text: w.Text,
} }
words = append(words, word) words = append(words, word)
tr.Words = append(tr.Words, word) tr.Words = append(tr.Words, word)
} }
tr.Segments = append(tr.Segments, tr.Segments = append(tr.Segments,

View File

@@ -62,7 +62,7 @@ func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADReques
} }
func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
} }
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
@@ -82,7 +82,7 @@ func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*sc
} }
func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
} }
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {

View File

@@ -126,7 +126,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
return streamTranscription(c, req, ml, *config, appConfig) return streamTranscription(c, req, ml, *config, appConfig)
} }
tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig) tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig)
if err != nil { if err != nil {
// Log before returning so the underlying error survives. Echo's // Log before returning so the underlying error survives. Echo's
// error handler turns this into a 500 with a generic body, which // error handler turns this into a 500 with a generic body, which
@@ -157,16 +157,16 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
Words: []schema.TranscriptionWordSeconds{}, Words: []schema.TranscriptionWordSeconds{},
Segments: []schema.TranscriptionSegmentSeconds{}, Segments: []schema.TranscriptionSegmentSeconds{},
} }
for _, word := range(tr.Words) { for _, word := range tr.Words {
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{ trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
Start: word.Start.Seconds(), Start: word.Start.Seconds(),
End: word.End.Seconds(), End: word.End.Seconds(),
Text: word.Text, Text: word.Text,
}) })
} }
for _, seg := range(tr.Segments) { for _, seg := range tr.Segments {
segWords := []schema.TranscriptionWordSeconds{} segWords := []schema.TranscriptionWordSeconds{}
for _, word := range(seg.Words) { for _, word := range seg.Words {
segWords = append(segWords, schema.TranscriptionWordSeconds{ segWords = append(segWords, schema.TranscriptionWordSeconds{
Start: word.Start.Seconds(), Start: word.Start.Seconds(),
End: word.End.Seconds(), End: word.End.Seconds(),
@@ -174,7 +174,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
}) })
} }
trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{ trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{
Id: seg.Id, Id: seg.Id,
Start: seg.Start.Seconds(), Start: seg.Start.Seconds(),
End: seg.End.Seconds(), End: seg.End.Seconds(),
Text: seg.Text, Text: seg.Text,
@@ -216,7 +216,7 @@ func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *m
var assembled strings.Builder var assembled strings.Builder
var finalResult *schema.TranscriptionResult var finalResult *schema.TranscriptionResult
err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) { err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
if chunk.Delta != "" { if chunk.Delta != "" {
assembled.WriteString(chunk.Delta) assembled.WriteString(chunk.Delta)
_ = writeEvent(map[string]any{ _ = writeEvent(map[string]any{