Files
LocalAI/backend/python/qwen-tts/backend.py
Ettore Di Giacinto 59108fbe32 feat: add distributed mode (#9124)
* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 00:47:27 +02:00

940 lines
36 KiB
Python

#!/usr/bin/env python3
"""
This is an extra gRPC server of LocalAI for Qwen3-TTS
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import copy
import traceback
from pathlib import Path
import backend_pb2
import backend_pb2_grpc
import torch
import soundfile as sf
from qwen_tts import Qwen3TTSModel
import json
import hashlib
import pickle
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
"""Check if a string can be converted to float."""
try:
float(s)
return True
except ValueError:
return False
def is_int(s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1"))
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
BackendServicer is the class that implements the gRPC service
"""
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", "utf-8"))
def LoadModel(self, request, context):
# Get device
if torch.cuda.is_available():
print("CUDA is available", file=sys.stderr)
device = "cuda"
else:
print("CUDA is not available", file=sys.stderr)
device = "cpu"
mps_available = (
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
)
if mps_available:
device = "mps"
if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")
# Normalize potential 'mpx' typo to 'mps'
if device == "mpx":
print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
device = "mps"
# Validate mps availability if requested
if device == "mps" and not torch.backends.mps.is_available():
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
device = "cpu"
self.device = device
self._torch_device = torch.device(device)
options = request.Options
# empty dict
self.options = {}
# The options are a list of strings in this form optname:optvalue
# We are storing all the options in a dict so we can use it later when
# generating the audio
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1) # Split only on first colon
# if value is a number, convert it to the appropriate type
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value
# Parse voices configuration from options
self.voices = {}
if "voices" in self.options:
try:
voices_data = self.options["voices"]
if isinstance(voices_data, str):
voices_list = json.loads(voices_data)
else:
voices_list = voices_data
# Validate and store voices
for voice_entry in voices_list:
if not isinstance(voice_entry, dict):
print(
f"[WARNING] Invalid voice entry (not a dict): {voice_entry}",
file=sys.stderr,
)
continue
name = voice_entry.get("name")
audio = voice_entry.get("audio")
ref_text = voice_entry.get("ref_text")
if not name or not isinstance(name, str):
print(
f"[WARNING] Voice entry missing required 'name' field: {voice_entry}",
file=sys.stderr,
)
continue
if not audio or not isinstance(audio, str):
print(
f"[WARNING] Voice entry missing required 'audio' field: {voice_entry}",
file=sys.stderr,
)
continue
if ref_text is None or not isinstance(ref_text, str):
print(
f"[WARNING] Voice entry missing required 'ref_text' field: {voice_entry}",
file=sys.stderr,
)
continue
self.voices[name] = {"audio": audio, "ref_text": ref_text}
print(
f"[INFO] Registered voice '{name}' with audio: {audio}",
file=sys.stderr,
)
print(f"[INFO] Loaded {len(self.voices)} voice(s)", file=sys.stderr)
except json.JSONDecodeError as e:
print(f"[ERROR] Failed to parse voices JSON: {e}", file=sys.stderr)
except Exception as e:
print(
f"[ERROR] Error processing voices configuration: {e}",
file=sys.stderr,
)
print(traceback.format_exc(), file=sys.stderr)
# Get model path from request
model_path = request.Model
if not model_path:
model_path = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
# Determine model type from model path or options
self.model_type = self.options.get("model_type", None)
if not self.model_type:
if "CustomVoice" in model_path:
self.model_type = "CustomVoice"
elif "VoiceDesign" in model_path:
self.model_type = "VoiceDesign"
elif "Base" in model_path or "0.6B" in model_path or "1.7B" in model_path:
self.model_type = "Base" # VoiceClone model
else:
# Default to CustomVoice
self.model_type = "CustomVoice"
# Cache for voice clone prompts
self._voice_clone_cache = {}
# Pre-load cached voices if disk_cache is enabled
self._preload_cached_voices()
# Store AudioPath, ModelFile, and ModelPath from LoadModel request
# These are used later in TTS for VoiceClone mode
self.audio_path = (
request.AudioPath
if hasattr(request, "AudioPath") and request.AudioPath
else None
)
self.model_file = (
request.ModelFile
if hasattr(request, "ModelFile") and request.ModelFile
else None
)
self.model_path = (
request.ModelPath
if hasattr(request, "ModelPath") and request.ModelPath
else None
)
# Decide dtype & attention implementation
if self.device == "mps":
load_dtype = torch.float32 # MPS requires float32
device_map = None
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
elif self.device == "cuda":
load_dtype = torch.bfloat16
device_map = "cuda"
attn_impl_primary = "flash_attention_2"
else: # cpu
load_dtype = torch.float32
device_map = "cpu"
attn_impl_primary = "sdpa"
print(
f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}, model_type: {self.model_type}",
file=sys.stderr,
)
print(f"Loading model from: {model_path}", file=sys.stderr)
# Load model with device-specific logic
# Common parameters for all devices
load_kwargs = {
"dtype": load_dtype,
"attn_implementation": attn_impl_primary,
"trust_remote_code": True, # Required for qwen-tts models
}
try:
if self.device == "mps":
load_kwargs["device_map"] = None # load then move
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
self.model.to("mps")
elif self.device == "cuda":
load_kwargs["device_map"] = device_map
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
else: # cpu
load_kwargs["device_map"] = device_map
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
except Exception as e:
error_msg = str(e)
print(
f"[ERROR] Loading model: {type(e).__name__}: {error_msg}",
file=sys.stderr,
)
print(traceback.format_exc(), file=sys.stderr)
# Check if it's a missing feature extractor/tokenizer error
if (
"speech_tokenizer" in error_msg
or "preprocessor_config.json" in error_msg
or "feature extractor" in error_msg.lower()
):
print(
"\n[ERROR] Model files appear to be incomplete. This usually means:",
file=sys.stderr,
)
print(
" 1. The model download was interrupted or incomplete",
file=sys.stderr,
)
print(" 2. The model cache is corrupted", file=sys.stderr)
print("\nTo fix this, try:", file=sys.stderr)
print(
f" rm -rf ~/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-*",
file=sys.stderr,
)
print(" Then re-run to trigger a fresh download.", file=sys.stderr)
print(
"\nAlternatively, try using a different model variant:",
file=sys.stderr,
)
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", file=sys.stderr)
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", file=sys.stderr)
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-Base", file=sys.stderr)
if attn_impl_primary == "flash_attention_2":
print(
"\nTrying to use SDPA instead of flash_attention_2...",
file=sys.stderr,
)
load_kwargs["attn_implementation"] = "sdpa"
try:
if self.device == "mps":
load_kwargs["device_map"] = None
self.model = Qwen3TTSModel.from_pretrained(
model_path, **load_kwargs
)
self.model.to("mps")
else:
load_kwargs["device_map"] = (
self.device if self.device in ("cuda", "cpu") else None
)
self.model = Qwen3TTSModel.from_pretrained(
model_path, **load_kwargs
)
except Exception as e2:
print(
f"[ERROR] Failed to load with SDPA: {type(e2).__name__}: {e2}",
file=sys.stderr,
)
print(traceback.format_exc(), file=sys.stderr)
raise e2
else:
raise e
print(f"Model loaded successfully: {model_path}", file=sys.stderr)
return backend_pb2.Result(message="Model loaded successfully", success=True)
def _detect_mode(self, request):
"""Detect which mode to use based on request parameters."""
# Priority: VoiceClone > VoiceDesign > CustomVoice
# model_type explicitly set
if self.model_type == "CustomVoice":
return "CustomVoice"
if self.model_type == "VoiceClone":
return "VoiceClone"
if self.model_type == "VoiceDesign":
return "VoiceDesign"
# VoiceClone: AudioPath is provided OR voices dict is populated
if self.audio_path or self.voices:
return "VoiceClone"
# VoiceDesign: instruct option is provided
if "instruct" in self.options and self.options["instruct"]:
return "VoiceDesign"
# Default to CustomVoice
return "CustomVoice"
def _get_ref_audio_path(self, request, voice_name=None):
"""Get reference audio path from stored AudioPath or from voices dict."""
# If voice_name is provided and exists in voices dict, use that
if voice_name and voice_name in self.voices:
audio_path = self.voices[voice_name]["audio"]
# If absolute path, use as-is
if os.path.isabs(audio_path):
return audio_path
# Try relative to ModelFile
if self.model_file:
model_file_base = os.path.dirname(self.model_file)
ref_path = os.path.join(model_file_base, audio_path)
if os.path.exists(ref_path):
return ref_path
# Try relative to ModelPath
if self.model_path:
ref_path = os.path.join(self.model_path, audio_path)
if os.path.exists(ref_path):
return ref_path
# Return as-is (might be URL or base64)
return audio_path
# Fall back to legacy single-voice mode using self.audio_path
if not self.audio_path:
return None
# If absolute path, use as-is
if os.path.isabs(self.audio_path):
return self.audio_path
# Try relative to ModelFile
if self.model_file:
model_file_base = os.path.dirname(self.model_file)
ref_path = os.path.join(model_file_base, self.audio_path)
if os.path.exists(ref_path):
return ref_path
# Try relative to ModelPath
if self.model_path:
ref_path = os.path.join(self.model_path, self.audio_path)
if os.path.exists(ref_path):
return ref_path
# Return as-is (might be URL or base64)
return self.audio_path
def _get_voice_clone_prompt(self, request, ref_audio, ref_text):
"""Get or create voice clone prompt, with in-memory and disk caching."""
cache_key = self._get_voice_cache_key(ref_audio, ref_text)
if cache_key not in self._voice_clone_cache:
# Check disk cache first (if enabled)
disk_cached = self._get_cached_voice_clone_prompt_from_disk(
ref_audio, ref_text
)
if disk_cached is not None:
self._voice_clone_cache[cache_key] = disk_cached
else:
# Create new prompt
print(f"Creating voice clone prompt from {ref_audio}", file=sys.stderr)
try:
prompt_items = self.model.create_voice_clone_prompt(
ref_audio=ref_audio,
ref_text=ref_text,
x_vector_only_mode=self.options.get(
"x_vector_only_mode", False
),
)
self._voice_clone_cache[cache_key] = prompt_items
# Save to disk cache if enabled
self._save_voice_clone_prompt_to_disk(
ref_audio, ref_text, prompt_items
)
except Exception as e:
print(f"Error creating voice clone prompt: {e}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
return self._voice_clone_cache[cache_key]
def _is_text_file_path(self, text):
"""Check if the text is a file path to a text file."""
if not text or not isinstance(text, str):
return False
# Check if it looks like a file path (contains / or \ and ends with common text file extensions)
text_extensions = [".txt", ".md", ".rst", ".text"]
has_path_separator = "/" in text or "\\" in text
ends_with_text_ext = any(text.lower().endswith(ext) for ext in text_extensions)
return has_path_separator and ends_with_text_ext
def _read_text_file(self, file_path):
"""Read text content from a file path, resolving relative paths."""
try:
# If absolute path, use as-is
if os.path.isabs(file_path):
resolved_path = file_path
else:
# Try relative to ModelFile
if self.model_file:
model_file_base = os.path.dirname(self.model_file)
candidate_path = os.path.join(model_file_base, file_path)
if os.path.exists(candidate_path):
resolved_path = candidate_path
else:
resolved_path = file_path
else:
resolved_path = file_path
# Try relative to ModelPath
if not os.path.exists(resolved_path) and self.model_path:
candidate_path = os.path.join(self.model_path, file_path)
if os.path.exists(candidate_path):
resolved_path = candidate_path
# Check if file exists and is readable
if not os.path.exists(resolved_path):
print(
f"[ERROR] ref_text file not found: {resolved_path}", file=sys.stderr
)
return None
if not os.path.isfile(resolved_path):
print(
f"[ERROR] ref_text path is not a file: {resolved_path}",
file=sys.stderr,
)
return None
# Read and return file contents
with open(resolved_path, "r", encoding="utf-8") as f:
content = f.read().strip()
print(
f"[INFO] Successfully read ref_text from file: {resolved_path}",
file=sys.stderr,
)
return content
except Exception as e:
print(
f"[ERROR] Failed to read ref_text file {file_path}: {e}",
file=sys.stderr,
)
print(traceback.format_exc(), file=sys.stderr)
return None
def _compute_file_hash(self, file_path):
"""Compute SHA256 hash of file content."""
try:
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
sha256.update(chunk)
return sha256.hexdigest()
except Exception as e:
print(
f"[ERROR] Failed to compute hash for {file_path}: {e}", file=sys.stderr
)
return None
def _compute_string_hash(self, text):
"""Compute SHA256 hash of string."""
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _get_cached_voice_clone_prompt_from_disk(self, ref_audio, ref_text_content):
"""Load cached prompt from disk if available and valid."""
if not self.options.get("disk_cache", False):
return None
cache_file = f"{ref_audio}.voice_cache.pkl"
if not os.path.exists(cache_file):
return None
try:
with open(cache_file, "rb") as f:
cached = pickle.load(f)
# Validate checksums
current_audio_hash = self._compute_file_hash(ref_audio)
current_text_hash = self._compute_string_hash(ref_text_content)
if current_audio_hash is None or cached["audio_hash"] != current_audio_hash:
print("[INFO] Cache invalidation: audio file changed", file=sys.stderr)
os.remove(cache_file)
return None
if cached["ref_text_hash"] != current_text_hash:
print(
"[INFO] Cache invalidation: ref_text content changed",
file=sys.stderr,
)
os.remove(cache_file)
return None
print(
f"[INFO] Loaded voice clone prompt from disk cache: {cache_file}",
file=sys.stderr,
)
return cached["prompt_items"]
except Exception as e:
print(
f"[WARNING] Failed to load disk cache {cache_file}: {e}",
file=sys.stderr,
)
return None
def _save_voice_clone_prompt_to_disk(
self, ref_audio, ref_text_content, prompt_items
):
"""Save prompt to disk cache alongside audio file."""
if not self.options.get("disk_cache", False):
return
cache_file = f"{ref_audio}.voice_cache.pkl"
try:
cache_data = {
"audio_hash": self._compute_file_hash(ref_audio),
"ref_text_hash": self._compute_string_hash(ref_text_content),
"prompt_items": prompt_items,
}
with open(cache_file, "wb") as f:
pickle.dump(cache_data, f)
print(
f"[INFO] Saved voice clone prompt to disk cache: {cache_file}",
file=sys.stderr,
)
except Exception as e:
print(
f"[WARNING] Failed to save disk cache {cache_file}: {e}",
file=sys.stderr,
)
def _get_voice_cache_key(self, ref_audio, ref_text):
"""Get the cache key for a voice."""
return f"{ref_audio}:{ref_text}"
def _preload_cached_voices(self):
"""Pre-load cached voice prompts at model startup."""
if not self.voices or not self.options.get("disk_cache", False):
return
print(
f"[INFO] Pre-loading {len(self.voices)} cached voice(s)...", file=sys.stderr
)
loaded_count = 0
missing_count = 0
invalid_count = 0
for voice_name, voice_config in self.voices.items():
audio_path = voice_config["audio"]
ref_text_path = voice_config["ref_text"]
# Check for cache file
cache_file = f"{audio_path}.voice_cache.pkl"
if os.path.exists(cache_file):
# Read ref_text content for validation
ref_text_content = self._read_text_file(ref_text_path)
if ref_text_content is None:
invalid_count += 1
print(
f"[INFO] Cannot read ref_text for {voice_name} (will recreate on first use)",
file=sys.stderr,
)
continue
cached_prompt = self._get_cached_voice_clone_prompt_from_disk(
audio_path, ref_text_content
)
if cached_prompt:
# Pre-populate memory cache with content-based key
cache_key = self._get_voice_cache_key(audio_path, ref_text_content)
self._voice_clone_cache[cache_key] = cached_prompt
loaded_count += 1
print(f"[INFO] Pre-loaded voice: {voice_name}", file=sys.stderr)
else:
invalid_count += 1
print(
f"[INFO] Cache invalid for {voice_name} (will recreate on first use)",
file=sys.stderr,
)
else:
missing_count += 1
print(
f"[INFO] No cache found for {voice_name} (will create on first use)",
file=sys.stderr,
)
# Summary line
print(
f"[INFO] Pre-loaded {loaded_count}/{len(self.voices)} voices ({missing_count} missing, {invalid_count} invalid)",
file=sys.stderr,
)
def TTS(self, request, context):
try:
# Check if dst is provided
if not request.dst:
return backend_pb2.Result(
success=False, message="dst (output path) is required"
)
# Prepare text
text = request.text.strip()
if not text:
return backend_pb2.Result(success=False, message="Text is empty")
# Get language (auto-detect if not provided)
language = (
request.language
if hasattr(request, "language") and request.language
else None
)
if not language or language == "":
language = "Auto" # Auto-detect language
# Detect mode
mode = self._detect_mode(request)
print(f"Detected mode: {mode}", file=sys.stderr)
# Get generation parameters from options
max_new_tokens = self.options.get("max_new_tokens", None)
top_p = self.options.get("top_p", None)
temperature = self.options.get("temperature", None)
do_sample = self.options.get("do_sample", None)
# Prepare generation kwargs
generation_kwargs = {}
if max_new_tokens is not None:
generation_kwargs["max_new_tokens"] = max_new_tokens
if top_p is not None:
generation_kwargs["top_p"] = top_p
if temperature is not None:
generation_kwargs["temperature"] = temperature
if do_sample is not None:
generation_kwargs["do_sample"] = do_sample
instruct = self.options.get("instruct", "")
if instruct is not None and instruct != "":
generation_kwargs["instruct"] = instruct
# Generate audio based on mode
if mode == "VoiceClone":
# VoiceClone mode
# Check if multi-voice mode is active (voices dict is populated)
voice_name = None
if self.voices:
# Get voice from request (priority) or options
voice_name = request.voice if request.voice else None
if not voice_name:
voice_name = self.options.get("voice", None)
# Validate voice exists
if voice_name and voice_name not in self.voices:
available_voices = ", ".join(sorted(self.voices.keys()))
return backend_pb2.Result(
success=False,
message=f"Voice '{voice_name}' not found. Available voices: {available_voices}",
)
# Get reference audio path (with voice-specific lookup if in multi-voice mode)
ref_audio = self._get_ref_audio_path(request, voice_name)
if not ref_audio:
if voice_name:
return backend_pb2.Result(
success=False,
message=f"Audio path for voice '{voice_name}' could not be resolved",
)
else:
return backend_pb2.Result(
success=False,
message="AudioPath is required for VoiceClone mode",
)
# Get reference text (from voice config if multi-voice, else from options/request)
if voice_name and voice_name in self.voices:
ref_text_source = self.voices[voice_name]["ref_text"]
else:
ref_text_source = self.options.get("ref_text", None)
if not ref_text_source:
# Try to get from request if available
if hasattr(request, "ref_text") and request.ref_text:
ref_text_source = request.ref_text
if not ref_text_source:
# x_vector_only_mode doesn't require ref_text
if not self.options.get("x_vector_only_mode", False):
return backend_pb2.Result(
success=False,
message="ref_text is required for VoiceClone mode (or set x_vector_only_mode=true)",
)
# Determine if ref_text_source is a file path
ref_text_is_file = ref_text_source and self._is_text_file_path(
ref_text_source
)
if ref_text_is_file:
ref_text_content = self._read_text_file(ref_text_source)
if ref_text_content is None:
return backend_pb2.Result(
success=False,
message=f"Failed to read ref_text from file: {ref_text_source}",
)
ref_text_source = ref_text_content
print(
f"[INFO] Loaded ref_text from file: {ref_text_content[:100]}...",
file=sys.stderr,
)
# For caching: use the content as the key (since we've read the file if it was one)
ref_text_for_cache = ref_text_source
# Check if we should use cached prompt
use_cached_prompt = self.options.get("use_cached_prompt", True)
voice_clone_prompt = None
if use_cached_prompt:
voice_clone_prompt = self._get_voice_clone_prompt(
request, ref_audio, ref_text_for_cache
)
if voice_clone_prompt is None:
return backend_pb2.Result(
success=False, message="Failed to create voice clone prompt"
)
if voice_clone_prompt:
# Use cached prompt
start_time = time.time()
wavs, sr = self.model.generate_voice_clone(
text=text,
language=language,
voice_clone_prompt=voice_clone_prompt,
**generation_kwargs,
)
generation_duration = time.time() - start_time
print(
f"[INFO] Voice clone generation completed: {generation_duration:.2f}s, output_samples={len(wavs) if wavs else 0}",
file=sys.stderr,
flush=True,
)
else:
# Create prompt on-the-fly (only for non-file ref_text that wasn't cached)
start_time = time.time()
wavs, sr = self.model.generate_voice_clone(
text=text,
language=language,
ref_audio=ref_audio,
ref_text=ref_text_source,
x_vector_only_mode=self.options.get(
"x_vector_only_mode", False
),
**generation_kwargs,
)
generation_duration = time.time() - start_time
print(
f"[INFO] Voice clone generation (on-the-fly) completed: {generation_duration:.2f}s, output_samples={len(wavs) if wavs else 0}",
file=sys.stderr,
flush=True,
)
elif mode == "VoiceDesign":
# VoiceDesign mode
if not instruct:
return backend_pb2.Result(
success=False,
message="instruct option is required for VoiceDesign mode",
)
wavs, sr = self.model.generate_voice_design(
text=text, language=language, **generation_kwargs
)
else:
# CustomVoice mode (default)
speaker = request.voice if request.voice else None
if not speaker:
# Try to get from options
speaker = self.options.get("speaker", None)
if not speaker:
# Use default speaker
speaker = "Vivian"
print(
f"No speaker specified, using default: {speaker}",
file=sys.stderr,
)
# Validate speaker if model supports it
if hasattr(self.model, "get_supported_speakers"):
try:
supported_speakers = self.model.get_supported_speakers()
if speaker not in supported_speakers:
print(
f"Warning: Speaker '{speaker}' not in supported list. Available: {supported_speakers}",
file=sys.stderr,
)
# Try to find a close match (case-insensitive)
speaker_lower = speaker.lower()
for sup_speaker in supported_speakers:
if sup_speaker.lower() == speaker_lower:
speaker = sup_speaker
print(
f"Using matched speaker: {speaker}",
file=sys.stderr,
)
break
except Exception as e:
print(
f"Warning: Could not get supported speakers: {e}",
file=sys.stderr,
)
wavs, sr = self.model.generate_custom_voice(
text=text, language=language, speaker=speaker, **generation_kwargs
)
# Save output
if wavs is not None and len(wavs) > 0:
# wavs is a list, take first element
audio_data = wavs[0] if isinstance(wavs, list) else wavs
audio_duration = len(audio_data) / sr if sr > 0 else 0
sf.write(request.dst, audio_data, sr)
print(
f"Saved {audio_duration:.2f}s audio to {request.dst}",
file=sys.stderr,
)
else:
return backend_pb2.Result(
success=False, message="No audio output generated"
)
except Exception as err:
print(f"Error in TTS: {err}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return backend_pb2.Result(
success=False, message=f"Unexpected {err=}, {type(err)=}"
)
return backend_pb2.Result(success=True)
def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
("grpc.max_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
serve(args.addr)