diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index b14f64b4d..14971ff1f 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -1728,9 +1728,8 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa // Synthesize and send the audio. With pipeline.streaming.tts enabled // emitSpeech forwards a response.output_audio.delta per backend PCM // chunk as it's produced; otherwise it sends the whole utterance as a - // single delta. The returned base64 audio is stored on the item below. - var err error - audioString, err = emitSpeech(ctx, t, session, responseID, item.Assistant.ID, finalSpeech) + // single delta. The returned PCM is stored (base64) on the item below. + pcmAudio, err := emitSpeech(ctx, t, session, responseID, item.Assistant.ID, finalSpeech) if err != nil { if ctx.Err() != nil { xlog.Debug("TTS cancelled (barge-in)") @@ -1741,6 +1740,9 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID) return } + if !isWebRTC { + audioString = base64.StdEncoding.EncodeToString(pcmAudio) + } if !isWebRTC { sendEvent(t, types.ResponseOutputAudioDoneEvent{ diff --git a/core/http/endpoints/openai/realtime_speech.go b/core/http/endpoints/openai/realtime_speech.go index 830777371..2b98b1b4e 100644 --- a/core/http/endpoints/openai/realtime_speech.go +++ b/core/http/endpoints/openai/realtime_speech.go @@ -20,13 +20,12 @@ import ( // those so a streamed reply can be split into several spoken segments that share // one response/item. // -// It returns the base64-encoded audio (at the session output rate) accumulated -// across all chunks, which the caller stores on the conversation item. For -// WebRTC the audio goes over the RTP track instead, so the returned string is -// empty. -func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) (string, error) { +// It returns the PCM audio (at the session output rate) accumulated across all +// chunks, which the caller base64-encodes onto the conversation item. For WebRTC +// the audio goes over the RTP track instead, so the returned slice is empty. +func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) ([]byte, error) { if text == "" { - return "", nil + return nil, nil } _, isWebRTC := t.(*WebRTCTransport) @@ -70,31 +69,31 @@ func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS() { if err := session.ModelInterface.TTSStream(ctx, text, session.Voice, language, sendChunk); err != nil { - return "", err + return nil, err } - return base64.StdEncoding.EncodeToString(wsAudio), nil + return wsAudio, nil } // Unary fallback: synthesize the whole utterance to a file, then emit once. audioFilePath, res, err := session.ModelInterface.TTS(ctx, text, session.Voice, language) if err != nil { - return "", err + return nil, err } if res != nil && !res.Success { - return "", fmt.Errorf("tts generation failed: %s", res.Message) + return nil, fmt.Errorf("tts generation failed: %s", res.Message) } defer func() { _ = os.Remove(audioFilePath) }() audioBytes, err := os.ReadFile(audioFilePath) if err != nil { - return "", fmt.Errorf("read tts audio: %w", err) + return nil, fmt.Errorf("read tts audio: %w", err) } pcm, sampleRate := laudio.ParseWAV(audioBytes) if sampleRate == 0 { sampleRate = session.OutputSampleRate } if err := sendChunk(pcm, sampleRate); err != nil { - return "", err + return nil, err } - return base64.StdEncoding.EncodeToString(wsAudio), nil + return wsAudio, nil } diff --git a/core/http/endpoints/openai/realtime_speech_test.go b/core/http/endpoints/openai/realtime_speech_test.go index 268e6a877..6d09a7217 100644 --- a/core/http/endpoints/openai/realtime_speech_test.go +++ b/core/http/endpoints/openai/realtime_speech_test.go @@ -2,7 +2,6 @@ package openai import ( "context" - "encoding/base64" "os" . "github.com/onsi/ginkgo/v2" @@ -40,8 +39,8 @@ var _ = Describe("emitSpeech", func() { Expect(err).ToNot(HaveOccurred()) Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(3)) - // The returned audio is the base64 of all chunks concatenated. - Expect(audio).To(Equal(base64.StdEncoding.EncodeToString([]byte{1, 2, 3, 4, 5, 6}))) + // The returned audio is all chunks concatenated (session output rate). + Expect(audio).To(Equal([]byte{1, 2, 3, 4, 5, 6})) }) It("sends a single output_audio.delta in unary mode", func() { diff --git a/core/http/endpoints/openai/realtime_stream.go b/core/http/endpoints/openai/realtime_stream.go new file mode 100644 index 000000000..aa3f31d7d --- /dev/null +++ b/core/http/endpoints/openai/realtime_stream.go @@ -0,0 +1,86 @@ +package openai + +import ( + "context" + + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/LocalAI/pkg/reasoning" +) + +// speechStreamer consumes streamed LLM tokens and drives the realtime output: +// it strips reasoning incrementally, emits a transcript text delta for each +// content fragment, and — when the pipeline streams TTS — sentence-pipes the +// content so each completed sentence is synthesized as soon as it's ready, +// overlapping generation, synthesis and playback. +// +// It is used only for plain-content turns (no tools): tool-call output can't be +// safely spoken mid-stream, so those turns keep the buffered path. +type speechStreamer struct { + ctx context.Context + t Transport + session *Session + responseID string + itemID string + + extractor *reasoning.ReasoningExtractor + seg streamSegmenter + audio []byte + streamTTS bool + err error +} + +func newSpeechStreamer(ctx context.Context, t Transport, session *Session, responseID, itemID, thinkingStartToken string, reasoningCfg reasoning.Config) *speechStreamer { + return &speechStreamer{ + ctx: ctx, + t: t, + session: session, + responseID: responseID, + itemID: itemID, + extractor: reasoning.NewReasoningExtractor(thinkingStartToken, reasoningCfg), + streamTTS: session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS(), + } +} + +// onToken handles one streamed LLM token. It is shaped to be used directly as +// the backend token callback's text sink. +func (s *speechStreamer) onToken(token string) { + _, content := s.extractor.ProcessToken(token) + if content == "" { + return + } + _ = s.t.SendEvent(types.ResponseOutputAudioTranscriptDeltaEvent{ + ServerEventBase: types.ServerEventBase{}, + ResponseID: s.responseID, + ItemID: s.itemID, + OutputIndex: 0, + ContentIndex: 0, + Delta: content, + }) + if s.streamTTS { + for _, segment := range s.seg.Push(content) { + s.speak(segment) + } + } +} + +func (s *speechStreamer) speak(text string) { + pcm, err := emitSpeech(s.ctx, s.t, s.session, s.responseID, s.itemID, text) + if err != nil { + if s.err == nil { + s.err = err + } + return + } + s.audio = append(s.audio, pcm...) +} + +// finish flushes any buffered sentence to TTS and returns the full cleaned +// content, the accumulated PCM audio, and the first error encountered (if any). +func (s *speechStreamer) finish() (content string, audio []byte, err error) { + if s.streamTTS { + if rem := s.seg.Flush(); rem != "" { + s.speak(rem) + } + } + return s.extractor.CleanedContent(), s.audio, s.err +} diff --git a/core/http/endpoints/openai/realtime_stream_test.go b/core/http/endpoints/openai/realtime_stream_test.go new file mode 100644 index 000000000..a6d233175 --- /dev/null +++ b/core/http/endpoints/openai/realtime_stream_test.go @@ -0,0 +1,65 @@ +package openai + +import ( + "context" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/LocalAI/pkg/reasoning" +) + +// speechStreamer consumes streamed LLM tokens: it strips reasoning, emits a +// transcript delta per content fragment, and sentence-pipes content into TTS so +// audio starts before the full reply is generated. +var _ = Describe("speechStreamer", func() { + It("emits a transcript delta per token and speaks each completed sentence", func() { + on := true + m := &fakeModel{ttsStreamChunks: [][]byte{{7}}, ttsStreamRate: 24000} + session := &Session{ + OutputSampleRate: 24000, + ModelInterface: m, + ModelConfig: &config.ModelConfig{ + Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{TTS: &on}}, + }, + } + t := &fakeTransport{} + s := newSpeechStreamer(context.Background(), t, session, "resp1", "item1", "", reasoning.Config{}) + + for _, tok := range []string{"Hello", " world.", " Bye"} { + s.onToken(tok) + } + content, audio, err := s.finish() + + Expect(err).ToNot(HaveOccurred()) + Expect(content).To(Equal("Hello world. Bye")) + // One transcript delta per (non-empty) token. + Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(3)) + // Two sentences spoken: "Hello world." mid-stream + "Bye" on flush; one + // chunk each. + Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(2)) + Expect(audio).To(Equal([]byte{7, 7})) + }) + + It("does not synthesize audio when TTS streaming is disabled", func() { + m := &fakeModel{ttsStreamChunks: [][]byte{{7}}, ttsStreamRate: 24000} + session := &Session{ + OutputSampleRate: 24000, + ModelInterface: m, + ModelConfig: &config.ModelConfig{}, // streaming.tts off + } + t := &fakeTransport{} + s := newSpeechStreamer(context.Background(), t, session, "resp1", "item1", "", reasoning.Config{}) + + s.onToken("Hello world.") + content, audio, err := s.finish() + + Expect(err).ToNot(HaveOccurred()) + Expect(content).To(Equal("Hello world.")) + Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioTranscriptDelta)).To(Equal(1)) + Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(0)) + Expect(audio).To(BeEmpty()) + }) +})