Files
LocalAI/backend/python/moonshine/backend.py
2026-02-12 23:22:37 +01:00

162 lines
5.8 KiB
Python

#!/usr/bin/env python3
"""
This is an extra gRPC server of LocalAI for Moonshine transcription
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import backend_pb2
import backend_pb2_grpc
from moonshine_voice import (
Transcriber,
get_model_for_language,
load_wav_file,
)
import grpc
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
BackendServicer is the class that implements the gRPC service
"""
def __init__(self):
self.transcriber = None
self.model_name = None
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
try:
print("Preparing models, please wait", file=sys.stderr)
self.model_name = request.Model
print(f"Model name set to: {self.model_name}", file=sys.stderr)
# Default values
language = "en"
model_arch = None
# Parse options from request
options = request.Options
self.options = {}
# The options are a list of strings in this form optname:optvalue
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
self.options[key] = value
print(f"Options: {self.options}", file=sys.stderr)
# Extract language and model_arch from options
if "language" in self.options:
language = self.options["language"]
if "model_arch" in self.options:
model_arch = self.options["model_arch"]
# Get the model path and architecture
model_path, model_arch = get_model_for_language(language, model_arch)
print(f"Loading model: {model_path} with architecture: {model_arch} for language: {language}", file=sys.stderr)
# Initialize the transcriber
self.transcriber = Transcriber(model_path=model_path, model_arch=model_arch)
print("Model loaded successfully", file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
def AudioTranscription(self, request, context):
resultSegments = []
text = ""
try:
if self.transcriber is None:
raise Exception("Model not loaded. Call LoadModel first.")
# Load the audio file
audio_data, sample_rate = load_wav_file(request.dst)
print(f"Loaded audio file: {request.dst} with sample rate: {sample_rate}", file=sys.stderr)
# Transcribe without streaming
transcript = self.transcriber.transcribe_without_streaming(
audio_data, sample_rate=sample_rate, flags=0
)
# Process transcript lines
full_text_parts = []
for idx, line in enumerate(transcript.lines):
line_text = line.text.strip()
full_text_parts.append(line_text)
# Create segment with timing information
start_ms = int(line.start_time * 1000)
end_ms = int((line.start_time + line.duration) * 1000)
resultSegments.append(backend_pb2.TranscriptSegment(
id=idx,
start=start_ms,
end=end_ms,
text=line_text
))
print(f"Segment {idx}: [{line.start_time:.2f}s - {line.start_time + line.duration:.2f}s] {line_text}", file=sys.stderr)
# Combine all transcriptions into a single text
text = " ".join(full_text_parts)
except Exception as err:
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
import traceback
traceback.print_exc()
return backend_pb2.TranscriptResult(segments=[], text="")
return backend_pb2.TranscriptResult(segments=resultSegments, text=text)
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
])
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
serve(args.addr)