mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-06 07:46:15 -04:00
feat(realtime): speechStreamer for token-streamed LLM->TTS
emitSpeech now returns raw PCM (caller base64-encodes) so streamed segments accumulate correctly. speechStreamer consumes streamed LLM tokens: it strips reasoning via the streaming ReasoningExtractor, emits a transcript delta per content fragment, and sentence-pipes content into emitSpeech so each sentence is synthesized as soon as it's ready. Handler wiring (plain-content turns) follows. Assisted-by: Claude:claude-opus-4-8 go vet Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -1728,9 +1728,8 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
// Synthesize and send the audio. With pipeline.streaming.tts enabled
|
||||
// emitSpeech forwards a response.output_audio.delta per backend PCM
|
||||
// chunk as it's produced; otherwise it sends the whole utterance as a
|
||||
// single delta. The returned base64 audio is stored on the item below.
|
||||
var err error
|
||||
audioString, err = emitSpeech(ctx, t, session, responseID, item.Assistant.ID, finalSpeech)
|
||||
// single delta. The returned PCM is stored (base64) on the item below.
|
||||
pcmAudio, err := emitSpeech(ctx, t, session, responseID, item.Assistant.ID, finalSpeech)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
xlog.Debug("TTS cancelled (barge-in)")
|
||||
@@ -1741,6 +1740,9 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
|
||||
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
|
||||
return
|
||||
}
|
||||
if !isWebRTC {
|
||||
audioString = base64.StdEncoding.EncodeToString(pcmAudio)
|
||||
}
|
||||
|
||||
if !isWebRTC {
|
||||
sendEvent(t, types.ResponseOutputAudioDoneEvent{
|
||||
|
||||
@@ -20,13 +20,12 @@ import (
|
||||
// those so a streamed reply can be split into several spoken segments that share
|
||||
// one response/item.
|
||||
//
|
||||
// It returns the base64-encoded audio (at the session output rate) accumulated
|
||||
// across all chunks, which the caller stores on the conversation item. For
|
||||
// WebRTC the audio goes over the RTP track instead, so the returned string is
|
||||
// empty.
|
||||
func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) (string, error) {
|
||||
// It returns the PCM audio (at the session output rate) accumulated across all
|
||||
// chunks, which the caller base64-encodes onto the conversation item. For WebRTC
|
||||
// the audio goes over the RTP track instead, so the returned slice is empty.
|
||||
func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) ([]byte, error) {
|
||||
if text == "" {
|
||||
return "", nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, isWebRTC := t.(*WebRTCTransport)
|
||||
@@ -70,31 +69,31 @@ func emitSpeech(ctx context.Context, t Transport, session *Session, responseID,
|
||||
|
||||
if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS() {
|
||||
if err := session.ModelInterface.TTSStream(ctx, text, session.Voice, language, sendChunk); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(wsAudio), nil
|
||||
return wsAudio, nil
|
||||
}
|
||||
|
||||
// Unary fallback: synthesize the whole utterance to a file, then emit once.
|
||||
audioFilePath, res, err := session.ModelInterface.TTS(ctx, text, session.Voice, language)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
if res != nil && !res.Success {
|
||||
return "", fmt.Errorf("tts generation failed: %s", res.Message)
|
||||
return nil, fmt.Errorf("tts generation failed: %s", res.Message)
|
||||
}
|
||||
defer func() { _ = os.Remove(audioFilePath) }()
|
||||
|
||||
audioBytes, err := os.ReadFile(audioFilePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read tts audio: %w", err)
|
||||
return nil, fmt.Errorf("read tts audio: %w", err)
|
||||
}
|
||||
pcm, sampleRate := laudio.ParseWAV(audioBytes)
|
||||
if sampleRate == 0 {
|
||||
sampleRate = session.OutputSampleRate
|
||||
}
|
||||
if err := sendChunk(pcm, sampleRate); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(wsAudio), nil
|
||||
return wsAudio, nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -40,8 +39,8 @@ var _ = Describe("emitSpeech", func() {
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(3))
|
||||
// The returned audio is the base64 of all chunks concatenated.
|
||||
Expect(audio).To(Equal(base64.StdEncoding.EncodeToString([]byte{1, 2, 3, 4, 5, 6})))
|
||||
// The returned audio is all chunks concatenated (session output rate).
|
||||
Expect(audio).To(Equal([]byte{1, 2, 3, 4, 5, 6}))
|
||||
})
|
||||
|
||||
It("sends a single output_audio.delta in unary mode", func() {
|
||||
|
||||
86
core/http/endpoints/openai/realtime_stream.go
Normal file
86
core/http/endpoints/openai/realtime_stream.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
)
|
||||
|
||||
// speechStreamer consumes streamed LLM tokens and drives the realtime output:
|
||||
// it strips reasoning incrementally, emits a transcript text delta for each
|
||||
// content fragment, and — when the pipeline streams TTS — sentence-pipes the
|
||||
// content so each completed sentence is synthesized as soon as it's ready,
|
||||
// overlapping generation, synthesis and playback.
|
||||
//
|
||||
// It is used only for plain-content turns (no tools): tool-call output can't be
|
||||
// safely spoken mid-stream, so those turns keep the buffered path.
|
||||
type speechStreamer struct {
|
||||
ctx context.Context
|
||||
t Transport
|
||||
session *Session
|
||||
responseID string
|
||||
itemID string
|
||||
|
||||
extractor *reasoning.ReasoningExtractor
|
||||
seg streamSegmenter
|
||||
audio []byte
|
||||
streamTTS bool
|
||||
err error
|
||||
}
|
||||
|
||||
func newSpeechStreamer(ctx context.Context, t Transport, session *Session, responseID, itemID, thinkingStartToken string, reasoningCfg reasoning.Config) *speechStreamer {
|
||||
return &speechStreamer{
|
||||
ctx: ctx,
|
||||
t: t,
|
||||
session: session,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
extractor: reasoning.NewReasoningExtractor(thinkingStartToken, reasoningCfg),
|
||||
streamTTS: session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS(),
|
||||
}
|
||||
}
|
||||
|
||||
// onToken handles one streamed LLM token. It is shaped to be used directly as
|
||||
// the backend token callback's text sink.
|
||||
func (s *speechStreamer) onToken(token string) {
|
||||
_, content := s.extractor.ProcessToken(token)
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
_ = s.t.SendEvent(types.ResponseOutputAudioTranscriptDeltaEvent{
|
||||
ServerEventBase: types.ServerEventBase{},
|
||||
ResponseID: s.responseID,
|
||||
ItemID: s.itemID,
|
||||
OutputIndex: 0,
|
||||
ContentIndex: 0,
|
||||
Delta: content,
|
||||
})
|
||||
if s.streamTTS {
|
||||
for _, segment := range s.seg.Push(content) {
|
||||
s.speak(segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *speechStreamer) speak(text string) {
|
||||
pcm, err := emitSpeech(s.ctx, s.t, s.session, s.responseID, s.itemID, text)
|
||||
if err != nil {
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
}
|
||||
return
|
||||
}
|
||||
s.audio = append(s.audio, pcm...)
|
||||
}
|
||||
|
||||
// finish flushes any buffered sentence to TTS and returns the full cleaned
|
||||
// content, the accumulated PCM audio, and the first error encountered (if any).
|
||||
func (s *speechStreamer) finish() (content string, audio []byte, err error) {
|
||||
if s.streamTTS {
|
||||
if rem := s.seg.Flush(); rem != "" {
|
||||
s.speak(rem)
|
||||
}
|
||||
}
|
||||
return s.extractor.CleanedContent(), s.audio, s.err
|
||||
}
|
||||
65
core/http/endpoints/openai/realtime_stream_test.go
Normal file
65
core/http/endpoints/openai/realtime_stream_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
)
|
||||
|
||||
// speechStreamer consumes streamed LLM tokens: it strips reasoning, emits a
|
||||
// transcript delta per content fragment, and sentence-pipes content into TTS so
|
||||
// audio starts before the full reply is generated.
|
||||
var _ = Describe("speechStreamer", func() {
|
||||
It("emits a transcript delta per token and speaks each completed sentence", func() {
|
||||
on := true
|
||||
m := &fakeModel{ttsStreamChunks: [][]byte{{7}}, ttsStreamRate: 24000}
|
||||
session := &Session{
|
||||
OutputSampleRate: 24000,
|
||||
ModelInterface: m,
|
||||
ModelConfig: &config.ModelConfig{
|
||||
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{TTS: &on}},
|
||||
},
|
||||
}
|
||||
t := &fakeTransport{}
|
||||
s := newSpeechStreamer(context.Background(), t, session, "resp1", "item1", "", reasoning.Config{})
|
||||
|
||||
for _, tok := range []string{"Hello", " world.", " Bye"} {
|
||||
s.onToken(tok)
|
||||
}
|
||||
content, audio, err := s.finish()
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content).To(Equal("Hello world. Bye"))
|
||||
// One transcript delta per (non-empty) token.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(3))
|
||||
// Two sentences spoken: "Hello world." mid-stream + "Bye" on flush; one
|
||||
// chunk each.
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(2))
|
||||
Expect(audio).To(Equal([]byte{7, 7}))
|
||||
})
|
||||
|
||||
It("does not synthesize audio when TTS streaming is disabled", func() {
|
||||
m := &fakeModel{ttsStreamChunks: [][]byte{{7}}, ttsStreamRate: 24000}
|
||||
session := &Session{
|
||||
OutputSampleRate: 24000,
|
||||
ModelInterface: m,
|
||||
ModelConfig: &config.ModelConfig{}, // streaming.tts off
|
||||
}
|
||||
t := &fakeTransport{}
|
||||
s := newSpeechStreamer(context.Background(), t, session, "resp1", "item1", "", reasoning.Config{})
|
||||
|
||||
s.onToken("Hello world.")
|
||||
content, audio, err := s.finish()
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content).To(Equal("Hello world."))
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(1))
|
||||
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(0))
|
||||
Expect(audio).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user