mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
refactor(transcription): propagate request ctx through ModelTranscription*
Replaces context.Background() with the HTTP request ctx so client disconnects start cancelling the gRPC call. No backend-side abort wiring yet — that comes in a later commit. Pure plumbing. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -57,8 +57,8 @@ func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfi
|
||||
return transcriptionModel, nil
|
||||
}
|
||||
|
||||
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
return ModelTranscriptionWithOptions(TranscriptionRequest{
|
||||
func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{
|
||||
Audio: audio,
|
||||
Language: language,
|
||||
Translate: translate,
|
||||
@@ -67,7 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
||||
}, ml, modelConfig, appConfig)
|
||||
}
|
||||
|
||||
func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -82,7 +82,7 @@ func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoad
|
||||
audioSnippet = trace.AudioSnippet(req.Audio)
|
||||
}
|
||||
|
||||
r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads)))
|
||||
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
|
||||
if err != nil {
|
||||
if appConfig.EnableTracing {
|
||||
errData := map[string]any{
|
||||
@@ -149,7 +149,7 @@ type TranscriptionStreamChunk struct {
|
||||
// invokes onChunk for each event the backend produces. Backends that don't
|
||||
// support real streaming should still emit one terminal event with Final set,
|
||||
// which the HTTP layer turns into a single delta + done SSE pair.
|
||||
func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -158,7 +158,7 @@ func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, m
|
||||
pbReq := req.toProto(uint32(*modelConfig.Threads))
|
||||
pbReq.Stream = true
|
||||
|
||||
return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
||||
return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
||||
if chunk == nil {
|
||||
return
|
||||
}
|
||||
@@ -187,12 +187,12 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
|
||||
}
|
||||
var words []schema.TranscriptionWord
|
||||
for _, w := range s.Words {
|
||||
var word = schema.TranscriptionWord {
|
||||
var word = schema.TranscriptionWord{
|
||||
Start: time.Duration(w.Start),
|
||||
End: time.Duration(w.End),
|
||||
Text: w.Text,
|
||||
}
|
||||
words = append(words, word)
|
||||
words = append(words, word)
|
||||
tr.Words = append(tr.Words, word)
|
||||
}
|
||||
tr.Segments = append(tr.Segments,
|
||||
|
||||
@@ -62,7 +62,7 @@ func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADReques
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
|
||||
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||
return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
|
||||
@@ -82,7 +82,7 @@ func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*sc
|
||||
}
|
||||
|
||||
func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
|
||||
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||
return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
|
||||
|
||||
@@ -126,7 +126,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
return streamTranscription(c, req, ml, *config, appConfig)
|
||||
}
|
||||
|
||||
tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig)
|
||||
tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
// Log before returning so the underlying error survives. Echo's
|
||||
// error handler turns this into a 500 with a generic body, which
|
||||
@@ -157,16 +157,16 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
Words: []schema.TranscriptionWordSeconds{},
|
||||
Segments: []schema.TranscriptionSegmentSeconds{},
|
||||
}
|
||||
for _, word := range(tr.Words) {
|
||||
for _, word := range tr.Words {
|
||||
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
|
||||
Start: word.Start.Seconds(),
|
||||
End: word.End.Seconds(),
|
||||
Text: word.Text,
|
||||
})
|
||||
}
|
||||
for _, seg := range(tr.Segments) {
|
||||
for _, seg := range tr.Segments {
|
||||
segWords := []schema.TranscriptionWordSeconds{}
|
||||
for _, word := range(seg.Words) {
|
||||
for _, word := range seg.Words {
|
||||
segWords = append(segWords, schema.TranscriptionWordSeconds{
|
||||
Start: word.Start.Seconds(),
|
||||
End: word.End.Seconds(),
|
||||
@@ -174,7 +174,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
})
|
||||
}
|
||||
trs.Segments = append(trs.Segments, schema.TranscriptionSegmentSeconds{
|
||||
Id: seg.Id,
|
||||
Id: seg.Id,
|
||||
Start: seg.Start.Seconds(),
|
||||
End: seg.End.Seconds(),
|
||||
Text: seg.Text,
|
||||
@@ -216,7 +216,7 @@ func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *m
|
||||
var assembled strings.Builder
|
||||
var finalResult *schema.TranscriptionResult
|
||||
|
||||
err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
||||
err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
||||
if chunk.Delta != "" {
|
||||
assembled.WriteString(chunk.Delta)
|
||||
_ = writeEvent(map[string]any{
|
||||
|
||||
Reference in New Issue
Block a user