#!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Qwen3-TTS """ from concurrent import futures import time import argparse import signal import sys import os import copy import traceback from pathlib import Path import backend_pb2 import backend_pb2_grpc import torch import soundfile as sf from qwen_tts import Qwen3TTSModel import json import hashlib import pickle import grpc def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1")) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", "utf-8")) def LoadModel(self, request, context): # Get device if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = ( hasattr(torch.backends, "mps") and torch.backends.mps.is_available() ) if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") # Normalize potential 'mpx' typo to 'mps' if device == "mpx": print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr) device = "mps" # Validate mps availability if requested if device == "mps" and not torch.backends.mps.is_available(): print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr) device = "cpu" self.device = device self._torch_device = torch.device(device) options = request.Options # empty dict self.options = {} # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the audio for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # Split only on first colon # if value is a number, convert it to the appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value # Parse voices configuration from options self.voices = {} if "voices" in self.options: try: voices_data = self.options["voices"] if isinstance(voices_data, str): voices_list = json.loads(voices_data) else: voices_list = voices_data # Validate and store voices for voice_entry in voices_list: if not isinstance(voice_entry, dict): print( f"[WARNING] Invalid voice entry (not a dict): {voice_entry}", file=sys.stderr, ) continue name = voice_entry.get("name") audio = voice_entry.get("audio") ref_text = voice_entry.get("ref_text") if not name or not isinstance(name, str): print( f"[WARNING] Voice entry missing required 'name' field: {voice_entry}", file=sys.stderr, ) continue if not audio or not isinstance(audio, str): print( f"[WARNING] Voice entry missing required 'audio' field: {voice_entry}", file=sys.stderr, ) continue if ref_text is None or not isinstance(ref_text, str): print( f"[WARNING] Voice entry missing required 'ref_text' field: {voice_entry}", file=sys.stderr, ) continue self.voices[name] = {"audio": audio, "ref_text": ref_text} print( f"[INFO] Registered voice '{name}' with audio: {audio}", file=sys.stderr, ) print(f"[INFO] Loaded {len(self.voices)} voice(s)", file=sys.stderr) except json.JSONDecodeError as e: print(f"[ERROR] Failed to parse voices JSON: {e}", file=sys.stderr) except Exception as e: print( f"[ERROR] Error processing voices configuration: {e}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) # Get model path from request model_path = request.Model if not model_path: model_path = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" # Determine model type from model path or options self.model_type = self.options.get("model_type", None) if not self.model_type: if "CustomVoice" in model_path: self.model_type = "CustomVoice" elif "VoiceDesign" in model_path: self.model_type = "VoiceDesign" elif "Base" in model_path or "0.6B" in model_path or "1.7B" in model_path: self.model_type = "Base" # VoiceClone model else: # Default to CustomVoice self.model_type = "CustomVoice" # Cache for voice clone prompts self._voice_clone_cache = {} # Pre-load cached voices if disk_cache is enabled self._preload_cached_voices() # Store AudioPath, ModelFile, and ModelPath from LoadModel request # These are used later in TTS for VoiceClone mode self.audio_path = ( request.AudioPath if hasattr(request, "AudioPath") and request.AudioPath else None ) self.model_file = ( request.ModelFile if hasattr(request, "ModelFile") and request.ModelFile else None ) self.model_path = ( request.ModelPath if hasattr(request, "ModelPath") and request.ModelPath else None ) # Decide dtype & attention implementation if self.device == "mps": load_dtype = torch.float32 # MPS requires float32 device_map = None attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS elif self.device == "cuda": load_dtype = torch.bfloat16 device_map = "cuda" attn_impl_primary = "flash_attention_2" else: # cpu load_dtype = torch.float32 device_map = "cpu" attn_impl_primary = "sdpa" print( f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}, model_type: {self.model_type}", file=sys.stderr, ) print(f"Loading model from: {model_path}", file=sys.stderr) # Load model with device-specific logic # Common parameters for all devices load_kwargs = { "dtype": load_dtype, "attn_implementation": attn_impl_primary, "trust_remote_code": True, # Required for qwen-tts models } try: if self.device == "mps": load_kwargs["device_map"] = None # load then move self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) self.model.to("mps") elif self.device == "cuda": load_kwargs["device_map"] = device_map self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) else: # cpu load_kwargs["device_map"] = device_map self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) except Exception as e: error_msg = str(e) print( f"[ERROR] Loading model: {type(e).__name__}: {error_msg}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) # Check if it's a missing feature extractor/tokenizer error if ( "speech_tokenizer" in error_msg or "preprocessor_config.json" in error_msg or "feature extractor" in error_msg.lower() ): print( "\n[ERROR] Model files appear to be incomplete. This usually means:", file=sys.stderr, ) print( " 1. The model download was interrupted or incomplete", file=sys.stderr, ) print(" 2. The model cache is corrupted", file=sys.stderr) print("\nTo fix this, try:", file=sys.stderr) print( f" rm -rf ~/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-*", file=sys.stderr, ) print(" Then re-run to trigger a fresh download.", file=sys.stderr) print( "\nAlternatively, try using a different model variant:", file=sys.stderr, ) print(" - Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", file=sys.stderr) print(" - Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", file=sys.stderr) print(" - Qwen/Qwen3-TTS-12Hz-1.7B-Base", file=sys.stderr) if attn_impl_primary == "flash_attention_2": print( "\nTrying to use SDPA instead of flash_attention_2...", file=sys.stderr, ) load_kwargs["attn_implementation"] = "sdpa" try: if self.device == "mps": load_kwargs["device_map"] = None self.model = Qwen3TTSModel.from_pretrained( model_path, **load_kwargs ) self.model.to("mps") else: load_kwargs["device_map"] = ( self.device if self.device in ("cuda", "cpu") else None ) self.model = Qwen3TTSModel.from_pretrained( model_path, **load_kwargs ) except Exception as e2: print( f"[ERROR] Failed to load with SDPA: {type(e2).__name__}: {e2}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) raise e2 else: raise e print(f"Model loaded successfully: {model_path}", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) def _detect_mode(self, request): """Detect which mode to use based on request parameters.""" # Priority: VoiceClone > VoiceDesign > CustomVoice # model_type explicitly set if self.model_type == "CustomVoice": return "CustomVoice" if self.model_type == "VoiceClone": return "VoiceClone" if self.model_type == "VoiceDesign": return "VoiceDesign" # VoiceClone: AudioPath is provided OR voices dict is populated if self.audio_path or self.voices: return "VoiceClone" # VoiceDesign: instruct option is provided if "instruct" in self.options and self.options["instruct"]: return "VoiceDesign" # Default to CustomVoice return "CustomVoice" def _get_ref_audio_path(self, request, voice_name=None): """Get reference audio path from stored AudioPath or from voices dict.""" # If voice_name is provided and exists in voices dict, use that if voice_name and voice_name in self.voices: audio_path = self.voices[voice_name]["audio"] # If absolute path, use as-is if os.path.isabs(audio_path): return audio_path # Try relative to ModelFile if self.model_file: model_file_base = os.path.dirname(self.model_file) ref_path = os.path.join(model_file_base, audio_path) if os.path.exists(ref_path): return ref_path # Try relative to ModelPath if self.model_path: ref_path = os.path.join(self.model_path, audio_path) if os.path.exists(ref_path): return ref_path # Return as-is (might be URL or base64) return audio_path # Fall back to legacy single-voice mode using self.audio_path if not self.audio_path: return None # If absolute path, use as-is if os.path.isabs(self.audio_path): return self.audio_path # Try relative to ModelFile if self.model_file: model_file_base = os.path.dirname(self.model_file) ref_path = os.path.join(model_file_base, self.audio_path) if os.path.exists(ref_path): return ref_path # Try relative to ModelPath if self.model_path: ref_path = os.path.join(self.model_path, self.audio_path) if os.path.exists(ref_path): return ref_path # Return as-is (might be URL or base64) return self.audio_path def _get_voice_clone_prompt(self, request, ref_audio, ref_text): """Get or create voice clone prompt, with in-memory and disk caching.""" cache_key = self._get_voice_cache_key(ref_audio, ref_text) if cache_key not in self._voice_clone_cache: # Check disk cache first (if enabled) disk_cached = self._get_cached_voice_clone_prompt_from_disk( ref_audio, ref_text ) if disk_cached is not None: self._voice_clone_cache[cache_key] = disk_cached else: # Create new prompt print(f"Creating voice clone prompt from {ref_audio}", file=sys.stderr) try: prompt_items = self.model.create_voice_clone_prompt( ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=self.options.get( "x_vector_only_mode", False ), ) self._voice_clone_cache[cache_key] = prompt_items # Save to disk cache if enabled self._save_voice_clone_prompt_to_disk( ref_audio, ref_text, prompt_items ) except Exception as e: print(f"Error creating voice clone prompt: {e}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return None return self._voice_clone_cache[cache_key] def _is_text_file_path(self, text): """Check if the text is a file path to a text file.""" if not text or not isinstance(text, str): return False # Check if it looks like a file path (contains / or \ and ends with common text file extensions) text_extensions = [".txt", ".md", ".rst", ".text"] has_path_separator = "/" in text or "\\" in text ends_with_text_ext = any(text.lower().endswith(ext) for ext in text_extensions) return has_path_separator and ends_with_text_ext def _read_text_file(self, file_path): """Read text content from a file path, resolving relative paths.""" try: # If absolute path, use as-is if os.path.isabs(file_path): resolved_path = file_path else: # Try relative to ModelFile if self.model_file: model_file_base = os.path.dirname(self.model_file) candidate_path = os.path.join(model_file_base, file_path) if os.path.exists(candidate_path): resolved_path = candidate_path else: resolved_path = file_path else: resolved_path = file_path # Try relative to ModelPath if not os.path.exists(resolved_path) and self.model_path: candidate_path = os.path.join(self.model_path, file_path) if os.path.exists(candidate_path): resolved_path = candidate_path # Check if file exists and is readable if not os.path.exists(resolved_path): print( f"[ERROR] ref_text file not found: {resolved_path}", file=sys.stderr ) return None if not os.path.isfile(resolved_path): print( f"[ERROR] ref_text path is not a file: {resolved_path}", file=sys.stderr, ) return None # Read and return file contents with open(resolved_path, "r", encoding="utf-8") as f: content = f.read().strip() print( f"[INFO] Successfully read ref_text from file: {resolved_path}", file=sys.stderr, ) return content except Exception as e: print( f"[ERROR] Failed to read ref_text file {file_path}: {e}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) return None def _compute_file_hash(self, file_path): """Compute SHA256 hash of file content.""" try: sha256 = hashlib.sha256() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): sha256.update(chunk) return sha256.hexdigest() except Exception as e: print( f"[ERROR] Failed to compute hash for {file_path}: {e}", file=sys.stderr ) return None def _compute_string_hash(self, text): """Compute SHA256 hash of string.""" return hashlib.sha256(text.encode("utf-8")).hexdigest() def _get_cached_voice_clone_prompt_from_disk(self, ref_audio, ref_text_content): """Load cached prompt from disk if available and valid.""" if not self.options.get("disk_cache", False): return None cache_file = f"{ref_audio}.voice_cache.pkl" if not os.path.exists(cache_file): return None try: with open(cache_file, "rb") as f: cached = pickle.load(f) # Validate checksums current_audio_hash = self._compute_file_hash(ref_audio) current_text_hash = self._compute_string_hash(ref_text_content) if current_audio_hash is None or cached["audio_hash"] != current_audio_hash: print("[INFO] Cache invalidation: audio file changed", file=sys.stderr) os.remove(cache_file) return None if cached["ref_text_hash"] != current_text_hash: print( "[INFO] Cache invalidation: ref_text content changed", file=sys.stderr, ) os.remove(cache_file) return None print( f"[INFO] Loaded voice clone prompt from disk cache: {cache_file}", file=sys.stderr, ) return cached["prompt_items"] except Exception as e: print( f"[WARNING] Failed to load disk cache {cache_file}: {e}", file=sys.stderr, ) return None def _save_voice_clone_prompt_to_disk( self, ref_audio, ref_text_content, prompt_items ): """Save prompt to disk cache alongside audio file.""" if not self.options.get("disk_cache", False): return cache_file = f"{ref_audio}.voice_cache.pkl" try: cache_data = { "audio_hash": self._compute_file_hash(ref_audio), "ref_text_hash": self._compute_string_hash(ref_text_content), "prompt_items": prompt_items, } with open(cache_file, "wb") as f: pickle.dump(cache_data, f) print( f"[INFO] Saved voice clone prompt to disk cache: {cache_file}", file=sys.stderr, ) except Exception as e: print( f"[WARNING] Failed to save disk cache {cache_file}: {e}", file=sys.stderr, ) def _get_voice_cache_key(self, ref_audio, ref_text): """Get the cache key for a voice.""" return f"{ref_audio}:{ref_text}" def _preload_cached_voices(self): """Pre-load cached voice prompts at model startup.""" if not self.voices or not self.options.get("disk_cache", False): return print( f"[INFO] Pre-loading {len(self.voices)} cached voice(s)...", file=sys.stderr ) loaded_count = 0 missing_count = 0 invalid_count = 0 for voice_name, voice_config in self.voices.items(): audio_path = voice_config["audio"] ref_text_path = voice_config["ref_text"] # Check for cache file cache_file = f"{audio_path}.voice_cache.pkl" if os.path.exists(cache_file): # Read ref_text content for validation ref_text_content = self._read_text_file(ref_text_path) if ref_text_content is None: invalid_count += 1 print( f"[INFO] Cannot read ref_text for {voice_name} (will recreate on first use)", file=sys.stderr, ) continue cached_prompt = self._get_cached_voice_clone_prompt_from_disk( audio_path, ref_text_content ) if cached_prompt: # Pre-populate memory cache with content-based key cache_key = self._get_voice_cache_key(audio_path, ref_text_content) self._voice_clone_cache[cache_key] = cached_prompt loaded_count += 1 print(f"[INFO] Pre-loaded voice: {voice_name}", file=sys.stderr) else: invalid_count += 1 print( f"[INFO] Cache invalid for {voice_name} (will recreate on first use)", file=sys.stderr, ) else: missing_count += 1 print( f"[INFO] No cache found for {voice_name} (will create on first use)", file=sys.stderr, ) # Summary line print( f"[INFO] Pre-loaded {loaded_count}/{len(self.voices)} voices ({missing_count} missing, {invalid_count} invalid)", file=sys.stderr, ) def TTS(self, request, context): try: # Check if dst is provided if not request.dst: return backend_pb2.Result( success=False, message="dst (output path) is required" ) # Prepare text text = request.text.strip() if not text: return backend_pb2.Result(success=False, message="Text is empty") # Get language (auto-detect if not provided) language = ( request.language if hasattr(request, "language") and request.language else None ) if not language or language == "": language = "Auto" # Auto-detect language # Detect mode mode = self._detect_mode(request) print(f"Detected mode: {mode}", file=sys.stderr) # Get generation parameters from options max_new_tokens = self.options.get("max_new_tokens", None) top_p = self.options.get("top_p", None) temperature = self.options.get("temperature", None) do_sample = self.options.get("do_sample", None) # Prepare generation kwargs generation_kwargs = {} if max_new_tokens is not None: generation_kwargs["max_new_tokens"] = max_new_tokens if top_p is not None: generation_kwargs["top_p"] = top_p if temperature is not None: generation_kwargs["temperature"] = temperature if do_sample is not None: generation_kwargs["do_sample"] = do_sample instruct = self.options.get("instruct", "") if instruct is not None and instruct != "": generation_kwargs["instruct"] = instruct # Generate audio based on mode if mode == "VoiceClone": # VoiceClone mode # Check if multi-voice mode is active (voices dict is populated) voice_name = None if self.voices: # Get voice from request (priority) or options voice_name = request.voice if request.voice else None if not voice_name: voice_name = self.options.get("voice", None) # Validate voice exists if voice_name and voice_name not in self.voices: available_voices = ", ".join(sorted(self.voices.keys())) return backend_pb2.Result( success=False, message=f"Voice '{voice_name}' not found. Available voices: {available_voices}", ) # Get reference audio path (with voice-specific lookup if in multi-voice mode) ref_audio = self._get_ref_audio_path(request, voice_name) if not ref_audio: if voice_name: return backend_pb2.Result( success=False, message=f"Audio path for voice '{voice_name}' could not be resolved", ) else: return backend_pb2.Result( success=False, message="AudioPath is required for VoiceClone mode", ) # Get reference text (from voice config if multi-voice, else from options/request) if voice_name and voice_name in self.voices: ref_text_source = self.voices[voice_name]["ref_text"] else: ref_text_source = self.options.get("ref_text", None) if not ref_text_source: # Try to get from request if available if hasattr(request, "ref_text") and request.ref_text: ref_text_source = request.ref_text if not ref_text_source: # x_vector_only_mode doesn't require ref_text if not self.options.get("x_vector_only_mode", False): return backend_pb2.Result( success=False, message="ref_text is required for VoiceClone mode (or set x_vector_only_mode=true)", ) # Determine if ref_text_source is a file path ref_text_is_file = ref_text_source and self._is_text_file_path( ref_text_source ) if ref_text_is_file: ref_text_content = self._read_text_file(ref_text_source) if ref_text_content is None: return backend_pb2.Result( success=False, message=f"Failed to read ref_text from file: {ref_text_source}", ) ref_text_source = ref_text_content print( f"[INFO] Loaded ref_text from file: {ref_text_content[:100]}...", file=sys.stderr, ) # For caching: use the content as the key (since we've read the file if it was one) ref_text_for_cache = ref_text_source # Check if we should use cached prompt use_cached_prompt = self.options.get("use_cached_prompt", True) voice_clone_prompt = None if use_cached_prompt: voice_clone_prompt = self._get_voice_clone_prompt( request, ref_audio, ref_text_for_cache ) if voice_clone_prompt is None: return backend_pb2.Result( success=False, message="Failed to create voice clone prompt" ) if voice_clone_prompt: # Use cached prompt start_time = time.time() wavs, sr = self.model.generate_voice_clone( text=text, language=language, voice_clone_prompt=voice_clone_prompt, **generation_kwargs, ) generation_duration = time.time() - start_time print( f"[INFO] Voice clone generation completed: {generation_duration:.2f}s, output_samples={len(wavs) if wavs else 0}", file=sys.stderr, flush=True, ) else: # Create prompt on-the-fly (only for non-file ref_text that wasn't cached) start_time = time.time() wavs, sr = self.model.generate_voice_clone( text=text, language=language, ref_audio=ref_audio, ref_text=ref_text_source, x_vector_only_mode=self.options.get( "x_vector_only_mode", False ), **generation_kwargs, ) generation_duration = time.time() - start_time print( f"[INFO] Voice clone generation (on-the-fly) completed: {generation_duration:.2f}s, output_samples={len(wavs) if wavs else 0}", file=sys.stderr, flush=True, ) elif mode == "VoiceDesign": # VoiceDesign mode if not instruct: return backend_pb2.Result( success=False, message="instruct option is required for VoiceDesign mode", ) wavs, sr = self.model.generate_voice_design( text=text, language=language, **generation_kwargs ) else: # CustomVoice mode (default) speaker = request.voice if request.voice else None if not speaker: # Try to get from options speaker = self.options.get("speaker", None) if not speaker: # Use default speaker speaker = "Vivian" print( f"No speaker specified, using default: {speaker}", file=sys.stderr, ) # Validate speaker if model supports it if hasattr(self.model, "get_supported_speakers"): try: supported_speakers = self.model.get_supported_speakers() if speaker not in supported_speakers: print( f"Warning: Speaker '{speaker}' not in supported list. Available: {supported_speakers}", file=sys.stderr, ) # Try to find a close match (case-insensitive) speaker_lower = speaker.lower() for sup_speaker in supported_speakers: if sup_speaker.lower() == speaker_lower: speaker = sup_speaker print( f"Using matched speaker: {speaker}", file=sys.stderr, ) break except Exception as e: print( f"Warning: Could not get supported speakers: {e}", file=sys.stderr, ) wavs, sr = self.model.generate_custom_voice( text=text, language=language, speaker=speaker, **generation_kwargs ) # Save output if wavs is not None and len(wavs) > 0: # wavs is a list, take first element audio_data = wavs[0] if isinstance(wavs, list) else wavs audio_duration = len(audio_data) / sr if sr > 0 else 0 sf.write(request.dst, audio_data, sr) print( f"Saved {audio_duration:.2f}s audio to {request.dst}", file=sys.stderr, ) else: return backend_pb2.Result( success=False, message="No audio output generated" ) except Exception as err: print(f"Error in TTS: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result( success=False, message=f"Unexpected {err=}, {type(err)=}" ) return backend_pb2.Result(success=True) def serve(address): server = grpc.server( futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ("grpc.max_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ], ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr)