mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-03 13:56:46 -04:00
refactor(audio): propagate request ctx into TTS, sound-gen, audio-transform
Same ctx-plumbing pattern applied to the rest of the audio path. CLI callers use context.Background() since there is no request scope; HTTP callers use c.Request().Context(). Assisted-by: Claude:claude-haiku-4-5 Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -40,6 +40,7 @@ type AudioTransformOutputs struct {
|
||||
// required; `referencePath` is optional (empty => backend zero-fills the
|
||||
// reference channel).
|
||||
func ModelAudioTransform(
|
||||
ctx context.Context,
|
||||
audioPath, referencePath string,
|
||||
opts AudioTransformOptions,
|
||||
loader *model.ModelLoader,
|
||||
@@ -81,7 +82,7 @@ func ModelAudioTransform(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := transformModel.AudioTransform(context.Background(), &proto.AudioTransformRequest{
|
||||
res, err := transformModel.AudioTransform(ctx, &proto.AudioTransformRequest{
|
||||
AudioPath: audioPath,
|
||||
ReferencePath: referencePath,
|
||||
Dst: dst,
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func SoundGeneration(
|
||||
ctx context.Context,
|
||||
text string,
|
||||
duration *float32,
|
||||
temperature *float32,
|
||||
@@ -101,7 +102,7 @@ func SoundGeneration(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := soundGenModel.SoundGeneration(context.Background(), req)
|
||||
res, err := soundGenModel.SoundGeneration(ctx, req)
|
||||
|
||||
if appConfig.EnableTracing {
|
||||
errStr := ""
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
)
|
||||
|
||||
func ModelTTS(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
@@ -70,7 +71,7 @@ func ModelTTS(
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
|
||||
res, err := ttsModel.TTS(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
@@ -121,6 +122,7 @@ func ModelTTS(
|
||||
}
|
||||
|
||||
func ModelTTSStream(
|
||||
ctx context.Context,
|
||||
text,
|
||||
voice,
|
||||
language string,
|
||||
@@ -172,7 +174,7 @@ func ModelTTSStream(
|
||||
var totalPCMBytes int
|
||||
snippetCapped := false
|
||||
|
||||
err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{
|
||||
err = ttsModel.TTSStream(ctx, &proto.TTSRequest{
|
||||
Text: text,
|
||||
Model: modelPath,
|
||||
Voice: voice,
|
||||
|
||||
@@ -97,7 +97,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
|
||||
inputFile = &t.InputFile
|
||||
}
|
||||
|
||||
filePath, _, err := backend.SoundGeneration(text,
|
||||
filePath, _, err := backend.SoundGeneration(context.Background(), text,
|
||||
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
|
||||
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor),
|
||||
nil, "", "", nil, "", "", "", nil,
|
||||
|
||||
@@ -62,7 +62,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
|
||||
options.Backend = t.Backend
|
||||
options.Model = t.Model
|
||||
|
||||
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
|
||||
filePath, _, err := backend.ModelTTS(context.Background(), text, t.Voice, t.Language, ml, opts, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
|
||||
bpm = &b
|
||||
}
|
||||
filePath, _, err := backend.SoundGeneration(
|
||||
c.Request().Context(),
|
||||
input.Text, input.Duration, input.Temperature, input.DoSample,
|
||||
nil, nil,
|
||||
input.Think, input.Caption, input.Lyrics, bpm, input.Keyscale,
|
||||
|
||||
@@ -37,7 +37,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
|
||||
xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID)
|
||||
|
||||
filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func AudioTransformEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
}
|
||||
}
|
||||
|
||||
out, _, err := backend.ModelAudioTransform(audioPath, referencePath, backend.AudioTransformOptions{
|
||||
out, _, err := backend.ModelAudioTransform(c.Request().Context(), audioPath, referencePath, backend.AudioTransformOptions{
|
||||
Params: params,
|
||||
}, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -59,7 +59,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Stream audio chunks as they're generated
|
||||
err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
err := backend.ModelTTSStream(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
|
||||
_, writeErr := c.Response().Write(audioChunk)
|
||||
if writeErr != nil {
|
||||
return writeErr
|
||||
@@ -75,7 +75,7 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
}
|
||||
|
||||
// Non-streaming TTS (existing behavior)
|
||||
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
filePath, _, err := backend.ModelTTS(c.Request().Context(), input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -241,7 +241,7 @@ func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, im
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) {
|
||||
return backend.ModelTTS(text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
return backend.ModelTTS(ctx, text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
|
||||
|
||||
Reference in New Issue
Block a user