feat(tts): add support for streaming mode (#8291)

* feat(tts): add support for streaming mode

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Send first audio, make sure it's 16

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-01-30 11:58:01 +01:00
committed by GitHub
parent 2c44b06a67
commit 68dd9765a0
13 changed files with 369 additions and 0 deletions

View File

@@ -17,6 +17,7 @@ service Backend {
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
rpc TTS(TTSRequest) returns (Result) {}
rpc TTSStream(TTSRequest) returns (stream Reply) {}
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
rpc Status(HealthMessage) returns (StatusResponse) {}

View File

@@ -207,6 +207,90 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=True)
def TTSStream(self, request, context):
try:
# Get generation parameters from options with defaults
cfg_value = self.options.get("cfg_value", 2.0)
inference_timesteps = self.options.get("inference_timesteps", 10)
normalize = self.options.get("normalize", False)
denoise = self.options.get("denoise", False)
retry_badcase = self.options.get("retry_badcase", True)
retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3)
retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0)
# Handle voice cloning via prompt_wav_path and prompt_text
prompt_wav_path = None
prompt_text = None
# Priority: request.voice > AudioPath > options
if hasattr(request, 'voice') and request.voice:
# If voice is provided, try to use it as a path
if os.path.exists(request.voice):
prompt_wav_path = request.voice
elif hasattr(request, 'ModelFile') and request.ModelFile:
model_file_base = os.path.dirname(request.ModelFile)
potential_path = os.path.join(model_file_base, request.voice)
if os.path.exists(potential_path):
prompt_wav_path = potential_path
elif hasattr(request, 'ModelPath') and request.ModelPath:
potential_path = os.path.join(request.ModelPath, request.voice)
if os.path.exists(potential_path):
prompt_wav_path = potential_path
if hasattr(request, 'AudioPath') and request.AudioPath:
if os.path.isabs(request.AudioPath):
prompt_wav_path = request.AudioPath
elif hasattr(request, 'ModelFile') and request.ModelFile:
model_file_base = os.path.dirname(request.ModelFile)
prompt_wav_path = os.path.join(model_file_base, request.AudioPath)
elif hasattr(request, 'ModelPath') and request.ModelPath:
prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath)
else:
prompt_wav_path = request.AudioPath
# Get prompt_text from options if available
if "prompt_text" in self.options:
prompt_text = self.options["prompt_text"]
# Prepare text
text = request.text.strip()
# Get sample rate from model (needed for WAV header)
sample_rate = self.model.tts_model.sample_rate
print(f"Streaming audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, sample_rate: {sample_rate}", file=sys.stderr)
# Send sample rate as first message (in message field as JSON or string)
# Format: "sample_rate:16000" so we can parse it
import json
sample_rate_info = json.dumps({"sample_rate": int(sample_rate)})
yield backend_pb2.Reply(message=bytes(sample_rate_info, 'utf-8'))
# Stream audio chunks
for chunk in self.model.generate_streaming(
text=text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=cfg_value,
inference_timesteps=inference_timesteps,
normalize=normalize,
denoise=denoise,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
):
# Convert numpy array to int16 PCM and then to bytes
# Ensure values are in int16 range
chunk_int16 = np.clip(chunk * 32767, -32768, 32767).astype(np.int16)
chunk_bytes = chunk_int16.tobytes()
yield backend_pb2.Reply(audio=chunk_bytes)
except Exception as err:
print(f"Error in TTSStream: {err}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# Yield an error reply
yield backend_pb2.Reply(message=bytes(f"Error: {err}", 'utf-8'))
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[

View File

@@ -49,3 +49,39 @@ class TestBackendServicer(unittest.TestCase):
self.fail("LoadModel service failed")
finally:
self.tearDown()
def test_tts_stream(self):
"""
This method tests if TTS streaming works correctly
"""
try:
self.setUp()
print("Starting test_tts_stream")
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5"))
print(response)
self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully")
# Test TTSStream
tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest. This is a streaming test.", dst="test_stream.wav")
chunks_received = 0
total_audio_bytes = 0
for reply in stub.TTSStream(tts_request):
# Verify that we receive audio chunks
if reply.audio:
chunks_received += 1
total_audio_bytes += len(reply.audio)
self.assertGreater(len(reply.audio), 0, "Audio chunk should not be empty")
# Verify that we received multiple chunks
self.assertGreater(chunks_received, 0, "Should receive at least one audio chunk")
self.assertGreater(total_audio_bytes, 0, "Total audio bytes should be greater than 0")
print(f"Received {chunks_received} chunks with {total_audio_bytes} total bytes")
except Exception as err:
print(err)
self.fail("TTSStream service failed")
finally:
self.tearDown()

View File

@@ -1,12 +1,16 @@
package backend
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/mudler/LocalAI/core/config"
laudio "github.com/mudler/LocalAI/pkg/audio"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
@@ -74,3 +78,101 @@ func ModelTTS(
return filePath, res, err
}
func ModelTTSStream(
text,
voice,
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig,
audioCallback func([]byte) error,
) error {
opts := ModelOptions(modelConfig, appConfig)
ttsModel, err := loader.Load(opts...)
if err != nil {
return err
}
if ttsModel == nil {
return fmt.Errorf("could not load tts model %q", modelConfig.Model)
}
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
// This should be addressed in a follow up PR soon.
// Copying it over nearly verbatim, as TTS backends are not functional without this.
modelPath := ""
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, modelConfig.Model)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil {
return err
}
modelPath = mp
} else {
modelPath = modelConfig.Model // skip this step if it fails?????
}
var sampleRate uint32 = 16000 // default
headerSent := false
var callbackErr error
err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Language: &language,
}, func(reply *proto.Reply) {
// First message contains sample rate info
if !headerSent && len(reply.Message) > 0 {
var info map[string]interface{}
if json.Unmarshal(reply.Message, &info) == nil {
if sr, ok := info["sample_rate"].(float64); ok {
sampleRate = uint32(sr)
}
}
// Send WAV header with placeholder size (0xFFFFFFFF for streaming)
header := laudio.WAVHeader{
ChunkID: [4]byte{'R', 'I', 'F', 'F'},
ChunkSize: 0xFFFFFFFF, // Unknown size for streaming
Format: [4]byte{'W', 'A', 'V', 'E'},
Subchunk1ID: [4]byte{'f', 'm', 't', ' '},
Subchunk1Size: 16,
AudioFormat: 1, // PCM
NumChannels: 1, // Mono
SampleRate: sampleRate,
ByteRate: sampleRate * 2, // SampleRate * BlockAlign
BlockAlign: 2, // 16-bit = 2 bytes
BitsPerSample: 16,
Subchunk2ID: [4]byte{'d', 'a', 't', 'a'},
Subchunk2Size: 0xFFFFFFFF, // Unknown size for streaming
}
var buf bytes.Buffer
if writeErr := binary.Write(&buf, binary.LittleEndian, header); writeErr != nil {
callbackErr = writeErr
return
}
if writeErr := audioCallback(buf.Bytes()); writeErr != nil {
callbackErr = writeErr
return
}
headerSent = true
}
// Stream audio chunks
if len(reply.Audio) > 0 {
if writeErr := audioCallback(reply.Audio); writeErr != nil {
callbackErr = writeErr
}
}
})
if callbackErr != nil {
return callbackErr
}
return err
}

View File

@@ -50,6 +50,31 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
cfg.Voice = input.Voice
}
// Handle streaming TTS
if input.Stream {
// Set headers for streaming audio
c.Response().Header().Set("Content-Type", "audio/wav")
c.Response().Header().Set("Transfer-Encoding", "chunked")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
// Stream audio chunks as they're generated
err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
_, writeErr := c.Response().Write(audioChunk)
if writeErr != nil {
return writeErr
}
c.Response().Flush()
return nil
})
if err != nil {
return err
}
return nil
}
// Non-streaming TTS (existing behavior)
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
if err != nil {
return err

View File

@@ -53,6 +53,7 @@ type TTSRequest struct {
Backend string `json:"backend" yaml:"backend"`
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS
}
// @Description VAD request body

View File

@@ -29,6 +29,41 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
Returns an `audio/wav` file.
## Streaming TTS
LocalAI supports streaming TTS generation, allowing audio to be played as it's generated. This is useful for real-time applications and reduces latency.
To enable streaming, add `"stream": true` to your request:
```bash
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
"input": "Hello world, this is a streaming test",
"model": "voxcpm",
"stream": true
}' | aplay
```
The audio will be streamed chunk-by-chunk as it's generated, allowing playback to start before generation completes. This is particularly useful for long texts or when you want to minimize perceived latency.
You can also pipe the streamed audio directly to audio players like `aplay` (Linux) or save it to a file:
```bash
# Stream to aplay (Linux)
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
"input": "This is a longer text that will be streamed as it is generated",
"model": "voxcpm",
"stream": true
}' | aplay
# Stream to a file
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
"input": "Streaming audio to file",
"model": "voxcpm",
"stream": true
}' > output.wav
```
Note: Streaming TTS is currently supported by the `voxcpm` backend. Other backends will fall back to non-streaming mode if streaming is not supported.
## Backends

View File

@@ -41,6 +41,7 @@ type Backend interface {
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)

View File

@@ -65,6 +65,10 @@ func (llm *Base) TTS(*pb.TTSRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) TTSStream(*pb.TTSRequest, chan []byte) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error {
return fmt.Errorf("unimplemented")
}

View File

@@ -270,6 +270,56 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
return client.TTS(ctx, in, opts...)
}
func (c *Client) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
}
c.setBusy(true)
defer c.setBusy(false)
c.wdMark()
defer c.wdUnMark()
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(50*1024*1024), // 50MB
grpc.MaxCallSendMsgSize(50*1024*1024), // 50MB
))
if err != nil {
return err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
stream, err := client.TTSStream(ctx, in, opts...)
if err != nil {
return err
}
for {
// Check if context is cancelled before receiving
select {
case <-ctx.Done():
return ctx.Err()
default:
}
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
// Check if error is due to context cancellation
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
f(reply)
}
return nil
}
func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
if !c.parallel {
c.opMutex.Lock()

View File

@@ -55,6 +55,14 @@ func (e *embedBackend) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.
return e.s.TTS(ctx, in)
}
func (e *embedBackend) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error {
bs := &embedBackendServerStream{
ctx: ctx,
fn: f,
}
return e.s.TTSStream(in, bs)
}
func (e *embedBackend) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error) {
return e.s.SoundGeneration(ctx, in)
}

View File

@@ -18,6 +18,7 @@ type AIModel interface {
Detect(*pb.DetectOptions) (pb.DetectResponse, error)
AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error)
TTS(*pb.TTSRequest) error
TTSStream(*pb.TTSRequest, chan []byte) error
SoundGeneration(*pb.SoundGenerationRequest) error
TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error)
Status() (pb.StatusResponse, error)

View File

@@ -99,6 +99,27 @@ func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error)
return &pb.Result{Message: "TTS audio generated", Success: true}, nil
}
func (s *server) TTSStream(in *pb.TTSRequest, stream pb.Backend_TTSStreamServer) error {
if s.llm.Locking() {
s.llm.Lock()
defer s.llm.Unlock()
}
audioChan := make(chan []byte)
done := make(chan bool)
go func() {
for audioChunk := range audioChan {
stream.Send(&pb.Reply{Audio: audioChunk})
}
done <- true
}()
err := s.llm.TTSStream(in, audioChan)
<-done
return err
}
func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) {
if s.llm.Locking() {
s.llm.Lock()