mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
feat(qwen-tts): Support using multiple voices (#8757)
* Add support for multiple voice clones in Qwen TTS Signed-off-by: Andres Smith <andressmithdev@pm.me> * Add voice prompt caching and generation logs to see generation time --------- Signed-off-by: Andres Smith <andressmithdev@pm.me> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for Qwen3-TTS
|
||||
"""
|
||||
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
@@ -17,8 +18,13 @@ import torch
|
||||
import soundfile as sf
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
import pickle
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
@@ -27,6 +33,7 @@ def is_float(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
@@ -35,10 +42,11 @@ def is_int(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1"))
|
||||
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
@@ -46,9 +54,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
return backend_pb2.Reply(message=bytes("OK", "utf-8"))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
# Get device
|
||||
if torch.cuda.is_available():
|
||||
@@ -57,7 +66,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
mps_available = (
|
||||
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
)
|
||||
if mps_available:
|
||||
device = "mps"
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
@@ -67,7 +78,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if device == "mpx":
|
||||
print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
|
||||
device = "mps"
|
||||
|
||||
|
||||
# Validate mps availability if requested
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
|
||||
@@ -97,6 +108,64 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
# Parse voices configuration from options
|
||||
self.voices = {}
|
||||
if "voices" in self.options:
|
||||
try:
|
||||
voices_data = self.options["voices"]
|
||||
if isinstance(voices_data, str):
|
||||
voices_list = json.loads(voices_data)
|
||||
else:
|
||||
voices_list = voices_data
|
||||
|
||||
# Validate and store voices
|
||||
for voice_entry in voices_list:
|
||||
if not isinstance(voice_entry, dict):
|
||||
print(
|
||||
f"[WARNING] Invalid voice entry (not a dict): {voice_entry}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
name = voice_entry.get("name")
|
||||
audio = voice_entry.get("audio")
|
||||
ref_text = voice_entry.get("ref_text")
|
||||
|
||||
if not name or not isinstance(name, str):
|
||||
print(
|
||||
f"[WARNING] Voice entry missing required 'name' field: {voice_entry}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
if not audio or not isinstance(audio, str):
|
||||
print(
|
||||
f"[WARNING] Voice entry missing required 'audio' field: {voice_entry}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
if ref_text is None or not isinstance(ref_text, str):
|
||||
print(
|
||||
f"[WARNING] Voice entry missing required 'ref_text' field: {voice_entry}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
self.voices[name] = {"audio": audio, "ref_text": ref_text}
|
||||
print(
|
||||
f"[INFO] Registered voice '{name}' with audio: {audio}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
print(f"[INFO] Loaded {len(self.voices)} voice(s)", file=sys.stderr)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[ERROR] Failed to parse voices JSON: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[ERROR] Error processing voices configuration: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
# Get model path from request
|
||||
model_path = request.Model
|
||||
if not model_path:
|
||||
@@ -118,11 +187,26 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# Cache for voice clone prompts
|
||||
self._voice_clone_cache = {}
|
||||
|
||||
# Pre-load cached voices if disk_cache is enabled
|
||||
self._preload_cached_voices()
|
||||
|
||||
# Store AudioPath, ModelFile, and ModelPath from LoadModel request
|
||||
# These are used later in TTS for VoiceClone mode
|
||||
self.audio_path = request.AudioPath if hasattr(request, 'AudioPath') and request.AudioPath else None
|
||||
self.model_file = request.ModelFile if hasattr(request, 'ModelFile') and request.ModelFile else None
|
||||
self.model_path = request.ModelPath if hasattr(request, 'ModelPath') and request.ModelPath else None
|
||||
self.audio_path = (
|
||||
request.AudioPath
|
||||
if hasattr(request, "AudioPath") and request.AudioPath
|
||||
else None
|
||||
)
|
||||
self.model_file = (
|
||||
request.ModelFile
|
||||
if hasattr(request, "ModelFile") and request.ModelFile
|
||||
else None
|
||||
)
|
||||
self.model_path = (
|
||||
request.ModelPath
|
||||
if hasattr(request, "ModelPath") and request.ModelPath
|
||||
else None
|
||||
)
|
||||
|
||||
# Decide dtype & attention implementation
|
||||
if self.device == "mps":
|
||||
@@ -138,7 +222,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
device_map = "cpu"
|
||||
attn_impl_primary = "sdpa"
|
||||
|
||||
print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}, model_type: {self.model_type}", file=sys.stderr)
|
||||
print(
|
||||
f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}, model_type: {self.model_type}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f"Loading model from: {model_path}", file=sys.stderr)
|
||||
|
||||
# Load model with device-specific logic
|
||||
@@ -148,7 +235,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"attn_implementation": attn_impl_primary,
|
||||
"trust_remote_code": True, # Required for qwen-tts models
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
if self.device == "mps":
|
||||
load_kwargs["device_map"] = None # load then move
|
||||
@@ -162,35 +249,66 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"[ERROR] Loading model: {type(e).__name__}: {error_msg}", file=sys.stderr)
|
||||
print(
|
||||
f"[ERROR] Loading model: {type(e).__name__}: {error_msg}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# Check if it's a missing feature extractor/tokenizer error
|
||||
if "speech_tokenizer" in error_msg or "preprocessor_config.json" in error_msg or "feature extractor" in error_msg.lower():
|
||||
print("\n[ERROR] Model files appear to be incomplete. This usually means:", file=sys.stderr)
|
||||
print(" 1. The model download was interrupted or incomplete", file=sys.stderr)
|
||||
if (
|
||||
"speech_tokenizer" in error_msg
|
||||
or "preprocessor_config.json" in error_msg
|
||||
or "feature extractor" in error_msg.lower()
|
||||
):
|
||||
print(
|
||||
"\n[ERROR] Model files appear to be incomplete. This usually means:",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
" 1. The model download was interrupted or incomplete",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(" 2. The model cache is corrupted", file=sys.stderr)
|
||||
print("\nTo fix this, try:", file=sys.stderr)
|
||||
print(f" rm -rf ~/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-*", file=sys.stderr)
|
||||
print(
|
||||
f" rm -rf ~/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-*",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(" Then re-run to trigger a fresh download.", file=sys.stderr)
|
||||
print("\nAlternatively, try using a different model variant:", file=sys.stderr)
|
||||
print(
|
||||
"\nAlternatively, try using a different model variant:",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", file=sys.stderr)
|
||||
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", file=sys.stderr)
|
||||
print(" - Qwen/Qwen3-TTS-12Hz-1.7B-Base", file=sys.stderr)
|
||||
|
||||
if attn_impl_primary == 'flash_attention_2':
|
||||
print("\nTrying to use SDPA instead of flash_attention_2...", file=sys.stderr)
|
||||
load_kwargs["attn_implementation"] = 'sdpa'
|
||||
|
||||
if attn_impl_primary == "flash_attention_2":
|
||||
print(
|
||||
"\nTrying to use SDPA instead of flash_attention_2...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
load_kwargs["attn_implementation"] = "sdpa"
|
||||
try:
|
||||
if self.device == "mps":
|
||||
load_kwargs["device_map"] = None
|
||||
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path, **load_kwargs
|
||||
)
|
||||
self.model.to("mps")
|
||||
else:
|
||||
load_kwargs["device_map"] = (self.device if self.device in ("cuda", "cpu") else None)
|
||||
self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs)
|
||||
load_kwargs["device_map"] = (
|
||||
self.device if self.device in ("cuda", "cpu") else None
|
||||
)
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path, **load_kwargs
|
||||
)
|
||||
except Exception as e2:
|
||||
print(f"[ERROR] Failed to load with SDPA: {type(e2).__name__}: {e2}", file=sys.stderr)
|
||||
print(
|
||||
f"[ERROR] Failed to load with SDPA: {type(e2).__name__}: {e2}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise e2
|
||||
else:
|
||||
@@ -212,81 +330,338 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if self.model_type == "VoiceDesign":
|
||||
return "VoiceDesign"
|
||||
|
||||
# VoiceClone: AudioPath is provided (from LoadModel, stored in self.audio_path)
|
||||
if self.audio_path:
|
||||
# VoiceClone: AudioPath is provided OR voices dict is populated
|
||||
if self.audio_path or self.voices:
|
||||
return "VoiceClone"
|
||||
|
||||
|
||||
# VoiceDesign: instruct option is provided
|
||||
if "instruct" in self.options and self.options["instruct"]:
|
||||
return "VoiceDesign"
|
||||
|
||||
|
||||
# Default to CustomVoice
|
||||
return "CustomVoice"
|
||||
|
||||
def _get_ref_audio_path(self, request):
|
||||
"""Get reference audio path from stored AudioPath (from LoadModel)."""
|
||||
def _get_ref_audio_path(self, request, voice_name=None):
|
||||
"""Get reference audio path from stored AudioPath or from voices dict."""
|
||||
# If voice_name is provided and exists in voices dict, use that
|
||||
if voice_name and voice_name in self.voices:
|
||||
audio_path = self.voices[voice_name]["audio"]
|
||||
|
||||
# If absolute path, use as-is
|
||||
if os.path.isabs(audio_path):
|
||||
return audio_path
|
||||
|
||||
# Try relative to ModelFile
|
||||
if self.model_file:
|
||||
model_file_base = os.path.dirname(self.model_file)
|
||||
ref_path = os.path.join(model_file_base, audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
# Try relative to ModelPath
|
||||
if self.model_path:
|
||||
ref_path = os.path.join(self.model_path, audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
# Return as-is (might be URL or base64)
|
||||
return audio_path
|
||||
|
||||
# Fall back to legacy single-voice mode using self.audio_path
|
||||
if not self.audio_path:
|
||||
return None
|
||||
|
||||
|
||||
# If absolute path, use as-is
|
||||
if os.path.isabs(self.audio_path):
|
||||
return self.audio_path
|
||||
|
||||
|
||||
# Try relative to ModelFile
|
||||
if self.model_file:
|
||||
model_file_base = os.path.dirname(self.model_file)
|
||||
ref_path = os.path.join(model_file_base, self.audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
|
||||
# Try relative to ModelPath
|
||||
if self.model_path:
|
||||
ref_path = os.path.join(self.model_path, self.audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
|
||||
# Return as-is (might be URL or base64)
|
||||
return self.audio_path
|
||||
|
||||
def _get_voice_clone_prompt(self, request, ref_audio, ref_text):
|
||||
"""Get or create voice clone prompt, with caching."""
|
||||
cache_key = f"{ref_audio}:{ref_text}"
|
||||
|
||||
"""Get or create voice clone prompt, with in-memory and disk caching."""
|
||||
cache_key = self._get_voice_cache_key(ref_audio, ref_text)
|
||||
|
||||
if cache_key not in self._voice_clone_cache:
|
||||
print(f"Creating voice clone prompt from {ref_audio}", file=sys.stderr)
|
||||
try:
|
||||
prompt_items = self.model.create_voice_clone_prompt(
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
x_vector_only_mode=self.options.get("x_vector_only_mode", False),
|
||||
)
|
||||
self._voice_clone_cache[cache_key] = prompt_items
|
||||
except Exception as e:
|
||||
print(f"Error creating voice clone prompt: {e}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
# Check disk cache first (if enabled)
|
||||
disk_cached = self._get_cached_voice_clone_prompt_from_disk(
|
||||
ref_audio, ref_text
|
||||
)
|
||||
if disk_cached is not None:
|
||||
self._voice_clone_cache[cache_key] = disk_cached
|
||||
else:
|
||||
# Create new prompt
|
||||
print(f"Creating voice clone prompt from {ref_audio}", file=sys.stderr)
|
||||
try:
|
||||
prompt_items = self.model.create_voice_clone_prompt(
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
x_vector_only_mode=self.options.get(
|
||||
"x_vector_only_mode", False
|
||||
),
|
||||
)
|
||||
self._voice_clone_cache[cache_key] = prompt_items
|
||||
# Save to disk cache if enabled
|
||||
self._save_voice_clone_prompt_to_disk(
|
||||
ref_audio, ref_text, prompt_items
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating voice clone prompt: {e}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
return self._voice_clone_cache[cache_key]
|
||||
|
||||
def _is_text_file_path(self, text):
|
||||
"""Check if the text is a file path to a text file."""
|
||||
if not text or not isinstance(text, str):
|
||||
return False
|
||||
# Check if it looks like a file path (contains / or \ and ends with common text file extensions)
|
||||
text_extensions = [".txt", ".md", ".rst", ".text"]
|
||||
has_path_separator = "/" in text or "\\" in text
|
||||
ends_with_text_ext = any(text.lower().endswith(ext) for ext in text_extensions)
|
||||
return has_path_separator and ends_with_text_ext
|
||||
|
||||
def _read_text_file(self, file_path):
|
||||
"""Read text content from a file path, resolving relative paths."""
|
||||
try:
|
||||
# If absolute path, use as-is
|
||||
if os.path.isabs(file_path):
|
||||
resolved_path = file_path
|
||||
else:
|
||||
# Try relative to ModelFile
|
||||
if self.model_file:
|
||||
model_file_base = os.path.dirname(self.model_file)
|
||||
candidate_path = os.path.join(model_file_base, file_path)
|
||||
if os.path.exists(candidate_path):
|
||||
resolved_path = candidate_path
|
||||
else:
|
||||
resolved_path = file_path
|
||||
else:
|
||||
resolved_path = file_path
|
||||
|
||||
# Try relative to ModelPath
|
||||
if not os.path.exists(resolved_path) and self.model_path:
|
||||
candidate_path = os.path.join(self.model_path, file_path)
|
||||
if os.path.exists(candidate_path):
|
||||
resolved_path = candidate_path
|
||||
|
||||
# Check if file exists and is readable
|
||||
if not os.path.exists(resolved_path):
|
||||
print(
|
||||
f"[ERROR] ref_text file not found: {resolved_path}", file=sys.stderr
|
||||
)
|
||||
return None
|
||||
|
||||
if not os.path.isfile(resolved_path):
|
||||
print(
|
||||
f"[ERROR] ref_text path is not a file: {resolved_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
# Read and return file contents
|
||||
with open(resolved_path, "r", encoding="utf-8") as f:
|
||||
content = f.read().strip()
|
||||
|
||||
print(
|
||||
f"[INFO] Successfully read ref_text from file: {resolved_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[ERROR] Failed to read ref_text file {file_path}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
def _compute_file_hash(self, file_path):
|
||||
"""Compute SHA256 hash of file content."""
|
||||
try:
|
||||
sha256 = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[ERROR] Failed to compute hash for {file_path}: {e}", file=sys.stderr
|
||||
)
|
||||
return None
|
||||
|
||||
def _compute_string_hash(self, text):
|
||||
"""Compute SHA256 hash of string."""
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
def _get_cached_voice_clone_prompt_from_disk(self, ref_audio, ref_text_content):
|
||||
"""Load cached prompt from disk if available and valid."""
|
||||
if not self.options.get("disk_cache", False):
|
||||
return None
|
||||
|
||||
cache_file = f"{ref_audio}.voice_cache.pkl"
|
||||
|
||||
if not os.path.exists(cache_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(cache_file, "rb") as f:
|
||||
cached = pickle.load(f)
|
||||
|
||||
# Validate checksums
|
||||
current_audio_hash = self._compute_file_hash(ref_audio)
|
||||
current_text_hash = self._compute_string_hash(ref_text_content)
|
||||
|
||||
if current_audio_hash is None or cached["audio_hash"] != current_audio_hash:
|
||||
print("[INFO] Cache invalidation: audio file changed", file=sys.stderr)
|
||||
os.remove(cache_file)
|
||||
return None
|
||||
|
||||
if cached["ref_text_hash"] != current_text_hash:
|
||||
print(
|
||||
"[INFO] Cache invalidation: ref_text content changed",
|
||||
file=sys.stderr,
|
||||
)
|
||||
os.remove(cache_file)
|
||||
return None
|
||||
|
||||
print(
|
||||
f"[INFO] Loaded voice clone prompt from disk cache: {cache_file}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return cached["prompt_items"]
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[WARNING] Failed to load disk cache {cache_file}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
def _save_voice_clone_prompt_to_disk(
|
||||
self, ref_audio, ref_text_content, prompt_items
|
||||
):
|
||||
"""Save prompt to disk cache alongside audio file."""
|
||||
if not self.options.get("disk_cache", False):
|
||||
return
|
||||
|
||||
cache_file = f"{ref_audio}.voice_cache.pkl"
|
||||
|
||||
try:
|
||||
cache_data = {
|
||||
"audio_hash": self._compute_file_hash(ref_audio),
|
||||
"ref_text_hash": self._compute_string_hash(ref_text_content),
|
||||
"prompt_items": prompt_items,
|
||||
}
|
||||
|
||||
with open(cache_file, "wb") as f:
|
||||
pickle.dump(cache_data, f)
|
||||
|
||||
print(
|
||||
f"[INFO] Saved voice clone prompt to disk cache: {cache_file}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[WARNING] Failed to save disk cache {cache_file}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def _get_voice_cache_key(self, ref_audio, ref_text):
|
||||
"""Get the cache key for a voice."""
|
||||
return f"{ref_audio}:{ref_text}"
|
||||
|
||||
def _preload_cached_voices(self):
|
||||
"""Pre-load cached voice prompts at model startup."""
|
||||
if not self.voices or not self.options.get("disk_cache", False):
|
||||
return
|
||||
|
||||
print(
|
||||
f"[INFO] Pre-loading {len(self.voices)} cached voice(s)...", file=sys.stderr
|
||||
)
|
||||
loaded_count = 0
|
||||
missing_count = 0
|
||||
invalid_count = 0
|
||||
|
||||
for voice_name, voice_config in self.voices.items():
|
||||
audio_path = voice_config["audio"]
|
||||
ref_text_path = voice_config["ref_text"]
|
||||
|
||||
# Check for cache file
|
||||
cache_file = f"{audio_path}.voice_cache.pkl"
|
||||
if os.path.exists(cache_file):
|
||||
# Read ref_text content for validation
|
||||
ref_text_content = self._read_text_file(ref_text_path)
|
||||
if ref_text_content is None:
|
||||
invalid_count += 1
|
||||
print(
|
||||
f"[INFO] Cannot read ref_text for {voice_name} (will recreate on first use)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
cached_prompt = self._get_cached_voice_clone_prompt_from_disk(
|
||||
audio_path, ref_text_content
|
||||
)
|
||||
if cached_prompt:
|
||||
# Pre-populate memory cache with content-based key
|
||||
cache_key = self._get_voice_cache_key(audio_path, ref_text_content)
|
||||
self._voice_clone_cache[cache_key] = cached_prompt
|
||||
loaded_count += 1
|
||||
print(f"[INFO] Pre-loaded voice: {voice_name}", file=sys.stderr)
|
||||
else:
|
||||
invalid_count += 1
|
||||
print(
|
||||
f"[INFO] Cache invalid for {voice_name} (will recreate on first use)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
missing_count += 1
|
||||
print(
|
||||
f"[INFO] No cache found for {voice_name} (will create on first use)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Summary line
|
||||
print(
|
||||
f"[INFO] Pre-loaded {loaded_count}/{len(self.voices)} voices ({missing_count} missing, {invalid_count} invalid)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# Check if dst is provided
|
||||
if not request.dst:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="dst (output path) is required"
|
||||
success=False, message="dst (output path) is required"
|
||||
)
|
||||
|
||||
|
||||
# Prepare text
|
||||
text = request.text.strip()
|
||||
if not text:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="Text is empty"
|
||||
)
|
||||
return backend_pb2.Result(success=False, message="Text is empty")
|
||||
|
||||
# Get language (auto-detect if not provided)
|
||||
language = request.language if hasattr(request, 'language') and request.language else None
|
||||
language = (
|
||||
request.language
|
||||
if hasattr(request, "language") and request.language
|
||||
else None
|
||||
)
|
||||
if not language or language == "":
|
||||
language = "Auto" # Auto-detect language
|
||||
|
||||
@@ -318,55 +693,123 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# Generate audio based on mode
|
||||
if mode == "VoiceClone":
|
||||
# VoiceClone mode
|
||||
ref_audio = self._get_ref_audio_path(request)
|
||||
|
||||
# Check if multi-voice mode is active (voices dict is populated)
|
||||
voice_name = None
|
||||
if self.voices:
|
||||
# Get voice from request (priority) or options
|
||||
voice_name = request.voice if request.voice else None
|
||||
if not voice_name:
|
||||
voice_name = self.options.get("voice", None)
|
||||
|
||||
# Validate voice exists
|
||||
if voice_name and voice_name not in self.voices:
|
||||
available_voices = ", ".join(sorted(self.voices.keys()))
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message=f"Voice '{voice_name}' not found. Available voices: {available_voices}",
|
||||
)
|
||||
|
||||
# Get reference audio path (with voice-specific lookup if in multi-voice mode)
|
||||
ref_audio = self._get_ref_audio_path(request, voice_name)
|
||||
if not ref_audio:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="AudioPath is required for VoiceClone mode"
|
||||
)
|
||||
|
||||
ref_text = self.options.get("ref_text", None)
|
||||
if not ref_text:
|
||||
# Try to get from request if available
|
||||
if hasattr(request, 'ref_text') and request.ref_text:
|
||||
ref_text = request.ref_text
|
||||
if voice_name:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message=f"Audio path for voice '{voice_name}' could not be resolved",
|
||||
)
|
||||
else:
|
||||
# x_vector_only_mode doesn't require ref_text
|
||||
if not self.options.get("x_vector_only_mode", False):
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="ref_text is required for VoiceClone mode (or set x_vector_only_mode=true)"
|
||||
)
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="AudioPath is required for VoiceClone mode",
|
||||
)
|
||||
|
||||
# Get reference text (from voice config if multi-voice, else from options/request)
|
||||
if voice_name and voice_name in self.voices:
|
||||
ref_text_source = self.voices[voice_name]["ref_text"]
|
||||
else:
|
||||
ref_text_source = self.options.get("ref_text", None)
|
||||
if not ref_text_source:
|
||||
# Try to get from request if available
|
||||
if hasattr(request, "ref_text") and request.ref_text:
|
||||
ref_text_source = request.ref_text
|
||||
|
||||
if not ref_text_source:
|
||||
# x_vector_only_mode doesn't require ref_text
|
||||
if not self.options.get("x_vector_only_mode", False):
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="ref_text is required for VoiceClone mode (or set x_vector_only_mode=true)",
|
||||
)
|
||||
|
||||
# Determine if ref_text_source is a file path
|
||||
ref_text_is_file = ref_text_source and self._is_text_file_path(
|
||||
ref_text_source
|
||||
)
|
||||
|
||||
if ref_text_is_file:
|
||||
ref_text_content = self._read_text_file(ref_text_source)
|
||||
if ref_text_content is None:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message=f"Failed to read ref_text from file: {ref_text_source}",
|
||||
)
|
||||
ref_text_source = ref_text_content
|
||||
print(
|
||||
f"[INFO] Loaded ref_text from file: {ref_text_content[:100]}...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# For caching: use the content as the key (since we've read the file if it was one)
|
||||
ref_text_for_cache = ref_text_source
|
||||
|
||||
# Check if we should use cached prompt
|
||||
use_cached_prompt = self.options.get("use_cached_prompt", True)
|
||||
voice_clone_prompt = None
|
||||
|
||||
|
||||
if use_cached_prompt:
|
||||
voice_clone_prompt = self._get_voice_clone_prompt(request, ref_audio, ref_text)
|
||||
if voice_clone_prompt is None:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="Failed to create voice clone prompt"
|
||||
)
|
||||
voice_clone_prompt = self._get_voice_clone_prompt(
|
||||
request, ref_audio, ref_text_for_cache
|
||||
)
|
||||
|
||||
if voice_clone_prompt is None:
|
||||
return backend_pb2.Result(
|
||||
success=False, message="Failed to create voice clone prompt"
|
||||
)
|
||||
|
||||
if voice_clone_prompt:
|
||||
# Use cached prompt
|
||||
start_time = time.time()
|
||||
wavs, sr = self.model.generate_voice_clone(
|
||||
text=text,
|
||||
language=language,
|
||||
voice_clone_prompt=voice_clone_prompt,
|
||||
**generation_kwargs
|
||||
**generation_kwargs,
|
||||
)
|
||||
generation_duration = time.time() - start_time
|
||||
print(
|
||||
f"[INFO] Voice clone generation completed: {generation_duration:.2f}s, output_samples={len(wavs) if wavs else 0}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
# Create prompt on-the-fly
|
||||
# Create prompt on-the-fly (only for non-file ref_text that wasn't cached)
|
||||
start_time = time.time()
|
||||
wavs, sr = self.model.generate_voice_clone(
|
||||
text=text,
|
||||
language=language,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
x_vector_only_mode=self.options.get("x_vector_only_mode", False),
|
||||
**generation_kwargs
|
||||
ref_text=ref_text_source,
|
||||
x_vector_only_mode=self.options.get(
|
||||
"x_vector_only_mode", False
|
||||
),
|
||||
**generation_kwargs,
|
||||
)
|
||||
generation_duration = time.time() - start_time
|
||||
print(
|
||||
f"[INFO] Voice clone generation (on-the-fly) completed: {generation_duration:.2f}s, output_samples={len(wavs) if wavs else 0}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
elif mode == "VoiceDesign":
|
||||
@@ -374,14 +817,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if not instruct:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="instruct option is required for VoiceDesign mode"
|
||||
message="instruct option is required for VoiceDesign mode",
|
||||
)
|
||||
|
||||
wavs, sr = self.model.generate_voice_design(
|
||||
text=text,
|
||||
language=language,
|
||||
instruct=instruct,
|
||||
**generation_kwargs
|
||||
text=text, language=language, instruct=instruct, **generation_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -393,57 +833,74 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if not speaker:
|
||||
# Use default speaker
|
||||
speaker = "Vivian"
|
||||
print(f"No speaker specified, using default: {speaker}", file=sys.stderr)
|
||||
print(
|
||||
f"No speaker specified, using default: {speaker}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Validate speaker if model supports it
|
||||
if hasattr(self.model, 'get_supported_speakers'):
|
||||
if hasattr(self.model, "get_supported_speakers"):
|
||||
try:
|
||||
supported_speakers = self.model.get_supported_speakers()
|
||||
if speaker not in supported_speakers:
|
||||
print(f"Warning: Speaker '{speaker}' not in supported list. Available: {supported_speakers}", file=sys.stderr)
|
||||
print(
|
||||
f"Warning: Speaker '{speaker}' not in supported list. Available: {supported_speakers}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Try to find a close match (case-insensitive)
|
||||
speaker_lower = speaker.lower()
|
||||
for sup_speaker in supported_speakers:
|
||||
if sup_speaker.lower() == speaker_lower:
|
||||
speaker = sup_speaker
|
||||
print(f"Using matched speaker: {speaker}", file=sys.stderr)
|
||||
print(
|
||||
f"Using matched speaker: {speaker}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not get supported speakers: {e}", file=sys.stderr)
|
||||
print(
|
||||
f"Warning: Could not get supported speakers: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
wavs, sr = self.model.generate_custom_voice(
|
||||
text=text,
|
||||
language=language,
|
||||
speaker=speaker,
|
||||
**generation_kwargs
|
||||
text=text, language=language, speaker=speaker, **generation_kwargs
|
||||
)
|
||||
|
||||
# Save output
|
||||
if wavs is not None and len(wavs) > 0:
|
||||
# wavs is a list, take first element
|
||||
audio_data = wavs[0] if isinstance(wavs, list) else wavs
|
||||
audio_duration = len(audio_data) / sr if sr > 0 else 0
|
||||
sf.write(request.dst, audio_data, sr)
|
||||
print(f"Saved output to {request.dst}", file=sys.stderr)
|
||||
print(
|
||||
f"Saved {audio_duration:.2f}s audio to {request.dst}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="No audio output generated"
|
||||
success=False, message="No audio output generated"
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in TTS: {err}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
return backend_pb2.Result(
|
||||
success=False, message=f"Unexpected {err=}, {type(err)=}"
|
||||
)
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
("grpc.max_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
@@ -465,6 +922,7 @@ def serve(address):
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
|
||||
@@ -364,7 +364,54 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
## Using config files
|
||||
#### Multi-Voice Clone Mode
|
||||
|
||||
Qwen3-TTS also supports loading multiple voices for voice cloning, allowing you to select different voices at request time. Configure multiple voices using the `voices` option:
|
||||
|
||||
```yaml
|
||||
name: qwen-tts-multi-voice
|
||||
backend: qwen-tts
|
||||
parameters:
|
||||
model: Qwen/Qwen3-TTS-12Hz-1.7B-Base
|
||||
options:
|
||||
- voices:[{"name":"jane","audio":"voices/jane.wav","ref_text":"voices/jane-ref.txt"},{"name":"john","audio":"voices/john.wav","ref_text":"voices/john-ref.txt"}]
|
||||
```
|
||||
|
||||
The `voices` option accepts a JSON array where each voice entry must have:
|
||||
- `name`: The voice identifier (used in API requests)
|
||||
- `audio`: Path to the reference audio file (relative to model directory or absolute)
|
||||
- `ref_text`: Path to the reference text file for the audio it is paired with
|
||||
|
||||
Then use the model with voice selection:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
"model": "qwen-tts-multi-voice",
|
||||
"input":"Hello world, this is Jane speaking.",
|
||||
"voice": "jane"
|
||||
}' | aplay
|
||||
|
||||
# Switch to a different voice
|
||||
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
"model": "qwen-tts-multi-voice",
|
||||
"input":"Hello world, this is John speaking.",
|
||||
"voice": "john"
|
||||
}' | aplay
|
||||
```
|
||||
|
||||
**Voice Selection Priority:**
|
||||
1. `voice` parameter in the API request (highest priority)
|
||||
2. `voice` option in the model configuration
|
||||
3. Error if voice is not found among configured voices
|
||||
|
||||
**Error Handling:**
|
||||
If you request a voice that doesn't exist in the voices list, the API will return an error with a list of available voices:
|
||||
```json
|
||||
{"error": "Voice 'unknown' not found. Available voices: jane, john"}
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
The multi-voice mode is backward compatible with existing single-voice configurations. Models using `audio_path` in the `tts` section will continue to work as before.
|
||||
|
||||
You can also use a `config-file` to specify TTS models and their parameters.
|
||||
|
||||
@@ -408,4 +455,4 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
If a `response_format` is added in the query (other than `wav`) and ffmpeg is not available, the call will fail.
|
||||
If a `response_format` is added in the query (other than `wav`) and ffmpeg is not available, the call will fail.
|
||||
|
||||
Reference in New Issue
Block a user