mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-30 01:02:37 -05:00
feat: add VoxCPM tts backend (#8109)
* feat: add VoxCPM tts backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Disable voxcpm on arm64 cpu Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
cba8ef4e38
commit
9b973b79f6
23
backend/python/voxcpm/Makefile
Normal file
23
backend/python/voxcpm/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: voxcpm
|
||||
voxcpm:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: voxcpm
|
||||
@echo "Running voxcpm..."
|
||||
bash run.sh
|
||||
@echo "voxcpm run."
|
||||
|
||||
.PHONY: test
|
||||
test: voxcpm
|
||||
@echo "Testing voxcpm..."
|
||||
bash test.sh
|
||||
@echo "voxcpm tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
245
backend/python/voxcpm/backend.py
Normal file
245
backend/python/voxcpm/backend.py
Normal file
@@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for VoxCPM
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import traceback
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from voxcpm import VoxCPM
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import torch
|
||||
|
||||
import grpc
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
# Get device
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA is available", file=sys.stderr)
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
if mps_available:
|
||||
device = "mps"
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
# Normalize potential 'mpx' typo to 'mps'
|
||||
if device == "mpx":
|
||||
print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
|
||||
device = "mps"
|
||||
|
||||
# Validate mps availability if requested
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
|
||||
device = "cpu"
|
||||
|
||||
self.device = device
|
||||
|
||||
options = request.Options
|
||||
|
||||
# empty dict
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when
|
||||
# generating the audio
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon
|
||||
# if value is a number, convert it to the appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
# Get model path from request
|
||||
model_path = request.Model
|
||||
if not model_path:
|
||||
model_path = "openbmb/VoxCPM1.5"
|
||||
|
||||
try:
|
||||
print(f"Loading model from {model_path}", file=sys.stderr)
|
||||
self.model = VoxCPM.from_pretrained(model_path)
|
||||
print(f"Model loaded successfully on device: {self.device}", file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# Get generation parameters from options with defaults
|
||||
cfg_value = self.options.get("cfg_value", 2.0)
|
||||
inference_timesteps = self.options.get("inference_timesteps", 10)
|
||||
normalize = self.options.get("normalize", False)
|
||||
denoise = self.options.get("denoise", False)
|
||||
retry_badcase = self.options.get("retry_badcase", True)
|
||||
retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3)
|
||||
retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0)
|
||||
use_streaming = self.options.get("streaming", False)
|
||||
|
||||
# Handle voice cloning via prompt_wav_path and prompt_text
|
||||
prompt_wav_path = None
|
||||
prompt_text = None
|
||||
|
||||
# Priority: request.voice > AudioPath > options
|
||||
if hasattr(request, 'voice') and request.voice:
|
||||
# If voice is provided, try to use it as a path
|
||||
if os.path.exists(request.voice):
|
||||
prompt_wav_path = request.voice
|
||||
elif hasattr(request, 'ModelFile') and request.ModelFile:
|
||||
model_file_base = os.path.dirname(request.ModelFile)
|
||||
potential_path = os.path.join(model_file_base, request.voice)
|
||||
if os.path.exists(potential_path):
|
||||
prompt_wav_path = potential_path
|
||||
elif hasattr(request, 'ModelPath') and request.ModelPath:
|
||||
potential_path = os.path.join(request.ModelPath, request.voice)
|
||||
if os.path.exists(potential_path):
|
||||
prompt_wav_path = potential_path
|
||||
|
||||
if hasattr(request, 'AudioPath') and request.AudioPath:
|
||||
if os.path.isabs(request.AudioPath):
|
||||
prompt_wav_path = request.AudioPath
|
||||
elif hasattr(request, 'ModelFile') and request.ModelFile:
|
||||
model_file_base = os.path.dirname(request.ModelFile)
|
||||
prompt_wav_path = os.path.join(model_file_base, request.AudioPath)
|
||||
elif hasattr(request, 'ModelPath') and request.ModelPath:
|
||||
prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath)
|
||||
else:
|
||||
prompt_wav_path = request.AudioPath
|
||||
|
||||
# Get prompt_text from options if available
|
||||
if "prompt_text" in self.options:
|
||||
prompt_text = self.options["prompt_text"]
|
||||
|
||||
# Prepare text
|
||||
text = request.text.strip()
|
||||
|
||||
print(f"Generating audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, streaming: {use_streaming}", file=sys.stderr)
|
||||
|
||||
# Generate audio
|
||||
if use_streaming:
|
||||
# Streaming generation
|
||||
chunks = []
|
||||
for chunk in self.model.generate_streaming(
|
||||
text=text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=cfg_value,
|
||||
inference_timesteps=inference_timesteps,
|
||||
normalize=normalize,
|
||||
denoise=denoise,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
):
|
||||
chunks.append(chunk)
|
||||
wav = np.concatenate(chunks)
|
||||
else:
|
||||
# Non-streaming generation
|
||||
wav = self.model.generate(
|
||||
text=text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=cfg_value,
|
||||
inference_timesteps=inference_timesteps,
|
||||
normalize=normalize,
|
||||
denoise=denoise,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
)
|
||||
|
||||
# Get sample rate from model
|
||||
sample_rate = self.model.tts_model.sample_rate
|
||||
|
||||
# Save output
|
||||
sf.write(request.dst, wav, sample_rate)
|
||||
print(f"Saved output to {request.dst}", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in TTS: {err}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
30
backend/python/voxcpm/install.sh
Executable file
30
backend/python/voxcpm/install.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
|
||||
# Apply patch to fix PyTorch compatibility issue in voxcpm
|
||||
# This fixes the "Dimension out of range" error in scaled_dot_product_attention
|
||||
# by changing .contiguous() to .unsqueeze(0) in the attention module
|
||||
# The patch is needed because voxcpm's initialization test generation fails with
|
||||
# certain PyTorch versions due to a bug in scaled_dot_product_attention
|
||||
# https://github.com/OpenBMB/VoxCPM/issues/71#issuecomment-3441789452
|
||||
VOXCPM_PATH=$(python -c "import voxcpm; import os; print(os.path.dirname(voxcpm.__file__))" 2>/dev/null || echo "")
|
||||
if [ -n "$VOXCPM_PATH" ] && [ -f "$VOXCPM_PATH/modules/minicpm4/model.py" ]; then
|
||||
echo "Applying patch to voxcpm at $VOXCPM_PATH/modules/minicpm4/model.py"
|
||||
# Replace .contiguous() with .unsqueeze(0) for the three lines in the attention forward_step method
|
||||
# This fixes the dimension error in scaled_dot_product_attention
|
||||
sed -i 's/query_states = query_states\.contiguous()/query_states = query_states.unsqueeze(0)/g' "$VOXCPM_PATH/modules/minicpm4/model.py"
|
||||
sed -i 's/key_cache = key_cache\.contiguous()/key_cache = key_cache.unsqueeze(0)/g' "$VOXCPM_PATH/modules/minicpm4/model.py"
|
||||
sed -i 's/value_cache = value_cache\.contiguous()/value_cache = value_cache.unsqueeze(0)/g' "$VOXCPM_PATH/modules/minicpm4/model.py"
|
||||
echo "Patch applied successfully"
|
||||
else
|
||||
echo "Warning: Could not find voxcpm installation to apply patch (path: ${VOXCPM_PATH:-not found})"
|
||||
fi
|
||||
11
backend/python/voxcpm/protogen.sh
Executable file
11
backend/python/voxcpm/protogen.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runProtogen
|
||||
6
backend/python/voxcpm/requirements-cpu.txt
Normal file
6
backend/python/voxcpm/requirements-cpu.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
torchcodec
|
||||
5
backend/python/voxcpm/requirements-cublas12.txt
Normal file
5
backend/python/voxcpm/requirements-cublas12.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torch
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
5
backend/python/voxcpm/requirements-cublas13.txt
Normal file
5
backend/python/voxcpm/requirements-cublas13.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
5
backend/python/voxcpm/requirements-hipblas.txt
Normal file
5
backend/python/voxcpm/requirements-hipblas.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
torch==2.7.1+rocm6.3
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
6
backend/python/voxcpm/requirements-intel.txt
Normal file
6
backend/python/voxcpm/requirements-intel.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
setuptools
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
5
backend/python/voxcpm/requirements-l4t12.txt
Normal file
5
backend/python/voxcpm/requirements-l4t12.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
5
backend/python/voxcpm/requirements-l4t13.txt
Normal file
5
backend/python/voxcpm/requirements-l4t13.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
4
backend/python/voxcpm/requirements-mps.txt
Normal file
4
backend/python/voxcpm/requirements-mps.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
7
backend/python/voxcpm/requirements.txt
Normal file
7
backend/python/voxcpm/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
soundfile
|
||||
numpy
|
||||
voxcpm
|
||||
9
backend/python/voxcpm/run.sh
Executable file
9
backend/python/voxcpm/run.sh
Executable file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
51
backend/python/voxcpm/test.py
Normal file
51
backend/python/voxcpm/test.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
A test script to test the gRPC service
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
This method sets up the gRPC service by starting the server
|
||||
"""
|
||||
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(30)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
This method tears down the gRPC service by terminating the server
|
||||
"""
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
print("Starting test_load_model")
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5"))
|
||||
print(response)
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest.", dst="test.wav")
|
||||
tts_response = stub.TTS(tts_request)
|
||||
self.assertIsNotNone(tts_response)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
11
backend/python/voxcpm/test.sh
Executable file
11
backend/python/voxcpm/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
Reference in New Issue
Block a user