mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-06 07:46:15 -04:00
feat(realtime): streaming TTS/transcription methods on Model interface
Add TTSStream and TranscribeStream to the realtime Model interface and implement them on wrappedModel (delegating to backend.ModelTTSStream / ModelTranscriptionStream) and transcriptOnlyModel. ttsStream adapts the backend's WAV-framed stream (44-byte header carrying the sample rate, then PCM) into raw PCM + sample rate for the realtime transports. Handler wiring that consumes these (flag-gated) follows. Assisted-by: Claude:claude-opus-4-8 go vet Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -235,6 +235,12 @@ type Model interface {
|
||||
Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error)
|
||||
Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error)
|
||||
TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error)
|
||||
// TTSStream synthesizes speech incrementally, invoking onAudio with raw PCM
|
||||
// chunks (and the backend sample rate) as they are produced.
|
||||
TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error
|
||||
// TranscribeStream transcribes audio incrementally, invoking onDelta for each
|
||||
// transcript text fragment and returning the final aggregated result.
|
||||
TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error)
|
||||
PredictConfig() *config.ModelConfig
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package openai
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -87,6 +88,14 @@ func (m *transcriptOnlyModel) TTS(ctx context.Context, text, voice, language str
|
||||
return "", nil, fmt.Errorf("TTS not supported in transcript-only mode")
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
|
||||
return fmt.Errorf("TTS not supported in transcript-only mode")
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
|
||||
return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta)
|
||||
}
|
||||
|
||||
func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig {
|
||||
return nil
|
||||
}
|
||||
@@ -321,10 +330,75 @@ func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (s
|
||||
return backend.ModelTTS(ctx, text, voice, language, "", nil, m.modelLoader, m.appConfig, *m.TTSConfig)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TTSStream(ctx context.Context, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
|
||||
return ttsStream(ctx, m.modelLoader, m.appConfig, *m.TTSConfig, text, voice, language, onAudio)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) TranscribeStream(ctx context.Context, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
|
||||
return transcribeStream(ctx, m.modelLoader, *m.TranscriptionConfig, m.appConfig, audio, language, translate, diarize, prompt, onDelta)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictConfig() *config.ModelConfig {
|
||||
return m.LLMConfig
|
||||
}
|
||||
|
||||
// wavStreamHeaderBytes is the size of the WAV header that backend.ModelTTSStream
|
||||
// emits as its first audio callback; the sample rate lives at byte offset 24.
|
||||
const wavStreamHeaderBytes = 44
|
||||
|
||||
// ttsStream adapts backend.ModelTTSStream (which emits a WAV stream: a 44-byte
|
||||
// header carrying the sample rate, then raw PCM) to the realtime onAudio
|
||||
// callback, which wants raw PCM plus the sample rate. The header is buffered
|
||||
// until complete, the sample rate is read from it, and subsequent bytes are
|
||||
// forwarded as PCM.
|
||||
func ttsStream(ctx context.Context, ml *model.ModelLoader, appConfig *config.ApplicationConfig, ttsConfig config.ModelConfig, text, voice, language string, onAudio func(pcm []byte, sampleRate int) error) error {
|
||||
var header []byte
|
||||
headerDone := false
|
||||
sampleRate := 0
|
||||
return backend.ModelTTSStream(ctx, text, voice, language, "", nil, ml, appConfig, ttsConfig, func(b []byte) error {
|
||||
if headerDone {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
return onAudio(b, sampleRate)
|
||||
}
|
||||
header = append(header, b...)
|
||||
if len(header) < wavStreamHeaderBytes {
|
||||
return nil
|
||||
}
|
||||
sampleRate = int(binary.LittleEndian.Uint32(header[24:28]))
|
||||
headerDone = true
|
||||
if len(header) > wavStreamHeaderBytes {
|
||||
return onAudio(header[wavStreamHeaderBytes:], sampleRate)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// transcribeStream adapts backend.ModelTranscriptionStream to the realtime
|
||||
// onDelta callback, returning the final aggregated transcription result.
|
||||
func transcribeStream(ctx context.Context, ml *model.ModelLoader, transcriptionConfig config.ModelConfig, appConfig *config.ApplicationConfig, audio, language string, translate, diarize bool, prompt string, onDelta func(text string)) (*schema.TranscriptionResult, error) {
|
||||
var final *schema.TranscriptionResult
|
||||
err := backend.ModelTranscriptionStream(ctx, backend.TranscriptionRequest{
|
||||
Audio: audio,
|
||||
Language: language,
|
||||
Translate: translate,
|
||||
Diarize: diarize,
|
||||
Prompt: prompt,
|
||||
}, ml, transcriptionConfig, appConfig, func(chunk backend.TranscriptionStreamChunk) {
|
||||
if chunk.Delta != "" {
|
||||
onDelta(chunk.Delta)
|
||||
}
|
||||
if chunk.Final != nil {
|
||||
final = chunk.Final
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return final, nil
|
||||
}
|
||||
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user