diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 0d638a909..76dc9fd3f 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -235,6 +235,12 @@ type Model interface { Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) + // TTSStream synthesizes speech incrementally, invoking onAudio with raw PCM + // chunks (and the backend sample rate) as they are produced. + TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error + // TranscribeStream transcribes audio incrementally, invoking onDelta for each + // transcript text fragment and returning the final aggregated result. + TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) PredictConfig() *config.ModelConfig } diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index b9a3adda9..b18439340 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -3,6 +3,7 @@ package openai import ( "context" "crypto/rand" + "encoding/binary" "encoding/hex" "encoding/json" "fmt" @@ -87,6 +88,14 @@ func (m *transcriptOnlyModel) TTS(ctx context.Context, text, voice, language str return "", nil, fmt.Errorf("TTS not supported in transcript-only mode") } +func (m *transcriptOnlyModel) TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error { + return fmt.Errorf("TTS not supported in transcript-only mode") +} + +func (m *transcriptOnlyModel) TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) { + return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta) +} + func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig { return nil } @@ -321,10 +330,75 @@ func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (s return backend.ModelTTS(ctx, text, voice, language, "", nil, m.modelLoader, m.appConfig, *m.TTSConfig) } +func (m *wrappedModel) TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error { + return ttsStream(ctx, m.modelLoader, m.appConfig, *m.TTSConfig, text, voice, language, onAudio) +} + +func (m *wrappedModel) TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) { + return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta) +} + func (m *wrappedModel) PredictConfig() *config.ModelConfig { return m.LLMConfig } +// wavStreamHeaderBytes is the size of the WAV header that backend.ModelTTSStream +// emits as its first audio callback; the sample rate lives at byte offset 24. +const wavStreamHeaderBytes = 44 + +// ttsStream adapts backend.ModelTTSStream (which emits a WAV stream: a 44-byte +// header carrying the sample rate, then raw PCM) to the realtime onAudio +// callback, which wants raw PCM plus the sample rate. The header is buffered +// until complete, the sample rate is read from it, and subsequent bytes are +// forwarded as PCM. +func ttsStream(ctx context.Context, ml *model.ModelLoader, appConfig *config.ApplicationConfig, ttsConfig config.ModelConfig, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error { + var header []byte + headerDone := false + sampleRate := 0 + return backend.ModelTTSStream(ctx, text, voice, language, "", nil, ml, appConfig, ttsConfig, func(b []byte) error { + if headerDone { + if len(b) == 0 { + return nil + } + return onAudio(b, sampleRate) + } + header = append(header, b...) + if len(header) < wavStreamHeaderBytes { + return nil + } + sampleRate = int(binary.LittleEndian.Uint32(header[24:28])) + headerDone = true + if len(header) > wavStreamHeaderBytes { + return onAudio(header[wavStreamHeaderBytes:], sampleRate) + } + return nil + }) +} + +// transcribeStream adapts backend.ModelTranscriptionStream to the realtime +// onDelta callback, returning the final aggregated transcription result. +func transcribeStream(ctx context.Context, ml *model.ModelLoader, transcriptionConfig config.ModelConfig, appConfig *config.ApplicationConfig, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) { + var final *schema.TranscriptionResult + err := backend.ModelTranscriptionStream(ctx, backend.TranscriptionRequest{ + Audio: audio, + Language: language, + Translate: translate, + Diarize: diarize, + Prompt: prompt, + }, ml, transcriptionConfig, appConfig, func(chunk backend.TranscriptionStreamChunk) { + if chunk.Delta != "" { + onDelta(chunk.Delta) + } + if chunk.Final != nil { + final = chunk.Final + } + }) + if err != nil { + return nil, err + } + return final, nil +} + func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) { cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) if err != nil {