mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-26 16:09:20 -05:00
466 lines
18 KiB
Python
466 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
import asyncio
|
|
from concurrent import futures
|
|
import argparse
|
|
import signal
|
|
import sys
|
|
import os
|
|
import shutil
|
|
import glob
|
|
from typing import List
|
|
import time
|
|
import tempfile
|
|
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
import grpc
|
|
from mlx_audio.tts.utils import load_model
|
|
import soundfile as sf
|
|
import numpy as np
|
|
import uuid
|
|
|
|
def is_float(s):
|
|
"""Check if a string can be converted to float."""
|
|
try:
|
|
float(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
def is_int(s):
|
|
"""Check if a string can be converted to int."""
|
|
try:
|
|
int(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
|
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
|
|
# Implement the BackendServicer class with the service methods
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
"""
|
|
A gRPC servicer that implements the Backend service defined in backend.proto.
|
|
This backend provides TTS (Text-to-Speech) functionality using MLX-Audio.
|
|
"""
|
|
|
|
def Health(self, request, context):
|
|
"""
|
|
Returns a health check message.
|
|
|
|
Args:
|
|
request: The health check request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Reply: The health check reply.
|
|
"""
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
|
|
async def LoadModel(self, request, context):
|
|
"""
|
|
Loads a TTS model using MLX-Audio.
|
|
|
|
Args:
|
|
request: The load model request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Result: The load model result.
|
|
"""
|
|
try:
|
|
print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr)
|
|
print(f"Request: {request}", file=sys.stderr)
|
|
|
|
# Parse options like in the kokoro backend
|
|
options = request.Options
|
|
self.options = {}
|
|
|
|
# The options are a list of strings in this form optname:optvalue
|
|
# We store all the options in a dict for later use
|
|
for opt in options:
|
|
if ":" not in opt:
|
|
continue
|
|
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
|
|
|
# Convert numeric values to appropriate types
|
|
if is_float(value):
|
|
value = float(value)
|
|
elif is_int(value):
|
|
value = int(value)
|
|
elif value.lower() in ["true", "false"]:
|
|
value = value.lower() == "true"
|
|
|
|
self.options[key] = value
|
|
|
|
print(f"Options: {self.options}", file=sys.stderr)
|
|
|
|
# Load the model using MLX-Audio's load_model function
|
|
try:
|
|
self.tts_model = load_model(request.Model)
|
|
self.model_path = request.Model
|
|
print(f"TTS model loaded successfully from {request.Model}", file=sys.stderr)
|
|
except Exception as model_err:
|
|
print(f"Error loading TTS model: {model_err}", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message=f"Failed to load model: {model_err}")
|
|
|
|
except Exception as err:
|
|
print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}")
|
|
|
|
print("MLX-Audio TTS model loaded successfully", file=sys.stderr)
|
|
return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True)
|
|
|
|
def TTS(self, request, context):
|
|
"""
|
|
Generates TTS audio from text using MLX-Audio.
|
|
|
|
Args:
|
|
request: A TTSRequest object containing text, model, destination, voice, and language.
|
|
context: A grpc.ServicerContext object that provides information about the RPC.
|
|
|
|
Returns:
|
|
A Result object indicating success or failure.
|
|
"""
|
|
try:
|
|
# Check if model is loaded
|
|
if not hasattr(self, 'tts_model') or self.tts_model is None:
|
|
return backend_pb2.Result(success=False, message="TTS model not loaded. Please call LoadModel first.")
|
|
|
|
print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr)
|
|
|
|
# Handle speed parameter based on model type
|
|
speed_value = self._handle_speed_parameter(request, self.model_path)
|
|
|
|
# Map language names to codes if needed
|
|
lang_code = self._map_language_code(request.language, request.voice)
|
|
|
|
# Prepare generation parameters
|
|
gen_params = {
|
|
"text": request.text,
|
|
"speed": speed_value,
|
|
"verbose": False,
|
|
}
|
|
|
|
# Add model-specific parameters
|
|
if request.voice and request.voice.strip():
|
|
gen_params["voice"] = request.voice
|
|
|
|
# Check if model supports language codes (primarily Kokoro)
|
|
if "kokoro" in self.model_path.lower():
|
|
gen_params["lang_code"] = lang_code
|
|
|
|
# Add pitch and gender for Spark models
|
|
if "spark" in self.model_path.lower():
|
|
gen_params["pitch"] = 1.0 # Default to moderate
|
|
gen_params["gender"] = "female" # Default to female
|
|
|
|
print(f"Generation parameters: {gen_params}", file=sys.stderr)
|
|
|
|
# Generate audio using the loaded model
|
|
try:
|
|
results = self.tts_model.generate(**gen_params)
|
|
except Exception as gen_err:
|
|
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message=f"TTS generation failed: {gen_err}")
|
|
|
|
# Process the generated audio segments
|
|
audio_arrays = []
|
|
for segment in results:
|
|
audio_arrays.append(segment.audio)
|
|
|
|
# If no segments, return error
|
|
if not audio_arrays:
|
|
print("No audio segments generated", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message="No audio generated")
|
|
|
|
# Concatenate all segments
|
|
cat_audio = np.concatenate(audio_arrays, axis=0)
|
|
|
|
# Generate output filename and path
|
|
if request.dst:
|
|
output_path = request.dst
|
|
else:
|
|
unique_id = str(uuid.uuid4())
|
|
filename = f"tts_{unique_id}.wav"
|
|
output_path = filename
|
|
|
|
# Write the audio as a WAV
|
|
try:
|
|
sf.write(output_path, cat_audio, 24000)
|
|
print(f"Successfully wrote audio file to {output_path}", file=sys.stderr)
|
|
|
|
# Verify the file exists and has content
|
|
if not os.path.exists(output_path):
|
|
print(f"File was not created at {output_path}", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message="Failed to create audio file")
|
|
|
|
file_size = os.path.getsize(output_path)
|
|
if file_size == 0:
|
|
print("File was created but is empty", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message="Generated audio file is empty")
|
|
|
|
print(f"Audio file size: {file_size} bytes", file=sys.stderr)
|
|
|
|
except Exception as write_err:
|
|
print(f"Error writing audio file: {write_err}", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message=f"Failed to save audio: {write_err}")
|
|
|
|
return backend_pb2.Result(success=True, message=f"TTS audio generated successfully: {output_path}")
|
|
|
|
except Exception as e:
|
|
print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr)
|
|
return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}")
|
|
|
|
async def Predict(self, request, context):
|
|
"""
|
|
Generates TTS audio based on the given prompt using MLX-Audio TTS.
|
|
This is a fallback method for compatibility with the Predict endpoint.
|
|
|
|
Args:
|
|
request: The predict request.
|
|
context: The gRPC context.
|
|
|
|
Returns:
|
|
backend_pb2.Reply: The predict result.
|
|
"""
|
|
try:
|
|
# Check if model is loaded
|
|
if not hasattr(self, 'tts_model') or self.tts_model is None:
|
|
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
|
context.set_details("TTS model not loaded. Please call LoadModel first.")
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
|
|
|
# For TTS, we expect the prompt to contain the text to synthesize
|
|
if not request.Prompt:
|
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
|
context.set_details("Prompt is required for TTS generation")
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
|
|
|
# Handle speed parameter based on model type
|
|
speed_value = self._handle_speed_parameter(request, self.model_path)
|
|
|
|
# Map language names to codes if needed
|
|
lang_code = self._map_language_code(None, None) # Use defaults for Predict
|
|
|
|
# Prepare generation parameters
|
|
gen_params = {
|
|
"text": request.Prompt,
|
|
"speed": speed_value,
|
|
"verbose": False,
|
|
}
|
|
|
|
# Add model-specific parameters
|
|
if hasattr(self, 'options') and 'voice' in self.options:
|
|
gen_params["voice"] = self.options['voice']
|
|
|
|
# Check if model supports language codes (primarily Kokoro)
|
|
if "kokoro" in self.model_path.lower():
|
|
gen_params["lang_code"] = lang_code
|
|
|
|
print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {gen_params}", file=sys.stderr)
|
|
|
|
# Generate audio using the loaded model
|
|
try:
|
|
results = self.tts_model.generate(**gen_params)
|
|
except Exception as gen_err:
|
|
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"TTS generation failed: {gen_err}")
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
|
|
|
# Process the generated audio segments
|
|
audio_arrays = []
|
|
for segment in results:
|
|
audio_arrays.append(segment.audio)
|
|
|
|
# If no segments, return error
|
|
if not audio_arrays:
|
|
print("No audio segments generated", file=sys.stderr)
|
|
return backend_pb2.Reply(message=bytes("No audio generated", encoding='utf-8'))
|
|
|
|
# Concatenate all segments
|
|
cat_audio = np.concatenate(audio_arrays, axis=0)
|
|
duration = len(cat_audio) / 24000 # Assuming 24kHz sample rate
|
|
|
|
# Return success message with audio information
|
|
response = f"TTS audio generated successfully. Duration: {duration:.2f}s, Sample rate: 24000Hz"
|
|
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
|
|
|
|
except Exception as e:
|
|
print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr)
|
|
context.set_code(grpc.StatusCode.INTERNAL)
|
|
context.set_details(f"TTS generation failed: {str(e)}")
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
|
|
|
def _handle_speed_parameter(self, request, model_path):
|
|
"""
|
|
Handle speed parameter based on model type.
|
|
|
|
Args:
|
|
request: The TTSRequest object.
|
|
model_path: The model path to determine model type.
|
|
|
|
Returns:
|
|
float: The processed speed value.
|
|
"""
|
|
# Get speed from options if available
|
|
speed = 1.0
|
|
if hasattr(self, 'options') and 'speed' in self.options:
|
|
speed = self.options['speed']
|
|
|
|
# Handle speed parameter based on model type
|
|
if "spark" in model_path.lower():
|
|
# Spark actually expects float values that map to speed descriptions
|
|
speed_map = {
|
|
"very_low": 0.0,
|
|
"low": 0.5,
|
|
"moderate": 1.0,
|
|
"high": 1.5,
|
|
"very_high": 2.0,
|
|
}
|
|
if isinstance(speed, str) and speed in speed_map:
|
|
speed_value = speed_map[speed]
|
|
else:
|
|
# Try to use as float, default to 1.0 (moderate) if invalid
|
|
try:
|
|
speed_value = float(speed)
|
|
if speed_value not in [0.0, 0.5, 1.0, 1.5, 2.0]:
|
|
speed_value = 1.0 # Default to moderate
|
|
except:
|
|
speed_value = 1.0 # Default to moderate
|
|
else:
|
|
# Other models use float speed values
|
|
try:
|
|
speed_value = float(speed)
|
|
if speed_value < 0.5 or speed_value > 2.0:
|
|
speed_value = 1.0 # Default to 1.0 if out of range
|
|
except ValueError:
|
|
speed_value = 1.0 # Default to 1.0 if invalid
|
|
|
|
return speed_value
|
|
|
|
def _map_language_code(self, language, voice):
|
|
"""
|
|
Map language names to codes if needed.
|
|
|
|
Args:
|
|
language: The language parameter from the request.
|
|
voice: The voice parameter from the request.
|
|
|
|
Returns:
|
|
str: The language code.
|
|
"""
|
|
if not language:
|
|
# Default to voice[0] if not found
|
|
return voice[0] if voice else "a"
|
|
|
|
# Map language names to codes if needed
|
|
language_map = {
|
|
"american_english": "a",
|
|
"british_english": "b",
|
|
"spanish": "e",
|
|
"french": "f",
|
|
"hindi": "h",
|
|
"italian": "i",
|
|
"portuguese": "p",
|
|
"japanese": "j",
|
|
"mandarin_chinese": "z",
|
|
# Also accept direct language codes
|
|
"a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z",
|
|
}
|
|
|
|
return language_map.get(language.lower(), language)
|
|
|
|
def _build_generation_params(self, request, default_speed=1.0):
|
|
"""
|
|
Build generation parameters from request attributes and options for MLX-Audio TTS.
|
|
|
|
Args:
|
|
request: The gRPC request.
|
|
default_speed: Default speed if not specified.
|
|
|
|
Returns:
|
|
dict: Generation parameters for MLX-Audio
|
|
"""
|
|
# Initialize generation parameters for MLX-Audio TTS
|
|
generation_params = {
|
|
'speed': default_speed,
|
|
'voice': 'af_heart', # Default voice
|
|
'lang_code': 'a', # Default language code
|
|
}
|
|
|
|
# Extract parameters from request attributes
|
|
if hasattr(request, 'Temperature') and request.Temperature > 0:
|
|
# Temperature could be mapped to speed variation
|
|
generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5
|
|
|
|
# Override with options if available
|
|
if hasattr(self, 'options'):
|
|
# Speed from options
|
|
if 'speed' in self.options:
|
|
generation_params['speed'] = self.options['speed']
|
|
|
|
# Voice from options
|
|
if 'voice' in self.options:
|
|
generation_params['voice'] = self.options['voice']
|
|
|
|
# Language code from options
|
|
if 'lang_code' in self.options:
|
|
generation_params['lang_code'] = self.options['lang_code']
|
|
|
|
# Model-specific parameters
|
|
param_option_mapping = {
|
|
'temp': 'speed',
|
|
'temperature': 'speed',
|
|
'top_p': 'speed', # Map top_p to speed variation
|
|
}
|
|
|
|
for option_key, param_key in param_option_mapping.items():
|
|
if option_key in self.options:
|
|
if param_key == 'speed':
|
|
# Ensure speed is within reasonable bounds
|
|
speed_val = float(self.options[option_key])
|
|
if 0.5 <= speed_val <= 2.0:
|
|
generation_params[param_key] = speed_val
|
|
|
|
return generation_params
|
|
|
|
async def serve(address):
|
|
# Start asyncio gRPC server
|
|
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
options=[
|
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
|
])
|
|
# Add the servicer to the server
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
# Bind the server to the address
|
|
server.add_insecure_port(address)
|
|
|
|
# Gracefully shutdown the server on SIGTERM or SIGINT
|
|
loop = asyncio.get_event_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(
|
|
sig, lambda: asyncio.ensure_future(server.stop(5))
|
|
)
|
|
|
|
# Start the server
|
|
await server.start()
|
|
print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr)
|
|
# Wait for the server to be terminated
|
|
await server.wait_for_termination()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.")
|
|
parser.add_argument(
|
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
asyncio.run(serve(args.addr))
|