mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-06 07:46:15 -04:00
feat(realtime): route response audio through emitSpeech (streaming TTS)
Replace the inline unary TTS block in the response handler with emitSpeech, which streams a response.output_audio.delta per backend PCM chunk when pipeline.streaming.tts is set and otherwise preserves the single-delta unary behaviour. emitSpeech returns the accumulated base64 audio, stored on the conversation item as before. Transcript and audio-done events stay in the handler so later per-segment streaming can reuse emitSpeech. Assisted-by: Claude:claude-opus-4-8 go vet Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -1719,64 +1719,7 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
return
|
||||
}
|
||||
|
||||
audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
xlog.Debug("TTS cancelled (barge-in)")
|
||||
sendCancelledResponse()
|
||||
return
|
||||
}
|
||||
xlog.Error("TTS failed", "error", err)
|
||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
if !res.Success {
|
||||
xlog.Error("TTS failed", "message", res.Message)
|
||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
defer func() { _ = os.Remove(audioFilePath) }()
|
||||
|
||||
audioBytes, err := os.ReadFile(audioFilePath)
|
||||
if err != nil {
|
||||
xlog.Error("failed to read TTS file", "error", err)
|
||||
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
|
||||
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
|
||||
if ttsSampleRate == 0 {
|
||||
ttsSampleRate = localSampleRate
|
||||
}
|
||||
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)
|
||||
|
||||
// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
|
||||
// Opus encoder, which resamples to 48kHz internally. This avoids a
|
||||
// lossy intermediate resample through 16kHz.
|
||||
// XXX: This is a noop in websocket mode; it's included in the JSON instead
|
||||
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
xlog.Debug("Audio playback cancelled (barge-in)")
|
||||
sendCancelledResponse()
|
||||
return
|
||||
}
|
||||
xlog.Error("failed to send audio via transport", "error", err)
|
||||
}
|
||||
|
||||
// For WebSocket clients, resample to the session's output rate and
|
||||
// deliver audio as base64 in JSON events. WebRTC clients already
|
||||
// received audio over the RTP track, so skip the base64 payload.
|
||||
if !isWebRTC {
|
||||
wsPCM := pcmData
|
||||
if ttsSampleRate != session.OutputSampleRate {
|
||||
samples := sound.BytesToInt16sLE(pcmData)
|
||||
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
|
||||
wsPCM = sound.Int16toBytesLE(resampled)
|
||||
}
|
||||
audioString = base64.StdEncoding.EncodeToString(wsPCM)
|
||||
}
|
||||
|
||||
// Transcript of the spoken reply (the audio's text).
|
||||
sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
@@ -1794,15 +1737,24 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
Transcript: finalSpeech,
|
||||
})
|
||||
|
||||
// Synthesize and send the audio. With pipeline.streaming.tts enabled
|
||||
// emitSpeech forwards a response.output_audio.delta per backend PCM
|
||||
// chunk as it's produced; otherwise it sends the whole utterance as a
|
||||
// single delta. The returned base64 audio is stored on the item below.
|
||||
var err error
|
||||
audioString, err = emitSpeech(ctx, t, session, responseID, item.Assistant.ID, finalSpeech)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
xlog.Debug("TTS cancelled (barge-in)")
|
||||
sendCancelledResponse()
|
||||
return
|
||||
}
|
||||
xlog.Error("TTS failed", "error", err)
|
||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
|
||||
if !isWebRTC {
|
||||
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
ItemID: item.Assistant.ID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Delta: audioString,
|
||||
})
|
||||
sendEvent(t, types.ResponseOutputAudioDoneEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
|
||||
@@ -19,13 +19,20 @@ import (
|
||||
// It deliberately does NOT emit transcript or audio-done events: the caller owns
|
||||
// those so a streamed reply can be split into several spoken segments that share
|
||||
// one response/item.
|
||||
func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) error {
|
||||
//
|
||||
// It returns the base64-encoded audio (at the session output rate) accumulated
|
||||
// across all chunks, which the caller stores on the conversation item. For
|
||||
// WebRTC the audio goes over the RTP track instead, so the returned string is
|
||||
// empty.
|
||||
func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) (string, error) {
|
||||
if text == "" {
|
||||
return nil
|
||||
return "", nil
|
||||
}
|
||||
|
||||
_, isWebRTC := t.(*WebRTCTransport)
|
||||
|
||||
var wsAudio []byte // PCM at the session output rate, accumulated for the item record
|
||||
|
||||
// sendChunk hands one PCM buffer to the transport: WebRTC consumes the raw
|
||||
// PCM directly (it resamples internally); WebSocket gets base64 PCM at the
|
||||
// session output rate via a JSON delta event.
|
||||
@@ -45,6 +52,7 @@ func emitSpeech(ctx context.Context, t Transport, session *Session, responseID,
|
||||
resampled := sound.ResampleInt16(samples, sampleRate, session.OutputSampleRate)
|
||||
wsPCM = sound.Int16toBytesLE(resampled)
|
||||
}
|
||||
wsAudio = append(wsAudio, wsPCM...)
|
||||
return t.SendEvent(types.ResponseOutputAudioDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: responseID,
|
||||
@@ -61,26 +69,32 @@ func emitSpeech(ctx context.Context, t Transport, session *Session, responseID,
|
||||
}
|
||||
|
||||
if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS() {
|
||||
return session.ModelInterface.TTSStream(ctx, text, session.Voice, language, sendChunk)
|
||||
if err := session.ModelInterface.TTSStream(ctx, text, session.Voice, language, sendChunk); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(wsAudio), nil
|
||||
}
|
||||
|
||||
// Unary fallback: synthesize the whole utterance to a file, then emit once.
|
||||
audioFilePath, res, err := session.ModelInterface.TTS(ctx, text, session.Voice, language)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
if res != nil && !res.Success {
|
||||
return fmt.Errorf("tts generation failed: %s", res.Message)
|
||||
return "", fmt.Errorf("tts generation failed: %s", res.Message)
|
||||
}
|
||||
defer func() { _ = os.Remove(audioFilePath) }()
|
||||
|
||||
audioBytes, err := os.ReadFile(audioFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tts audio: %w", err)
|
||||
return "", fmt.Errorf("read tts audio: %w", err)
|
||||
}
|
||||
pcm, sampleRate := laudio.ParseWAV(audioBytes)
|
||||
if sampleRate == 0 {
|
||||
sampleRate = session.OutputSampleRate
|
||||
}
|
||||
return sendChunk(pcm, sampleRate)
|
||||
if err := sendChunk(pcm, sampleRate); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(wsAudio), nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -35,10 +36,12 @@ var _ = Describe("emitSpeech", func() {
|
||||
}
|
||||
t := &fakeTransport{}
|
||||
|
||||
err := emitSpeech(context.Background(), t, streamingSession(m), "resp1", "item1", "Hello there.")
|
||||
audio, err := emitSpeech(context.Background(), t, streamingSession(m), "resp1", "item1", "Hello there.")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(3))
|
||||
// The returned audio is the base64 of all chunks concatenated.
|
||||
Expect(audio).To(Equal(base64.StdEncoding.EncodeToString([]byte{1, 2, 3, 4, 5, 6})))
|
||||
})
|
||||
|
||||
It("sends a single output_audio.delta in unary mode", func() {
|
||||
@@ -60,7 +63,7 @@ var _ = Describe("emitSpeech", func() {
|
||||
}
|
||||
t := &fakeTransport{}
|
||||
|
||||
err = emitSpeech(context.Background(), t, session, "resp1", "item1", "Hello there.")
|
||||
_, err = emitSpeech(context.Background(), t, session, "resp1", "item1", "Hello there.")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(1))
|
||||
|
||||
Reference in New Issue
Block a user