feat(realtime): emitSpeech with flag-gated streaming TTS

emitSpeech synthesizes a piece of text and forwards audio to the client,
streaming one output_audio.delta per backend PCM chunk when the pipeline
sets streaming.tts, or one delta for the whole utterance otherwise. WebRTC
gets raw PCM (it resamples internally); WebSocket gets base64 PCM at the
session rate. It emits no transcript/audio-done events so a streamed reply
can be split into multiple spoken segments sharing one response.

Adds fakeModel/fakeTransport test doubles for the realtime Model/Transport
interfaces, driving streaming assertions deterministically.

Assisted-by: Claude:claude-opus-4-8 go vet
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-04 16:16:53 +00:00
parent 2ba2216ce2
commit 2c6fdd0570
3 changed files with 255 additions and 0 deletions

View File

@@ -0,0 +1,101 @@
package openai
import (
"context"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
)
// fakeTransport records the server events and audio sent to a realtime client
// so streaming behaviour can be asserted without a real WebSocket/WebRTC peer.
// It is not a *WebRTCTransport, so handler code takes the WebSocket path.
type fakeTransport struct {
events []types.ServerEvent
audio []fakeAudioChunk
}
type fakeAudioChunk struct {
pcm []byte
sampleRate int
}
func (f *fakeTransport) SendEvent(e types.ServerEvent) error {
f.events = append(f.events, e)
return nil
}
func (f *fakeTransport) ReadEvent() ([]byte, error) { return nil, nil }
func (f *fakeTransport) SendAudio(_ context.Context, pcm []byte, sampleRate int) error {
f.audio = append(f.audio, fakeAudioChunk{pcm: pcm, sampleRate: sampleRate})
return nil
}
func (f *fakeTransport) Close() error { return nil }
// countEvents returns how many recorded events have the given type.
func (f *fakeTransport) countEvents(et types.ServerEventType) int {
n := 0
for _, e := range f.events {
if e.ServerEventType() == et {
n++
}
}
return n
}
// fakeModel is a configurable Model double. TTSStream replays ttsStreamChunks
// and TranscribeStream replays transcribeDeltas, so the handler's streaming
// paths can be driven deterministically.
type fakeModel struct {
cfg *config.ModelConfig
ttsFile string
ttsStreamChunks [][]byte
ttsStreamRate int
ttsStreamErr error
transcribeDeltas []string
transcribeFinal *schema.TranscriptionResult
}
func (m *fakeModel) VAD(context.Context, *schema.VADRequest) (*schema.VADResponse, error) {
return nil, nil
}
func (m *fakeModel) Transcribe(context.Context, string, string, bool, bool, string) (*schema.TranscriptionResult, error) {
return m.transcribeFinal, nil
}
func (m *fakeModel) Predict(context.Context, schema.Messages, []string, []string, []string, func(string, backend.TokenUsage) bool, []types.ToolUnion, *types.ToolChoiceUnion, *int, *int, map[string]float64) (func() (backend.LLMResponse, error), error) {
return nil, nil
}
func (m *fakeModel) TTS(context.Context, string, string, string) (string, *proto.Result, error) {
return m.ttsFile, &proto.Result{Success: true}, nil
}
func (m *fakeModel) TTSStream(_ context.Context, _, _, _ string, onAudio func(pcm []byte, sampleRate int) error) error {
if m.ttsStreamErr != nil {
return m.ttsStreamErr
}
for _, c := range m.ttsStreamChunks {
if err := onAudio(c, m.ttsStreamRate); err != nil {
return err
}
}
return nil
}
func (m *fakeModel) TranscribeStream(_ context.Context, _, _ string, _, _ bool, _ string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
for _, d := range m.transcribeDeltas {
onDelta(d)
}
return m.transcribeFinal, nil
}
func (m *fakeModel) PredictConfig() *config.ModelConfig { return m.cfg }

View File

@@ -0,0 +1,86 @@
package openai
import (
"context"
"encoding/base64"
"fmt"
"os"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
laudio "github.com/mudler/LocalAI/pkg/audio"
"github.com/mudler/LocalAI/pkg/sound"
)
// emitSpeech synthesizes text and sends the audio to the client. When the
// pipeline opts into TTS streaming it forwards each PCM chunk as its own
// response.output_audio.delta as soon as the backend produces it; otherwise it
// synthesizes the whole utterance and sends it as a single delta.
//
// It deliberately does NOT emit transcript or audio-done events: the caller owns
// those so a streamed reply can be split into several spoken segments that share
// one response/item.
func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) error {
if text == "" {
return nil
}
_, isWebRTC := t.(*WebRTCTransport)
// sendChunk hands one PCM buffer to the transport: WebRTC consumes the raw
// PCM directly (it resamples internally); WebSocket gets base64 PCM at the
// session output rate via a JSON delta event.
sendChunk := func(pcm []byte, sampleRate int) error {
if len(pcm) == 0 {
return nil
}
if err := t.SendAudio(ctx, pcm, sampleRate); err != nil {
return err
}
if isWebRTC {
return nil
}
wsPCM := pcm
if sampleRate != 0 && sampleRate != session.OutputSampleRate {
samples := sound.BytesToInt16sLE(pcm)
resampled := sound.ResampleInt16(samples, sampleRate, session.OutputSampleRate)
wsPCM = sound.Int16toBytesLE(resampled)
}
return t.SendEvent(types.ResponseOutputAudioDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: itemID,
OutputIndex: 0,
ContentIndex: 0,
Delta: base64.StdEncoding.EncodeToString(wsPCM),
})
}
language := ""
if session.InputAudioTranscription != nil {
language = session.InputAudioTranscription.Language
}
if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS() {
return session.ModelInterface.TTSStream(ctx, text, session.Voice, language, sendChunk)
}
// Unary fallback: synthesize the whole utterance to a file, then emit once.
audioFilePath, res, err := session.ModelInterface.TTS(ctx, text, session.Voice, language)
if err != nil {
return err
}
if res != nil && !res.Success {
return fmt.Errorf("tts generation failed: %s", res.Message)
}
defer func() { _ = os.Remove(audioFilePath) }()
audioBytes, err := os.ReadFile(audioFilePath)
if err != nil {
return fmt.Errorf("read tts audio: %w", err)
}
pcm, sampleRate := laudio.ParseWAV(audioBytes)
if sampleRate == 0 {
sampleRate = session.OutputSampleRate
}
return sendChunk(pcm, sampleRate)
}

View File

@@ -0,0 +1,68 @@
package openai
import (
"context"
"os"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
laudio "github.com/mudler/LocalAI/pkg/audio"
)
// emitSpeech synthesizes a piece of text and forwards the audio to the client,
// streaming a delta per TTS chunk when the pipeline opts in, or sending the
// whole utterance as one delta otherwise.
var _ = Describe("emitSpeech", func() {
ttsOn := true
streamingSession := func(m Model) *Session {
return &Session{
OutputSampleRate: 24000,
ModelInterface: m,
ModelConfig: &config.ModelConfig{
Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{TTS: &ttsOn}},
},
}
}
It("streams one output_audio.delta per TTS chunk when streaming is enabled", func() {
m := &fakeModel{
ttsStreamChunks: [][]byte{{1, 2}, {3, 4}, {5, 6}},
ttsStreamRate: 24000,
}
t := &fakeTransport{}
err := emitSpeech(context.Background(), t, streamingSession(m), "resp1", "item1", "Hello there.")
Expect(err).ToNot(HaveOccurred())
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(3))
})
It("sends a single output_audio.delta in unary mode", func() {
// A minimal real WAV file for the unary TTS path to read + parse.
f, err := os.CreateTemp("", "emit-*.wav")
Expect(err).ToNot(HaveOccurred())
defer os.Remove(f.Name())
pcm := make([]byte, 320) // 160 samples of silence
hdr := laudio.NewWAVHeader(uint32(len(pcm)))
Expect(hdr.Write(f)).To(Succeed())
_, err = f.Write(pcm)
Expect(err).ToNot(HaveOccurred())
Expect(f.Close()).To(Succeed())
session := &Session{
OutputSampleRate: 24000,
ModelInterface: &fakeModel{ttsFile: f.Name()},
ModelConfig: &config.ModelConfig{}, // streaming off
}
t := &fakeTransport{}
err = emitSpeech(context.Background(), t, session, "resp1", "item1", "Hello there.")
Expect(err).ToNot(HaveOccurred())
Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(1))
})
})