diff --git a/core/http/endpoints/openai/realtime_doubles_test.go b/core/http/endpoints/openai/realtime_doubles_test.go new file mode 100644 index 000000000..2a54f3dbe --- /dev/null +++ b/core/http/endpoints/openai/realtime_doubles_test.go @@ -0,0 +1,101 @@ +package openai + +import ( + "context" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +// fakeTransport records the server events and audio sent to a realtime client +// so streaming behaviour can be asserted without a real WebSocket/WebRTC peer. +// It is not a *WebRTCTransport, so handler code takes the WebSocket path. +type fakeTransport struct { + events []types.ServerEvent + audio []fakeAudioChunk +} + +type fakeAudioChunk struct { + pcm []byte + sampleRate int +} + +func (f *fakeTransport) SendEvent(e types.ServerEvent) error { + f.events = append(f.events, e) + return nil +} + +func (f *fakeTransport) ReadEvent() ([]byte, error) { return nil, nil } + +func (f *fakeTransport) SendAudio(_ context.Context, pcm []byte, sampleRate int) error { + f.audio = append(f.audio, fakeAudioChunk{pcm: pcm, sampleRate: sampleRate}) + return nil +} + +func (f *fakeTransport) Close() error { return nil } + +// countEvents returns how many recorded events have the given type. +func (f *fakeTransport) countEvents(et types.ServerEventType) int { + n := 0 + for _, e := range f.events { + if e.ServerEventType() == et { + n++ + } + } + return n +} + +// fakeModel is a configurable Model double. TTSStream replays ttsStreamChunks +// and TranscribeStream replays transcribeDeltas, so the handler's streaming +// paths can be driven deterministically. +type fakeModel struct { + cfg *config.ModelConfig + + ttsFile string + ttsStreamChunks [][]byte + ttsStreamRate int + ttsStreamErr error + + transcribeDeltas []string + transcribeFinal *schema.TranscriptionResult +} + +func (m *fakeModel) VAD(context.Context, *schema.VADRequest) (*schema.VADResponse, error) { + return nil, nil +} + +func (m *fakeModel) Transcribe(context.Context, string, string, bool, bool, string) (*schema.TranscriptionResult, error) { + return m.transcribeFinal, nil +} + +func (m *fakeModel) Predict(context.Context, schema.Messages, []string, []string, []string, func(string, backend.TokenUsage) bool, []types.ToolUnion, *types.ToolChoiceUnion, *int, *int, map[string]float64) (func() (backend.LLMResponse, error), error) { + return nil, nil +} + +func (m *fakeModel) TTS(context.Context, string, string, string) (string, *proto.Result, error) { + return m.ttsFile, &proto.Result{Success: true}, nil +} + +func (m *fakeModel) TTSStream(_ context.Context, _, _, _ string, onAudio func(pcm []byte, sampleRate int) error) error { + if m.ttsStreamErr != nil { + return m.ttsStreamErr + } + for _, c := range m.ttsStreamChunks { + if err := onAudio(c, m.ttsStreamRate); err != nil { + return err + } + } + return nil +} + +func (m *fakeModel) TranscribeStream(_ context.Context, _, _ string, _, _ bool, _ string, onDelta func(text string)) (*schema.TranscriptionResult, error) { + for _, d := range m.transcribeDeltas { + onDelta(d) + } + return m.transcribeFinal, nil +} + +func (m *fakeModel) PredictConfig() *config.ModelConfig { return m.cfg } diff --git a/core/http/endpoints/openai/realtime_speech.go b/core/http/endpoints/openai/realtime_speech.go new file mode 100644 index 000000000..b8429f7ca --- /dev/null +++ b/core/http/endpoints/openai/realtime_speech.go @@ -0,0 +1,86 @@ +package openai + +import ( + "context" + "encoding/base64" + "fmt" + "os" + + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + laudio "github.com/mudler/LocalAI/pkg/audio" + "github.com/mudler/LocalAI/pkg/sound" +) + +// emitSpeech synthesizes text and sends the audio to the client. When the +// pipeline opts into TTS streaming it forwards each PCM chunk as its own +// response.output_audio.delta as soon as the backend produces it; otherwise it +// synthesizes the whole utterance and sends it as a single delta. +// +// It deliberately does NOT emit transcript or audio-done events: the caller owns +// those so a streamed reply can be split into several spoken segments that share +// one response/item. +func emitSpeech(ctx context.Context, t Transport, session *Session, responseID, itemID, text string) error { + if text == "" { + return nil + } + + _, isWebRTC := t.(*WebRTCTransport) + + // sendChunk hands one PCM buffer to the transport: WebRTC consumes the raw + // PCM directly (it resamples internally); WebSocket gets base64 PCM at the + // session output rate via a JSON delta event. + sendChunk := func(pcm []byte, sampleRate int) error { + if len(pcm) == 0 { + return nil + } + if err := t.SendAudio(ctx, pcm, sampleRate); err != nil { + return err + } + if isWebRTC { + return nil + } + wsPCM := pcm + if sampleRate != 0 && sampleRate != session.OutputSampleRate { + samples := sound.BytesToInt16sLE(pcm) + resampled := sound.ResampleInt16(samples, sampleRate, session.OutputSampleRate) + wsPCM = sound.Int16toBytesLE(resampled) + } + return t.SendEvent(types.ResponseOutputAudioDeltaEvent{ + ServerEventBase: types.ServerEventBase{}, + ResponseID: responseID, + ItemID: itemID, + OutputIndex: 0, + ContentIndex: 0, + Delta: base64.StdEncoding.EncodeToString(wsPCM), + }) + } + + language := "" + if session.InputAudioTranscription != nil { + language = session.InputAudioTranscription.Language + } + + if session.ModelConfig != nil && session.ModelConfig.Pipeline.StreamTTS() { + return session.ModelInterface.TTSStream(ctx, text, session.Voice, language, sendChunk) + } + + // Unary fallback: synthesize the whole utterance to a file, then emit once. + audioFilePath, res, err := session.ModelInterface.TTS(ctx, text, session.Voice, language) + if err != nil { + return err + } + if res != nil && !res.Success { + return fmt.Errorf("tts generation failed: %s", res.Message) + } + defer func() { _ = os.Remove(audioFilePath) }() + + audioBytes, err := os.ReadFile(audioFilePath) + if err != nil { + return fmt.Errorf("read tts audio: %w", err) + } + pcm, sampleRate := laudio.ParseWAV(audioBytes) + if sampleRate == 0 { + sampleRate = session.OutputSampleRate + } + return sendChunk(pcm, sampleRate) +} diff --git a/core/http/endpoints/openai/realtime_speech_test.go b/core/http/endpoints/openai/realtime_speech_test.go new file mode 100644 index 000000000..1e7da36c7 --- /dev/null +++ b/core/http/endpoints/openai/realtime_speech_test.go @@ -0,0 +1,68 @@ +package openai + +import ( + "context" + "os" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + laudio "github.com/mudler/LocalAI/pkg/audio" +) + +// emitSpeech synthesizes a piece of text and forwards the audio to the client, +// streaming a delta per TTS chunk when the pipeline opts in, or sending the +// whole utterance as one delta otherwise. +var _ = Describe("emitSpeech", func() { + ttsOn := true + + streamingSession := func(m Model) *Session { + return &Session{ + OutputSampleRate: 24000, + ModelInterface: m, + ModelConfig: &config.ModelConfig{ + Pipeline: config.Pipeline{Streaming: config.PipelineStreaming{TTS: &ttsOn}}, + }, + } + } + + It("streams one output_audio.delta per TTS chunk when streaming is enabled", func() { + m := &fakeModel{ + ttsStreamChunks: [][]byte{{1, 2}, {3, 4}, {5, 6}}, + ttsStreamRate: 24000, + } + t := &fakeTransport{} + + err := emitSpeech(context.Background(), t, streamingSession(m), "resp1", "item1", "Hello there.") + + Expect(err).ToNot(HaveOccurred()) + Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(3)) + }) + + It("sends a single output_audio.delta in unary mode", func() { + // A minimal real WAV file for the unary TTS path to read + parse. + f, err := os.CreateTemp("", "emit-*.wav") + Expect(err).ToNot(HaveOccurred()) + defer os.Remove(f.Name()) + pcm := make([]byte, 320) // 160 samples of silence + hdr := laudio.NewWAVHeader(uint32(len(pcm))) + Expect(hdr.Write(f)).To(Succeed()) + _, err = f.Write(pcm) + Expect(err).ToNot(HaveOccurred()) + Expect(f.Close()).To(Succeed()) + + session := &Session{ + OutputSampleRate: 24000, + ModelInterface: &fakeModel{ttsFile: f.Name()}, + ModelConfig: &config.ModelConfig{}, // streaming off + } + t := &fakeTransport{} + + err = emitSpeech(context.Background(), t, session, "resp1", "item1", "Hello there.") + + Expect(err).ToNot(HaveOccurred()) + Expect(t.countEvents(types.ServerEventTypeResponseOutputAudioDelta)).To(Equal(1)) + }) +})