mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-19 14:17:21 -04:00
refactor(transcription): propagate request ctx through ModelTranscription*
Replaces context.Background() with the HTTP request ctx so client disconnects start cancelling the gRPC call. No backend-side abort wiring yet — that comes in a later commit. Pure plumbing. Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -57,8 +57,8 @@ func loadTranscriptionModel(ml *model.ModelLoader, modelConfig config.ModelConfi
|
|||||||
return transcriptionModel, nil
|
return transcriptionModel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
func ModelTranscription(ctx context.Context, audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||||
return ModelTranscriptionWithOptions(TranscriptionRequest{
|
return ModelTranscriptionWithOptions(ctx, TranscriptionRequest{
|
||||||
Audio: audio,
|
Audio: audio,
|
||||||
Language: language,
|
Language: language,
|
||||||
Translate: translate,
|
Translate: translate,
|
||||||
@@ -67,7 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt
|
|||||||
}, ml, modelConfig, appConfig)
|
}, ml, modelConfig, appConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
func ModelTranscriptionWithOptions(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -82,7 +82,7 @@ func ModelTranscriptionWithOptions(req TranscriptionRequest, ml *model.ModelLoad
|
|||||||
audioSnippet = trace.AudioSnippet(req.Audio)
|
audioSnippet = trace.AudioSnippet(req.Audio)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := transcriptionModel.AudioTranscription(context.Background(), req.toProto(uint32(*modelConfig.Threads)))
|
r, err := transcriptionModel.AudioTranscription(ctx, req.toProto(uint32(*modelConfig.Threads)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if appConfig.EnableTracing {
|
if appConfig.EnableTracing {
|
||||||
errData := map[string]any{
|
errData := map[string]any{
|
||||||
@@ -149,7 +149,7 @@ type TranscriptionStreamChunk struct {
|
|||||||
// invokes onChunk for each event the backend produces. Backends that don't
|
// invokes onChunk for each event the backend produces. Backends that don't
|
||||||
// support real streaming should still emit one terminal event with Final set,
|
// support real streaming should still emit one terminal event with Final set,
|
||||||
// which the HTTP layer turns into a single delta + done SSE pair.
|
// which the HTTP layer turns into a single delta + done SSE pair.
|
||||||
func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
func ModelTranscriptionStream(ctx context.Context, req TranscriptionRequest, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, onChunk func(TranscriptionStreamChunk)) error {
|
||||||
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
transcriptionModel, err := loadTranscriptionModel(ml, modelConfig, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -158,7 +158,7 @@ func ModelTranscriptionStream(req TranscriptionRequest, ml *model.ModelLoader, m
|
|||||||
pbReq := req.toProto(uint32(*modelConfig.Threads))
|
pbReq := req.toProto(uint32(*modelConfig.Threads))
|
||||||
pbReq.Stream = true
|
pbReq.Stream = true
|
||||||
|
|
||||||
return transcriptionModel.AudioTranscriptionStream(context.Background(), pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
return transcriptionModel.AudioTranscriptionStream(ctx, pbReq, func(chunk *proto.TranscriptStreamResponse) {
|
||||||
if chunk == nil {
|
if chunk == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -187,7 +187,7 @@ func transcriptResultFromProto(r *proto.TranscriptResult) *schema.TranscriptionR
|
|||||||
}
|
}
|
||||||
var words []schema.TranscriptionWord
|
var words []schema.TranscriptionWord
|
||||||
for _, w := range s.Words {
|
for _, w := range s.Words {
|
||||||
var word = schema.TranscriptionWord {
|
var word = schema.TranscriptionWord{
|
||||||
Start: time.Duration(w.Start),
|
Start: time.Duration(w.Start),
|
||||||
End: time.Duration(w.End),
|
End: time.Duration(w.End),
|
||||||
Text: w.Text,
|
Text: w.Text,
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADReques
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
|
func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
|
||||||
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
|
func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
|
||||||
@@ -82,7 +82,7 @@ func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*sc
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
|
func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) {
|
||||||
return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
return backend.ModelTranscription(ctx, audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
|
func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) {
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
|||||||
return streamTranscription(c, req, ml, *config, appConfig)
|
return streamTranscription(c, req, ml, *config, appConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
tr, err := backend.ModelTranscriptionWithOptions(req, ml, *config, appConfig)
|
tr, err := backend.ModelTranscriptionWithOptions(c.Request().Context(), req, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log before returning so the underlying error survives. Echo's
|
// Log before returning so the underlying error survives. Echo's
|
||||||
// error handler turns this into a 500 with a generic body, which
|
// error handler turns this into a 500 with a generic body, which
|
||||||
@@ -157,16 +157,16 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
|||||||
Words: []schema.TranscriptionWordSeconds{},
|
Words: []schema.TranscriptionWordSeconds{},
|
||||||
Segments: []schema.TranscriptionSegmentSeconds{},
|
Segments: []schema.TranscriptionSegmentSeconds{},
|
||||||
}
|
}
|
||||||
for _, word := range(tr.Words) {
|
for _, word := range tr.Words {
|
||||||
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
|
trs.Words = append(trs.Words, schema.TranscriptionWordSeconds{
|
||||||
Start: word.Start.Seconds(),
|
Start: word.Start.Seconds(),
|
||||||
End: word.End.Seconds(),
|
End: word.End.Seconds(),
|
||||||
Text: word.Text,
|
Text: word.Text,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, seg := range(tr.Segments) {
|
for _, seg := range tr.Segments {
|
||||||
segWords := []schema.TranscriptionWordSeconds{}
|
segWords := []schema.TranscriptionWordSeconds{}
|
||||||
for _, word := range(seg.Words) {
|
for _, word := range seg.Words {
|
||||||
segWords = append(segWords, schema.TranscriptionWordSeconds{
|
segWords = append(segWords, schema.TranscriptionWordSeconds{
|
||||||
Start: word.Start.Seconds(),
|
Start: word.Start.Seconds(),
|
||||||
End: word.End.Seconds(),
|
End: word.End.Seconds(),
|
||||||
@@ -216,7 +216,7 @@ func streamTranscription(c echo.Context, req backend.TranscriptionRequest, ml *m
|
|||||||
var assembled strings.Builder
|
var assembled strings.Builder
|
||||||
var finalResult *schema.TranscriptionResult
|
var finalResult *schema.TranscriptionResult
|
||||||
|
|
||||||
err := backend.ModelTranscriptionStream(req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
err := backend.ModelTranscriptionStream(c.Request().Context(), req, ml, config, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
||||||
if chunk.Delta != "" {
|
if chunk.Delta != "" {
|
||||||
assembled.WriteString(chunk.Delta)
|
assembled.WriteString(chunk.Delta)
|
||||||
_ = writeEvent(map[string]any{
|
_ = writeEvent(map[string]any{
|
||||||
|
|||||||
Reference in New Issue
Block a user