mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-01 02:03:04 -05:00
feat(tts): add support for streaming mode (#8291)
* feat(tts): add support for streaming mode Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Send first audio, make sure it's 16 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
2c44b06a67
commit
68dd9765a0
@@ -207,6 +207,90 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def TTSStream(self, request, context):
|
||||
try:
|
||||
# Get generation parameters from options with defaults
|
||||
cfg_value = self.options.get("cfg_value", 2.0)
|
||||
inference_timesteps = self.options.get("inference_timesteps", 10)
|
||||
normalize = self.options.get("normalize", False)
|
||||
denoise = self.options.get("denoise", False)
|
||||
retry_badcase = self.options.get("retry_badcase", True)
|
||||
retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3)
|
||||
retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0)
|
||||
|
||||
# Handle voice cloning via prompt_wav_path and prompt_text
|
||||
prompt_wav_path = None
|
||||
prompt_text = None
|
||||
|
||||
# Priority: request.voice > AudioPath > options
|
||||
if hasattr(request, 'voice') and request.voice:
|
||||
# If voice is provided, try to use it as a path
|
||||
if os.path.exists(request.voice):
|
||||
prompt_wav_path = request.voice
|
||||
elif hasattr(request, 'ModelFile') and request.ModelFile:
|
||||
model_file_base = os.path.dirname(request.ModelFile)
|
||||
potential_path = os.path.join(model_file_base, request.voice)
|
||||
if os.path.exists(potential_path):
|
||||
prompt_wav_path = potential_path
|
||||
elif hasattr(request, 'ModelPath') and request.ModelPath:
|
||||
potential_path = os.path.join(request.ModelPath, request.voice)
|
||||
if os.path.exists(potential_path):
|
||||
prompt_wav_path = potential_path
|
||||
|
||||
if hasattr(request, 'AudioPath') and request.AudioPath:
|
||||
if os.path.isabs(request.AudioPath):
|
||||
prompt_wav_path = request.AudioPath
|
||||
elif hasattr(request, 'ModelFile') and request.ModelFile:
|
||||
model_file_base = os.path.dirname(request.ModelFile)
|
||||
prompt_wav_path = os.path.join(model_file_base, request.AudioPath)
|
||||
elif hasattr(request, 'ModelPath') and request.ModelPath:
|
||||
prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath)
|
||||
else:
|
||||
prompt_wav_path = request.AudioPath
|
||||
|
||||
# Get prompt_text from options if available
|
||||
if "prompt_text" in self.options:
|
||||
prompt_text = self.options["prompt_text"]
|
||||
|
||||
# Prepare text
|
||||
text = request.text.strip()
|
||||
|
||||
# Get sample rate from model (needed for WAV header)
|
||||
sample_rate = self.model.tts_model.sample_rate
|
||||
|
||||
print(f"Streaming audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, sample_rate: {sample_rate}", file=sys.stderr)
|
||||
|
||||
# Send sample rate as first message (in message field as JSON or string)
|
||||
# Format: "sample_rate:16000" so we can parse it
|
||||
import json
|
||||
sample_rate_info = json.dumps({"sample_rate": int(sample_rate)})
|
||||
yield backend_pb2.Reply(message=bytes(sample_rate_info, 'utf-8'))
|
||||
|
||||
# Stream audio chunks
|
||||
for chunk in self.model.generate_streaming(
|
||||
text=text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=cfg_value,
|
||||
inference_timesteps=inference_timesteps,
|
||||
normalize=normalize,
|
||||
denoise=denoise,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
):
|
||||
# Convert numpy array to int16 PCM and then to bytes
|
||||
# Ensure values are in int16 range
|
||||
chunk_int16 = np.clip(chunk * 32767, -32768, 32767).astype(np.int16)
|
||||
chunk_bytes = chunk_int16.tobytes()
|
||||
yield backend_pb2.Reply(audio=chunk_bytes)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in TTSStream: {err}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
# Yield an error reply
|
||||
yield backend_pb2.Reply(message=bytes(f"Error: {err}", 'utf-8'))
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
|
||||
@@ -49,3 +49,39 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts_stream(self):
|
||||
"""
|
||||
This method tests if TTS streaming works correctly
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
print("Starting test_tts_stream")
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5"))
|
||||
print(response)
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
|
||||
# Test TTSStream
|
||||
tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest. This is a streaming test.", dst="test_stream.wav")
|
||||
chunks_received = 0
|
||||
total_audio_bytes = 0
|
||||
|
||||
for reply in stub.TTSStream(tts_request):
|
||||
# Verify that we receive audio chunks
|
||||
if reply.audio:
|
||||
chunks_received += 1
|
||||
total_audio_bytes += len(reply.audio)
|
||||
self.assertGreater(len(reply.audio), 0, "Audio chunk should not be empty")
|
||||
|
||||
# Verify that we received multiple chunks
|
||||
self.assertGreater(chunks_received, 0, "Should receive at least one audio chunk")
|
||||
self.assertGreater(total_audio_bytes, 0, "Total audio bytes should be greater than 0")
|
||||
print(f"Received {chunks_received} chunks with {total_audio_bytes} total bytes")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTSStream service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
Reference in New Issue
Block a user