refactor(audio): propagate request ctx into TTS, sound-gen, audio-transform

Same ctx-plumbing pattern applied to the rest of the audio path. CLI
callers use context.Background() since there is no request scope; HTTP
callers use c.Request().Context().

Assisted-by: Claude:claude-haiku-4-5
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-05-07 16:00:42 +00:00
parent be5a9b90cd
commit 047a8c57a7
10 changed files with 16 additions and 11 deletions

View File

@@ -40,6 +40,7 @@ type AudioTransformOutputs struct {
// required; `referencePath` is optional (empty => backend zero-fills the
// reference channel).
func ModelAudioTransform(
ctx context.Context,
audioPath, referencePath string,
opts AudioTransformOptions,
loader *model.ModelLoader,
@@ -81,7 +82,7 @@ func ModelAudioTransform(
startTime = time.Now()
}
res, err := transformModel.AudioTransform(context.Background(), &proto.AudioTransformRequest{
res, err := transformModel.AudioTransform(ctx, &proto.AudioTransformRequest{
AudioPath: audioPath,
ReferencePath: referencePath,
Dst: dst,

View File

@@ -15,6 +15,7 @@ import (
)
func SoundGeneration(
ctx context.Context,
text string,
duration *float32,
temperature *float32,
@@ -101,7 +102,7 @@ func SoundGeneration(
startTime = time.Now()
}
res, err := soundGenModel.SoundGeneration(context.Background(), req)
res, err := soundGenModel.SoundGeneration(ctx, req)
if appConfig.EnableTracing {
errStr := ""

View File

@@ -21,6 +21,7 @@ import (
)
func ModelTTS(
ctx context.Context,
text,
voice,
language string,
@@ -70,7 +71,7 @@ func ModelTTS(
startTime = time.Now()
}
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
@@ -121,6 +122,7 @@ func ModelTTS(
}
func ModelTTSStream(
ctx context.Context,
text,
voice,
language string,
@@ -172,7 +174,7 @@ func ModelTTSStream(
var totalPCMBytes int
snippetCapped := false
err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,

View File

@@ -97,7 +97,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
inputFile = &t.InputFile
}
filePath, _, err := backend.SoundGeneration(text,
filePath, _, err := backend.SoundGeneration(context.Background(), text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor),
nil, "", "", nil, "", "", "", nil,

View File

@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
options.Backend = t.Backend
options.Model = t.Model
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
if err != nil {
return err
}

View File

@@ -44,6 +44,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
bpm = &b
}
filePath, _, err := backend.SoundGeneration(
c.Request().Context(),
input.Text, input.Duration, input.Temperature, input.DoSample,
nil, nil,
input.Think, input.Caption, input.Lyrics, bpm, input.Keyscale,

View File

@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -109,7 +109,7 @@ func AudioTransformEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
}
}
out, _, err := backend.ModelAudioTransform(audioPath, referencePath, backend.AudioTransformOptions{
out, _, err := backend.ModelAudioTransform(c.Request().Context(), audioPath, referencePath, backend.AudioTransformOptions{
Params: params,
}, ml, appConfig, *cfg)
if err != nil {

View File

@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
c.Response().Header().Set("Connection", "keep-alive")
// Stream audio chunks as they're generated
err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
_, writeErr := c.Response().Write(audioChunk)
if writeErr != nil {
return writeErr
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
}
// Non-streaming TTS (existing behavior)
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil {
return err
}

View File

@@ -241,7 +241,7 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
}
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
return backend.ModelTTS(text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
}
func (m *wrappedModel) PredictConfig() *config.ModelConfig {