From ec1598868b0f627bce2b8c36c71b3e78c596264c Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 27 Jan 2026 20:19:22 +0100 Subject: [PATCH] feat(vibevoice): add ASR support (#8222) * feat(vibevoice): add ASR support Signed-off-by: Ettore Di Giacinto * Add tests Signed-off-by: Ettore Di Giacinto * fixups Signed-off-by: Ettore Di Giacinto * chore(tests): download voice files Signed-off-by: Ettore Di Giacinto * Small fixups Signed-off-by: Ettore Di Giacinto * Small fixups Signed-off-by: Ettore Di Giacinto * Try to run on bigger runner Signed-off-by: Ettore Di Giacinto * Fixups Signed-off-by: Ettore Di Giacinto * Fixups Signed-off-by: Ettore Di Giacinto * Fixups Signed-off-by: Ettore Di Giacinto * debug Signed-off-by: Ettore Di Giacinto * CI can't hold vibevoice Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- .github/workflows/test-extra.yml | 29 +- backend/python/vibevoice/Makefile | 39 +- backend/python/vibevoice/backend.py | 451 ++++++++++++++---- backend/python/vibevoice/install.sh | 14 +- backend/python/vibevoice/requirements-cpu.txt | 2 +- .../vibevoice/requirements-cublas12.txt | 2 +- .../vibevoice/requirements-cublas13.txt | 2 +- .../python/vibevoice/requirements-hipblas.txt | 2 +- .../python/vibevoice/requirements-intel.txt | 2 +- .../python/vibevoice/requirements-l4t12.txt | 2 +- .../python/vibevoice/requirements-l4t13.txt | 2 +- backend/python/vibevoice/requirements-mps.txt | 2 +- backend/python/vibevoice/test.py | 146 +++++- 13 files changed, 576 insertions(+), 119 deletions(-) diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml index 9ef7039e6..abb8fa376 100644 --- a/.github/workflows/test-extra.yml +++ b/.github/workflows/test-extra.yml @@ -238,7 +238,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install build-essential ffmpeg + sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch espeak espeak-ng python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh @@ -257,7 +257,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install build-essential ffmpeg + sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh @@ -276,7 +276,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install build-essential ffmpeg + sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh @@ -295,7 +295,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install build-essential ffmpeg + sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh @@ -303,4 +303,23 @@ jobs: - name: Test qwen-tts run: | make --jobs=5 --output-sync=target -C backend/python/qwen-tts - make --jobs=5 --output-sync=target -C backend/python/qwen-tts test \ No newline at end of file + make --jobs=5 --output-sync=target -C backend/python/qwen-tts test + # tests-vibevoice: + # runs-on: bigger-runner + # steps: + # - name: Clone + # uses: actions/checkout@v6 + # with: + # submodules: true + # - name: Dependencies + # run: | + # sudo apt-get update + # sudo apt-get install -y build-essential ffmpeg + # sudo apt-get install -y ca-certificates cmake curl patch python3-pip wget + # # Install UV + # curl -LsSf https://astral.sh/uv/install.sh | sh + # pip install --user --no-cache-dir --break-system-packages grpcio-tools==1.64.1 + # - name: Test vibevoice + # run: | + # make --jobs=5 --output-sync=target -C backend/python/vibevoice + # make --jobs=5 --output-sync=target -C backend/python/vibevoice test \ No newline at end of file diff --git a/backend/python/vibevoice/Makefile b/backend/python/vibevoice/Makefile index 2fd2297be..141f06f66 100644 --- a/backend/python/vibevoice/Makefile +++ b/backend/python/vibevoice/Makefile @@ -2,6 +2,43 @@ vibevoice: bash install.sh +.PHONY: download-voices +download-voices: + @echo "Downloading voice preset files..." + @mkdir -p voices/streaming_model + @if command -v wget >/dev/null 2>&1; then \ + wget -q -O voices/streaming_model/en-Frank_man.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Frank_man.pt && \ + wget -q -O voices/streaming_model/en-Grace_woman.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Grace_woman.pt && \ + wget -q -O voices/streaming_model/en-Mike_man.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Mike_man.pt && \ + wget -q -O voices/streaming_model/en-Emma_woman.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Emma_woman.pt && \ + wget -q -O voices/streaming_model/en-Carter_man.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Carter_man.pt && \ + wget -q -O voices/streaming_model/en-Davis_man.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Davis_man.pt && \ + echo "Voice files downloaded successfully"; \ + elif command -v curl >/dev/null 2>&1; then \ + curl -sL -o voices/streaming_model/en-Frank_man.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Frank_man.pt && \ + curl -sL -o voices/streaming_model/en-Grace_woman.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Grace_woman.pt && \ + curl -sL -o voices/streaming_model/en-Mike_man.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Mike_man.pt && \ + curl -sL -o voices/streaming_model/en-Emma_woman.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Emma_woman.pt && \ + curl -sL -o voices/streaming_model/en-Carter_man.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Carter_man.pt && \ + curl -sL -o voices/streaming_model/en-Davis_man.pt \ + https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Davis_man.pt && \ + echo "Voice files downloaded successfully"; \ + else \ + echo "Error: Neither wget nor curl found. Cannot download voice files."; \ + exit 1; \ + fi + .PHONY: run run: vibevoice @echo "Running vibevoice..." @@ -9,7 +46,7 @@ run: vibevoice @echo "vibevoice run." .PHONY: test -test: vibevoice +test: vibevoice download-voices @echo "Testing vibevoice..." bash test.sh @echo "vibevoice tested." diff --git a/backend/python/vibevoice/backend.py b/backend/python/vibevoice/backend.py index 418940bcb..2344188d2 100644 --- a/backend/python/vibevoice/backend.py +++ b/backend/python/vibevoice/backend.py @@ -16,6 +16,8 @@ import backend_pb2_grpc import torch from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor +from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration +from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor import grpc @@ -95,21 +97,72 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): value = value.lower() == "true" self.options[key] = value + # Check if ASR mode is enabled + self.asr_mode = self.options.get("asr_mode", False) + if not isinstance(self.asr_mode, bool): + # Handle string "true"/"false" case + self.asr_mode = str(self.asr_mode).lower() == "true" + # Get model path from request model_path = request.Model if not model_path: - model_path = "microsoft/VibeVoice-Realtime-0.5B" + if self.asr_mode: + model_path = "microsoft/VibeVoice-ASR" # Default ASR model + else: + model_path = "microsoft/VibeVoice-Realtime-0.5B" # Default TTS model - # Get inference steps from options, default to 5 + default_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 + + load_dtype = default_dtype + if "torch_dtype" in self.options: + torch_dtype_str = str(self.options["torch_dtype"]).lower() + if torch_dtype_str == "fp16": + load_dtype = torch.float16 + elif torch_dtype_str == "bf16": + load_dtype = torch.bfloat16 + elif torch_dtype_str == "fp32": + load_dtype = torch.float32 + # remove it from options after reading + del self.options["torch_dtype"] + + # Get inference steps from options, default to 5 (TTS only) self.inference_steps = self.options.get("inference_steps", 5) if not isinstance(self.inference_steps, int) or self.inference_steps <= 0: self.inference_steps = 5 - # Get cfg_scale from options, default to 1.5 + # Get cfg_scale from options, default to 1.5 (TTS only) self.cfg_scale = self.options.get("cfg_scale", 1.5) if not isinstance(self.cfg_scale, (int, float)) or self.cfg_scale <= 0: self.cfg_scale = 1.5 + # Get ASR generation parameters from options + self.max_new_tokens = self.options.get("max_new_tokens", 512) + if not isinstance(self.max_new_tokens, int) or self.max_new_tokens <= 0: + self.max_new_tokens = 512 + + self.temperature = self.options.get("temperature", 0.0) + if not isinstance(self.temperature, (int, float)) or self.temperature < 0: + self.temperature = 0.0 + + self.top_p = self.options.get("top_p", 1.0) + if not isinstance(self.top_p, (int, float)) or self.top_p <= 0: + self.top_p = 1.0 + + self.do_sample = self.options.get("do_sample", None) + if self.do_sample is None: + # Default: use sampling if temperature > 0 + self.do_sample = self.temperature > 0 + elif not isinstance(self.do_sample, bool): + self.do_sample = str(self.do_sample).lower() == "true" + + self.num_beams = self.options.get("num_beams", 1) + if not isinstance(self.num_beams, int) or self.num_beams < 1: + self.num_beams = 1 + + self.repetition_penalty = self.options.get("repetition_penalty", 1.0) + if not isinstance(self.repetition_penalty, (int, float)) or self.repetition_penalty <= 0: + self.repetition_penalty = 1.0 + # Determine voices directory # Priority order: # 1. voices_dir option (explicitly set by user - highest priority) @@ -163,91 +216,151 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): else: voices_dir = None + # Initialize voice-related attributes (TTS only) self.voices_dir = voices_dir self.voice_presets = {} self._voice_cache = {} self.default_voice_key = None - - # Load voice presets if directory exists - if self.voices_dir and os.path.exists(self.voices_dir): - self._load_voice_presets() - else: - print(f"Warning: Voices directory not found. Voice presets will not be available.", file=sys.stderr) + + # Store AudioPath, ModelFile, and ModelPath from LoadModel request for use in TTS + self.audio_path = request.AudioPath if hasattr(request, 'AudioPath') and request.AudioPath else None + self.model_file = request.ModelFile if hasattr(request, 'ModelFile') and request.ModelFile else None + self.model_path = request.ModelPath if hasattr(request, 'ModelPath') and request.ModelPath else None + + # Decide attention implementation and device_map (matching upstream example) + if self.device == "mps": + device_map = None + attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS + elif self.device == "cuda": + device_map = "cuda" + attn_impl_primary = "flash_attention_2" + else: # cpu + device_map = "cpu" # Match upstream example: use "cpu" for CPU device_map + attn_impl_primary = "sdpa" try: - print(f"Loading processor & model from {model_path}", file=sys.stderr) - self.processor = VibeVoiceStreamingProcessor.from_pretrained(model_path) + if self.asr_mode: + # Load ASR model and processor + print(f"Loading ASR processor & model from {model_path}", file=sys.stderr) + + # Load ASR processor + self.processor = VibeVoiceASRProcessor.from_pretrained( + model_path, + language_model_pretrained_name="Qwen/Qwen2.5-7B" + ) - # Decide dtype & attention implementation - if self.device == "mps": - load_dtype = torch.float32 # MPS requires float32 - device_map = None - attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS - elif self.device == "cuda": - load_dtype = torch.bfloat16 - device_map = "cuda" - attn_impl_primary = "flash_attention_2" - else: # cpu - load_dtype = torch.float32 - device_map = "cpu" - attn_impl_primary = "sdpa" + print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}", file=sys.stderr) - print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}", file=sys.stderr) - - # Load model with device-specific logic - try: - if self.device == "mps": - self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + # Load ASR model - use device_map=None and move manually to avoid JSON serialization issues + # Load with dtype to ensure all components are in correct dtype from the start + try: + print(f"Using attention implementation: {attn_impl_primary}", file=sys.stderr) + # Load model with dtype to ensure all components are in correct dtype + self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( model_path, - torch_dtype=load_dtype, + dtype=load_dtype, + device_map=None, # Always use None, move manually to avoid JSON serialization issues attn_implementation=attn_impl_primary, - device_map=None, # load then move + trust_remote_code=True ) - self.model.to("mps") - elif self.device == "cuda": - self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( - model_path, - torch_dtype=load_dtype, - device_map="cuda", - attn_implementation=attn_impl_primary, - ) - else: # cpu - self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( - model_path, - torch_dtype=load_dtype, - device_map="cpu", - attn_implementation=attn_impl_primary, - ) - except Exception as e: - if attn_impl_primary == 'flash_attention_2': - print(f"[ERROR] : {type(e).__name__}: {e}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.", file=sys.stderr) - self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( - model_path, - torch_dtype=load_dtype, - device_map=(self.device if self.device in ("cuda", "cpu") else None), - attn_implementation='sdpa' - ) - if self.device == "mps": - self.model.to("mps") - else: - raise e + # Move to device manually + self.model = self.model.to(self.device) + except Exception as e: + if attn_impl_primary == 'flash_attention_2': + print(f"[ERROR] : {type(e).__name__}: {e}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + print("Error loading the ASR model. Trying to use SDPA.", file=sys.stderr) + self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( + model_path, + dtype=load_dtype, + device_map=None, + attn_implementation='sdpa', + trust_remote_code=True + ) + # Move to device manually + self.model = self.model.to(self.device) + else: + raise e - self.model.eval() - self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) - - # Set default voice key - if self.voice_presets: - # Try to get default from environment or use first available - preset_name = os.environ.get("VOICE_PRESET") - self.default_voice_key = self._determine_voice_key(preset_name) - print(f"Default voice preset: {self.default_voice_key}", file=sys.stderr) + self.model.eval() + print(f"ASR model loaded successfully", file=sys.stderr) else: - print("Warning: No voice presets available. Voice selection will not work.", file=sys.stderr) + # Load TTS model and processor (existing logic) + # Load voice presets if directory exists + if self.voices_dir and os.path.exists(self.voices_dir): + self._load_voice_presets() + else: + print(f"Warning: Voices directory not found. Voice presets will not be available.", file=sys.stderr) + + print(f"Loading TTS processor & model from {model_path}", file=sys.stderr) + self.processor = VibeVoiceStreamingProcessor.from_pretrained(model_path) + + + print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}", file=sys.stderr) + + # Load model with device-specific logic (matching upstream example exactly) + try: + if self.device == "mps": + self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + model_path, + torch_dtype=load_dtype, + attn_implementation=attn_impl_primary, + device_map=None, # load then move + ) + self.model.to("mps") + elif self.device == "cuda": + self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + model_path, + torch_dtype=load_dtype, + device_map=device_map, + attn_implementation=attn_impl_primary, + ) + else: # cpu + # Match upstream example: use device_map="cpu" for CPU + self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + model_path, + torch_dtype=load_dtype, + device_map="cpu", + attn_implementation=attn_impl_primary, + ) + except Exception as e: + if attn_impl_primary == 'flash_attention_2': + print(f"[ERROR] : {type(e).__name__}: {e}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.", file=sys.stderr) + # Match upstream example fallback pattern + self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( + model_path, + torch_dtype=load_dtype, + device_map=(self.device if self.device in ("cuda", "cpu") else None), + attn_implementation='sdpa' + ) + if self.device == "mps": + self.model.to("mps") + else: + raise e + + self.model.eval() + self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) + + # Set default voice key + if self.voice_presets: + # Try to get default from environment or use first available + preset_name = os.environ.get("VOICE_PRESET") + self.default_voice_key = self._determine_voice_key(preset_name) + print(f"Default voice preset: {self.default_voice_key}", file=sys.stderr) + else: + print("Warning: No voice presets available. Voice selection will not work.", file=sys.stderr) except Exception as err: - return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + # Format error message safely, avoiding JSON serialization issues + error_msg = str(err) + error_type = type(err).__name__ + # Include traceback for debugging + tb_str = traceback.format_exc() + print(f"[ERROR] LoadModel failed: {error_type}: {error_msg}", file=sys.stderr) + print(tb_str, file=sys.stderr) + return backend_pb2.Result(success=False, message=f"{error_type}: {error_msg}") return backend_pb2.Result(message="Model loaded successfully", success=True) @@ -327,14 +440,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if not voice_path or not os.path.exists(voice_path): return None + # Ensure cache exists (should be initialized in LoadModel) + if not hasattr(self, '_voice_cache'): + self._voice_cache = {} + # Use path as cache key if voice_path not in self._voice_cache: print(f"Loading prefilled prompt from {voice_path}", file=sys.stderr) - prefilled_outputs = torch.load( - voice_path, - map_location=self._torch_device, - weights_only=False, - ) + # Match self-test.py: use string device name for map_location + # Ensure self.device exists (should be set in LoadModel) + try: + if not hasattr(self, 'device'): + # Fallback to CPU if device not set + device_str = "cpu" + else: + device_str = str(self.device) + except AttributeError as e: + print(f"Error accessing self.device: {e}, falling back to CPU", file=sys.stderr) + device_str = "cpu" + if device_str != "cpu": + map_loc = device_str + else: + map_loc = "cpu" + # Call torch.load with explicit arguments + prefilled_outputs = torch.load(voice_path, map_location=map_loc, weights_only=False) self._voice_cache[voice_path] = prefilled_outputs return self._voice_cache[voice_path] @@ -351,17 +480,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): voice_path = self._get_voice_path(request.voice) if voice_path: voice_key = request.voice - elif request.AudioPath: - # Use AudioPath as voice file - if os.path.isabs(request.AudioPath): - voice_path = request.AudioPath - elif request.ModelFile: - model_file_base = os.path.dirname(request.ModelFile) - voice_path = os.path.join(model_file_base, request.AudioPath) - elif hasattr(request, 'ModelPath') and request.ModelPath: - voice_path = os.path.join(request.ModelPath, request.AudioPath) + elif self.audio_path: + # Use AudioPath from LoadModel as voice file + if os.path.isabs(self.audio_path): + voice_path = self.audio_path + elif self.model_file: + model_file_base = os.path.dirname(self.model_file) + voice_path = os.path.join(model_file_base, self.audio_path) + elif self.model_path: + voice_path = os.path.join(self.model_path, self.audio_path) else: - voice_path = request.AudioPath + voice_path = self.audio_path elif self.default_voice_key: voice_path = self._get_voice_path(self.default_voice_key) voice_key = self.default_voice_key @@ -404,8 +533,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return_attention_mask=True, ) - # Move tensors to target device - target_device = self._torch_device + # Move tensors to target device (matching self-test.py exactly) + # Explicitly ensure it's a string to avoid any variable name collisions + target_device = str(self.device) if str(self.device) != "cpu" else "cpu" for k, v in inputs.items(): if torch.is_tensor(v): inputs[k] = v.to(target_device) @@ -447,6 +577,147 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(success=True) + def AudioTranscription(self, request, context): + """Transcribe audio file to text using ASR model.""" + try: + # Validate ASR mode is active + if not self.asr_mode: + return backend_pb2.TranscriptResult( + segments=[], + text="", + ) + # Note: We return empty result instead of error to match faster-whisper behavior + + # Get audio file path + audio_path = request.dst + if not audio_path or not os.path.exists(audio_path): + print(f"Error: Audio file not found: {audio_path}", file=sys.stderr) + return backend_pb2.TranscriptResult( + segments=[], + text="", + ) + + print(f"Transcribing audio file: {audio_path}", file=sys.stderr) + + # Get context_info from options if available + context_info = self.options.get("context_info", None) + if context_info and isinstance(context_info, str) and context_info.strip(): + context_info = context_info.strip() + else: + context_info = None + + # Process audio with ASR processor (matching gradio example) + inputs = self.processor( + audio=audio_path, + sampling_rate=None, + return_tensors="pt", + add_generation_prompt=True, + context_info=context_info + ) + + # Move to device (matching gradio example) + inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v + for k, v in inputs.items()} + + # Prepare generation config (matching gradio example) + generation_config = { + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature if self.temperature > 0 else None, + "top_p": self.top_p if self.do_sample else None, + "do_sample": self.do_sample, + "num_beams": self.num_beams, + "repetition_penalty": self.repetition_penalty, + "pad_token_id": self.processor.pad_id, + "eos_token_id": self.processor.tokenizer.eos_token_id, + } + + # Remove None values (matching gradio example) + generation_config = {k: v for k, v in generation_config.items() if v is not None} + + print(f"Generating transcription with max_new_tokens: {self.max_new_tokens}, temperature: {self.temperature}, do_sample: {self.do_sample}, num_beams: {self.num_beams}, repetition_penalty: {self.repetition_penalty}", file=sys.stderr) + + # Generate transcription (matching gradio example) + with torch.no_grad(): + output_ids = self.model.generate( + **inputs, + **generation_config + ) + + # Decode output (matching gradio example) + generated_ids = output_ids[0, inputs['input_ids'].shape[1]:] + generated_text = self.processor.decode(generated_ids, skip_special_tokens=True) + + # Parse structured output to get segments + result_segments = [] + try: + transcription_segments = self.processor.post_process_transcription(generated_text) + + if transcription_segments: + # Map segments to TranscriptSegment format + for idx, seg in enumerate(transcription_segments): + # Extract timing information (if available) + # Handle both dict and object with attributes + if isinstance(seg, dict): + start_time = seg.get('start_time', 0) + end_time = seg.get('end_time', 0) + text = seg.get('text', '') + speaker_id = seg.get('speaker_id', None) + else: + # Handle object with attributes + start_time = getattr(seg, 'start_time', 0) + end_time = getattr(seg, 'end_time', 0) + text = getattr(seg, 'text', '') + speaker_id = getattr(seg, 'speaker_id', None) + + # Convert time to milliseconds (assuming seconds) + start_ms = int(start_time * 1000) if isinstance(start_time, (int, float)) else 0 + end_ms = int(end_time * 1000) if isinstance(end_time, (int, float)) else 0 + + # Add speaker info to text if available + if speaker_id is not None: + text = f"[Speaker {speaker_id}] {text}" + + result_segments.append(backend_pb2.TranscriptSegment( + id=idx, + start=start_ms, + end=end_ms, + text=text, + tokens=[] # Token IDs not extracted for now + )) + except Exception as e: + print(f"Warning: Failed to parse structured output: {e}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + # Fallback: create a single segment with the full text + if generated_text: + result_segments.append(backend_pb2.TranscriptSegment( + id=0, + start=0, + end=0, + text=generated_text, + tokens=[] + )) + + # Combine all segment texts into full transcription + if result_segments: + full_text = " ".join([seg.text for seg in result_segments]) + else: + full_text = generated_text if generated_text else "" + + print(f"Transcription completed: {len(result_segments)} segments", file=sys.stderr) + + return backend_pb2.TranscriptResult( + segments=result_segments, + text=full_text + ) + + except Exception as err: + print(f"Error in AudioTranscription: {err}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return backend_pb2.TranscriptResult( + segments=[], + text="", + ) + def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ diff --git a/backend/python/vibevoice/install.sh b/backend/python/vibevoice/install.sh index 3f669c6fe..199c46e42 100755 --- a/backend/python/vibevoice/install.sh +++ b/backend/python/vibevoice/install.sh @@ -29,11 +29,13 @@ fi installRequirements -git clone https://github.com/microsoft/VibeVoice.git -cd VibeVoice/ +if [ ! -d VibeVoice ]; then + git clone https://github.com/microsoft/VibeVoice.git + cd VibeVoice/ -if [ "x${USE_PIP}" == "xtrue" ]; then - pip install ${EXTRA_PIP_INSTALL_FLAGS:-} . -else - uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} . + if [ "x${USE_PIP}" == "xtrue" ]; then + pip install ${EXTRA_PIP_INSTALL_FLAGS:-} . + else + uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} . + fi fi \ No newline at end of file diff --git a/backend/python/vibevoice/requirements-cpu.txt b/backend/python/vibevoice/requirements-cpu.txt index 607db4ae3..553394261 100644 --- a/backend/python/vibevoice/requirements-cpu.txt +++ b/backend/python/vibevoice/requirements-cpu.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cpu git+https://github.com/huggingface/diffusers opencv-python -transformers==4.51.3 +transformers>=4.51.3,<5.0.0 torchvision==0.22.1 accelerate compel diff --git a/backend/python/vibevoice/requirements-cublas12.txt b/backend/python/vibevoice/requirements-cublas12.txt index 267a0313e..c8973eb3e 100644 --- a/backend/python/vibevoice/requirements-cublas12.txt +++ b/backend/python/vibevoice/requirements-cublas12.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cu121 git+https://github.com/huggingface/diffusers opencv-python -transformers==4.51.3 +transformers>=4.51.3,<5.0.0 torchvision accelerate compel diff --git a/backend/python/vibevoice/requirements-cublas13.txt b/backend/python/vibevoice/requirements-cublas13.txt index 372be740b..b75b15728 100644 --- a/backend/python/vibevoice/requirements-cublas13.txt +++ b/backend/python/vibevoice/requirements-cublas13.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cu130 git+https://github.com/huggingface/diffusers opencv-python -transformers==4.51.3 +transformers>=4.51.3,<5.0.0 torchvision accelerate compel diff --git a/backend/python/vibevoice/requirements-hipblas.txt b/backend/python/vibevoice/requirements-hipblas.txt index 291096c3f..931dd1e0a 100644 --- a/backend/python/vibevoice/requirements-hipblas.txt +++ b/backend/python/vibevoice/requirements-hipblas.txt @@ -3,7 +3,7 @@ torch==2.7.1+rocm6.3 torchvision==0.22.1+rocm6.3 git+https://github.com/huggingface/diffusers opencv-python -transformers==4.51.3 +transformers>=4.51.3,<5.0.0 accelerate compel peft diff --git a/backend/python/vibevoice/requirements-intel.txt b/backend/python/vibevoice/requirements-intel.txt index af061781b..7dee3464f 100644 --- a/backend/python/vibevoice/requirements-intel.txt +++ b/backend/python/vibevoice/requirements-intel.txt @@ -5,7 +5,7 @@ optimum[openvino] setuptools git+https://github.com/huggingface/diffusers opencv-python -transformers==4.51.3 +transformers>=4.51.3,<5.0.0 accelerate compel peft diff --git a/backend/python/vibevoice/requirements-l4t12.txt b/backend/python/vibevoice/requirements-l4t12.txt index 4e033c0f6..e959f0cb2 100644 --- a/backend/python/vibevoice/requirements-l4t12.txt +++ b/backend/python/vibevoice/requirements-l4t12.txt @@ -1,7 +1,7 @@ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch git+https://github.com/huggingface/diffusers -transformers==4.51.3 +transformers>=4.51.3,<5.0.0 accelerate compel peft diff --git a/backend/python/vibevoice/requirements-l4t13.txt b/backend/python/vibevoice/requirements-l4t13.txt index ca4848d50..8e67d8503 100644 --- a/backend/python/vibevoice/requirements-l4t13.txt +++ b/backend/python/vibevoice/requirements-l4t13.txt @@ -1,7 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cu130 torch git+https://github.com/huggingface/diffusers -transformers==4.51.3 +transformers>=4.51.3,<5.0.0 accelerate compel peft diff --git a/backend/python/vibevoice/requirements-mps.txt b/backend/python/vibevoice/requirements-mps.txt index 11757190e..41f8a0b65 100644 --- a/backend/python/vibevoice/requirements-mps.txt +++ b/backend/python/vibevoice/requirements-mps.txt @@ -2,7 +2,7 @@ torch==2.7.1 torchvision==0.22.1 git+https://github.com/huggingface/diffusers opencv-python -transformers==4.51.3 +transformers>=4.51.3,<5.0.0 accelerate compel peft diff --git a/backend/python/vibevoice/test.py b/backend/python/vibevoice/test.py index e0b1a0bdd..95d38f562 100644 --- a/backend/python/vibevoice/test.py +++ b/backend/python/vibevoice/test.py @@ -1,14 +1,21 @@ """ -A test script to test the gRPC service +A test script to test the gRPC service for VibeVoice TTS and ASR """ import unittest import subprocess import time +import os +import tempfile +import shutil import backend_pb2 import backend_pb2_grpc import grpc +# Check if we should skip ASR tests (they require large models ~14B parameters total) +# Skip in CI or if explicitly disabled +SKIP_ASR_TESTS = os.environ.get("SKIP_ASR_TESTS", "false").lower() == "true" + class TestBackendServicer(unittest.TestCase): """ @@ -44,15 +51,15 @@ class TestBackendServicer(unittest.TestCase): finally: self.tearDown() - def test_load_model(self): + def test_load_tts_model(self): """ - This method tests if the model is loaded successfully + This method tests if the TTS model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits")) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="microsoft/VibeVoice-Realtime-0.5B")) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") @@ -62,21 +69,142 @@ class TestBackendServicer(unittest.TestCase): finally: self.tearDown() - def test_tts(self): + @unittest.skipIf(SKIP_ASR_TESTS, "ASR tests require large models (~14B parameters) and are skipped in CI") + def test_load_asr_model(self): """ - This method tests if the embeddings are generated successfully + This method tests if the ASR model is loaded successfully with asr_mode option """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) - response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits")) + response = stub.LoadModel(backend_pb2.ModelOptions( + Model="microsoft/VibeVoice-ASR", + Options=["asr_mode:true"] + )) + print(f"LoadModel response: {response}") + if not response.success: + print(f"LoadModel failed with message: {response.message}") + self.assertTrue(response.success, f"LoadModel failed: {response.message}") + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(f"Exception during LoadModel: {err}") + import traceback + traceback.print_exc() + self.fail("LoadModel service failed for ASR mode") + finally: + self.tearDown() + + def test_tts(self): + """ + This method tests if TTS generation works successfully + """ + # Create a temporary directory for the output audio file + temp_dir = tempfile.mkdtemp() + output_file = os.path.join(temp_dir, 'output.wav') + + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + # Load TTS model + response = stub.LoadModel(backend_pb2.ModelOptions(Model="microsoft/VibeVoice-Realtime-0.5B")) self.assertTrue(response.success) - tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story") + + # Generate TTS + tts_request = backend_pb2.TTSRequest( + text="Hello, this is a test of the VibeVoice text to speech system.", + dst=output_file + ) tts_response = stub.TTS(tts_request) + + # Verify response self.assertIsNotNone(tts_response) + self.assertTrue(tts_response.success) + + # Verify output file was created + self.assertTrue(os.path.exists(output_file), f"Output file was not created: {output_file}") + self.assertGreater(os.path.getsize(output_file), 0, "Output file is empty") except Exception as err: print(err) self.fail("TTS service failed") finally: - self.tearDown() \ No newline at end of file + self.tearDown() + # Clean up the temporary directory + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + @unittest.skipIf(SKIP_ASR_TESTS, "ASR tests require large models (~14B parameters) and are skipped in CI") + def test_audio_transcription(self): + """ + This method tests if audio transcription works successfully + """ + # Create a temporary directory for the audio file + temp_dir = tempfile.mkdtemp() + audio_file = os.path.join(temp_dir, 'audio.wav') + + try: + # Download the audio file to the temporary directory + print(f"Downloading audio file to {audio_file}...") + url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav" + result = subprocess.run( + ["wget", "-q", url, "-O", audio_file], + capture_output=True, + text=True + ) + if result.returncode != 0: + self.fail(f"Failed to download audio file: {result.stderr}") + + # Verify the file was downloaded + if not os.path.exists(audio_file): + self.fail(f"Audio file was not downloaded to {audio_file}") + + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + # Load the ASR model first + load_response = stub.LoadModel(backend_pb2.ModelOptions( + Model="microsoft/VibeVoice-ASR", + Options=["asr_mode:true"] + )) + print(f"LoadModel response: {load_response}") + if not load_response.success: + print(f"LoadModel failed with message: {load_response.message}") + self.assertTrue(load_response.success, f"LoadModel failed: {load_response.message}") + + # Perform transcription + transcript_request = backend_pb2.TranscriptRequest(dst=audio_file) + transcript_response = stub.AudioTranscription(transcript_request) + + # Print the transcribed text for debugging + print(f"Transcribed text: {transcript_response.text}") + print(f"Number of segments: {len(transcript_response.segments)}") + + # Verify response structure + self.assertIsNotNone(transcript_response) + self.assertIsNotNone(transcript_response.text) + # Protobuf repeated fields return a sequence, not a list + self.assertIsNotNone(transcript_response.segments) + # Check if segments is iterable (has length) + self.assertGreaterEqual(len(transcript_response.segments), 0) + + # Verify the transcription contains some text + self.assertGreater(len(transcript_response.text), 0, "Transcription should not be empty") + + # If we got segments, verify they have the expected structure + if len(transcript_response.segments) > 0: + segment = transcript_response.segments[0] + self.assertIsNotNone(segment.text) + self.assertIsInstance(segment.id, int) + else: + # Even if no segments, we should have text + self.assertIsNotNone(transcript_response.text) + self.assertGreater(len(transcript_response.text), 0) + except Exception as err: + print(err) + self.fail("AudioTranscription service failed") + finally: + self.tearDown() + # Clean up the temporary directory + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) \ No newline at end of file