From c8245d069dcdb9a30b13ed4339786170dd0e9bde Mon Sep 17 00:00:00 2001 From: eureka928 Date: Fri, 30 Jan 2026 10:10:05 +0100 Subject: [PATCH] feat(whisperx): add whisperx backend for transcription with diarization Add Python gRPC backend using WhisperX for speech-to-text with word-level timestamps, forced alignment, and speaker diarization via pyannote-audio when HF_TOKEN is provided. Signed-off-by: eureka928 --- backend/python/whisperx/Makefile | 16 ++ backend/python/whisperx/backend.py | 169 ++++++++++++++++++ backend/python/whisperx/install.sh | 11 ++ backend/python/whisperx/protogen.sh | 11 ++ backend/python/whisperx/requirements-cpu.txt | 2 + .../python/whisperx/requirements-cublas12.txt | 2 + .../python/whisperx/requirements-cublas13.txt | 3 + .../python/whisperx/requirements-hipblas.txt | 3 + backend/python/whisperx/requirements.txt | 3 + backend/python/whisperx/run.sh | 9 + backend/python/whisperx/test.py | 124 +++++++++++++ backend/python/whisperx/test.sh | 11 ++ 12 files changed, 364 insertions(+) create mode 100644 backend/python/whisperx/Makefile create mode 100644 backend/python/whisperx/backend.py create mode 100755 backend/python/whisperx/install.sh create mode 100755 backend/python/whisperx/protogen.sh create mode 100644 backend/python/whisperx/requirements-cpu.txt create mode 100644 backend/python/whisperx/requirements-cublas12.txt create mode 100644 backend/python/whisperx/requirements-cublas13.txt create mode 100644 backend/python/whisperx/requirements-hipblas.txt create mode 100644 backend/python/whisperx/requirements.txt create mode 100755 backend/python/whisperx/run.sh create mode 100644 backend/python/whisperx/test.py create mode 100755 backend/python/whisperx/test.sh diff --git a/backend/python/whisperx/Makefile b/backend/python/whisperx/Makefile new file mode 100644 index 000000000..8ad2368ab --- /dev/null +++ b/backend/python/whisperx/Makefile @@ -0,0 +1,16 @@ +.DEFAULT_GOAL := install + +.PHONY: install +install: + bash install.sh + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ + +test: install + bash test.sh diff --git a/backend/python/whisperx/backend.py b/backend/python/whisperx/backend.py new file mode 100644 index 000000000..7fd5cfb42 --- /dev/null +++ b/backend/python/whisperx/backend.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +""" +This is an extra gRPC server of LocalAI for WhisperX transcription +with speaker diarization, word-level timestamps, and forced alignment. +""" +from concurrent import futures +import time +import argparse +import signal +import sys +import os +import backend_pb2 +import backend_pb2_grpc + +import grpc + + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + BackendServicer is the class that implements the gRPC service + """ + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + import whisperx + import torch + + device = "cpu" + if request.CUDA: + device = "cuda" + mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + if mps_available: + device = "mps" + + try: + print("Preparing WhisperX model, please wait", file=sys.stderr) + compute_type = "float16" if device != "cpu" else "int8" + self.model = whisperx.load_model( + request.Model, + device, + compute_type=compute_type, + ) + self.device = device + self.model_name = request.Model + + # Store HF token for diarization if available + self.hf_token = os.environ.get("HF_TOKEN", None) + self.diarize_pipeline = None + + # Cache for alignment models keyed by language code + self.align_cache = {} + + print(f"WhisperX model loaded: {request.Model} on {device}", file=sys.stderr) + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + return backend_pb2.Result(message="Model loaded successfully", success=True) + + def _get_align_model(self, language_code): + """Load or return cached alignment model for a given language.""" + import whisperx + if language_code not in self.align_cache: + model_a, metadata = whisperx.load_align_model( + language_code=language_code, + device=self.device, + ) + self.align_cache[language_code] = (model_a, metadata) + return self.align_cache[language_code] + + def AudioTranscription(self, request, context): + import whisperx + + resultSegments = [] + text = "" + try: + audio = whisperx.load_audio(request.dst) + + # Transcribe + transcript = self.model.transcribe( + audio, + batch_size=16, + language=request.language if request.language else None, + ) + + # Align for word-level timestamps + model_a, metadata = self._get_align_model(transcript["language"]) + transcript = whisperx.align( + transcript["segments"], + model_a, + metadata, + audio, + self.device, + return_char_alignments=False, + ) + + # Diarize if requested and HF token is available + if request.diarize and self.hf_token: + if self.diarize_pipeline is None: + self.diarize_pipeline = whisperx.DiarizationPipeline( + use_auth_token=self.hf_token, + device=self.device, + ) + diarize_segments = self.diarize_pipeline(audio) + transcript = whisperx.assign_word_speakers(diarize_segments, transcript) + + # Build result segments + for idx, seg in enumerate(transcript["segments"]): + seg_text = seg.get("text", "") + start = int(seg.get("start", 0)) + end = int(seg.get("end", 0)) + speaker = seg.get("speaker", "") + + resultSegments.append(backend_pb2.TranscriptSegment( + id=idx, + start=start, + end=end, + text=seg_text, + speaker=speaker, + )) + text += seg_text + + except Exception as err: + print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) + return backend_pb2.TranscriptResult(segments=[], text="") + + return backend_pb2.TranscriptResult(segments=resultSegments, text=text) + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB + ]) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) diff --git a/backend/python/whisperx/install.sh b/backend/python/whisperx/install.sh new file mode 100755 index 000000000..4136d8765 --- /dev/null +++ b/backend/python/whisperx/install.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +installRequirements diff --git a/backend/python/whisperx/protogen.sh b/backend/python/whisperx/protogen.sh new file mode 100755 index 000000000..1ad37dee1 --- /dev/null +++ b/backend/python/whisperx/protogen.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto diff --git a/backend/python/whisperx/requirements-cpu.txt b/backend/python/whisperx/requirements-cpu.txt new file mode 100644 index 000000000..9e9dd9f7d --- /dev/null +++ b/backend/python/whisperx/requirements-cpu.txt @@ -0,0 +1,2 @@ +torch==2.4.1 +whisperx @ git+https://github.com/m-bain/whisperX.git diff --git a/backend/python/whisperx/requirements-cublas12.txt b/backend/python/whisperx/requirements-cublas12.txt new file mode 100644 index 000000000..9e9dd9f7d --- /dev/null +++ b/backend/python/whisperx/requirements-cublas12.txt @@ -0,0 +1,2 @@ +torch==2.4.1 +whisperx @ git+https://github.com/m-bain/whisperX.git diff --git a/backend/python/whisperx/requirements-cublas13.txt b/backend/python/whisperx/requirements-cublas13.txt new file mode 100644 index 000000000..8a8507199 --- /dev/null +++ b/backend/python/whisperx/requirements-cublas13.txt @@ -0,0 +1,3 @@ +--extra-index-url https://download.pytorch.org/whl/cu130 +torch==2.9.1 +whisperx @ git+https://github.com/m-bain/whisperX.git diff --git a/backend/python/whisperx/requirements-hipblas.txt b/backend/python/whisperx/requirements-hipblas.txt new file mode 100644 index 000000000..9f7a1778d --- /dev/null +++ b/backend/python/whisperx/requirements-hipblas.txt @@ -0,0 +1,3 @@ +--extra-index-url https://download.pytorch.org/whl/rocm6.4 +torch +whisperx @ git+https://github.com/m-bain/whisperX.git diff --git a/backend/python/whisperx/requirements.txt b/backend/python/whisperx/requirements.txt new file mode 100644 index 000000000..44b40efd0 --- /dev/null +++ b/backend/python/whisperx/requirements.txt @@ -0,0 +1,3 @@ +grpcio==1.71.0 +protobuf +grpcio-tools diff --git a/backend/python/whisperx/run.sh b/backend/python/whisperx/run.sh new file mode 100755 index 000000000..eae121f37 --- /dev/null +++ b/backend/python/whisperx/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/whisperx/test.py b/backend/python/whisperx/test.py new file mode 100644 index 000000000..c2b4db8b5 --- /dev/null +++ b/backend/python/whisperx/test.py @@ -0,0 +1,124 @@ +""" +A test script to test the gRPC service for WhisperX transcription +""" +import unittest +import subprocess +import time +import os +import tempfile +import shutil +import backend_pb2 +import backend_pb2_grpc + +import grpc + + +class TestBackendServicer(unittest.TestCase): + """ + TestBackendServicer is the class that tests the gRPC service + """ + def setUp(self): + """ + This method sets up the gRPC service by starting the server + """ + self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) + time.sleep(10) + + def tearDown(self) -> None: + """ + This method tears down the gRPC service by terminating the server + """ + self.service.terminate() + self.service.wait() + + def test_server_startup(self): + """ + This method tests if the server starts up successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + + def test_load_model(self): + """ + This method tests if the model is loaded successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="tiny")) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_audio_transcription(self): + """ + This method tests if audio transcription works successfully + """ + # Create a temporary directory for the audio file + temp_dir = tempfile.mkdtemp() + audio_file = os.path.join(temp_dir, 'audio.wav') + + try: + # Download the audio file to the temporary directory + print(f"Downloading audio file to {audio_file}...") + url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav" + result = subprocess.run( + ["wget", "-q", url, "-O", audio_file], + capture_output=True, + text=True + ) + if result.returncode != 0: + self.fail(f"Failed to download audio file: {result.stderr}") + + # Verify the file was downloaded + if not os.path.exists(audio_file): + self.fail(f"Audio file was not downloaded to {audio_file}") + + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + # Load the model first + load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="tiny")) + self.assertTrue(load_response.success) + + # Perform transcription without diarization + transcript_request = backend_pb2.TranscriptRequest(dst=audio_file) + transcript_response = stub.AudioTranscription(transcript_request) + + # Print the transcribed text for debugging + print(f"Transcribed text: {transcript_response.text}") + print(f"Number of segments: {len(transcript_response.segments)}") + + # Verify response structure + self.assertIsNotNone(transcript_response) + self.assertIsNotNone(transcript_response.text) + self.assertGreater(len(transcript_response.text), 0) + self.assertGreater(len(transcript_response.segments), 0) + + # Verify segments have timing info + segment = transcript_response.segments[0] + self.assertIsNotNone(segment.text) + self.assertIsInstance(segment.id, int) + + except Exception as err: + print(err) + self.fail("AudioTranscription service failed") + finally: + self.tearDown() + # Clean up the temporary directory + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) diff --git a/backend/python/whisperx/test.sh b/backend/python/whisperx/test.sh new file mode 100755 index 000000000..eb59f2aaf --- /dev/null +++ b/backend/python/whisperx/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests