feat(tts): add support for streaming mode (#8291)

* feat(tts): add support for streaming mode

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Send first audio, make sure it's 16

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-01-30 11:58:01 +01:00
committed by GitHub
parent 2c44b06a67
commit 68dd9765a0
13 changed files with 369 additions and 0 deletions

View File

@@ -207,6 +207,90 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=True)
def TTSStream(self, request, context):
try:
# Get generation parameters from options with defaults
cfg_value = self.options.get("cfg_value", 2.0)
inference_timesteps = self.options.get("inference_timesteps", 10)
normalize = self.options.get("normalize", False)
denoise = self.options.get("denoise", False)
retry_badcase = self.options.get("retry_badcase", True)
retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3)
retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0)
# Handle voice cloning via prompt_wav_path and prompt_text
prompt_wav_path = None
prompt_text = None
# Priority: request.voice > AudioPath > options
if hasattr(request, 'voice') and request.voice:
# If voice is provided, try to use it as a path
if os.path.exists(request.voice):
prompt_wav_path = request.voice
elif hasattr(request, 'ModelFile') and request.ModelFile:
model_file_base = os.path.dirname(request.ModelFile)
potential_path = os.path.join(model_file_base, request.voice)
if os.path.exists(potential_path):
prompt_wav_path = potential_path
elif hasattr(request, 'ModelPath') and request.ModelPath:
potential_path = os.path.join(request.ModelPath, request.voice)
if os.path.exists(potential_path):
prompt_wav_path = potential_path
if hasattr(request, 'AudioPath') and request.AudioPath:
if os.path.isabs(request.AudioPath):
prompt_wav_path = request.AudioPath
elif hasattr(request, 'ModelFile') and request.ModelFile:
model_file_base = os.path.dirname(request.ModelFile)
prompt_wav_path = os.path.join(model_file_base, request.AudioPath)
elif hasattr(request, 'ModelPath') and request.ModelPath:
prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath)
else:
prompt_wav_path = request.AudioPath
# Get prompt_text from options if available
if "prompt_text" in self.options:
prompt_text = self.options["prompt_text"]
# Prepare text
text = request.text.strip()
# Get sample rate from model (needed for WAV header)
sample_rate = self.model.tts_model.sample_rate
print(f"Streaming audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, sample_rate: {sample_rate}", file=sys.stderr)
# Send sample rate as first message (in message field as JSON or string)
# Format: "sample_rate:16000" so we can parse it
import json
sample_rate_info = json.dumps({"sample_rate": int(sample_rate)})
yield backend_pb2.Reply(message=bytes(sample_rate_info, 'utf-8'))
# Stream audio chunks
for chunk in self.model.generate_streaming(
text=text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=cfg_value,
inference_timesteps=inference_timesteps,
normalize=normalize,
denoise=denoise,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
):
# Convert numpy array to int16 PCM and then to bytes
# Ensure values are in int16 range
chunk_int16 = np.clip(chunk * 32767, -32768, 32767).astype(np.int16)
chunk_bytes = chunk_int16.tobytes()
yield backend_pb2.Reply(audio=chunk_bytes)
except Exception as err:
print(f"Error in TTSStream: {err}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# Yield an error reply
yield backend_pb2.Reply(message=bytes(f"Error: {err}", 'utf-8'))
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[

View File

@@ -49,3 +49,39 @@ class TestBackendServicer(unittest.TestCase):
self.fail("LoadModel service failed")
finally:
self.tearDown()
def test_tts_stream(self):
"""
This method tests if TTS streaming works correctly
"""
try:
self.setUp()
print("Starting test_tts_stream")
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5"))
print(response)
self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully")
# Test TTSStream
tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest. This is a streaming test.", dst="test_stream.wav")
chunks_received = 0
total_audio_bytes = 0
for reply in stub.TTSStream(tts_request):
# Verify that we receive audio chunks
if reply.audio:
chunks_received += 1
total_audio_bytes += len(reply.audio)
self.assertGreater(len(reply.audio), 0, "Audio chunk should not be empty")
# Verify that we received multiple chunks
self.assertGreater(chunks_received, 0, "Should receive at least one audio chunk")
self.assertGreater(total_audio_bytes, 0, "Total audio bytes should be greater than 0")
print(f"Received {chunks_received} chunks with {total_audio_bytes} total bytes")
except Exception as err:
print(err)
self.fail("TTSStream service failed")
finally:
self.tearDown()