mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
* feat: add distributed mode (experimental) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix data races, mutexes, transactions Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix events and tool stream in agent chat Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * use ginkgo Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(cron): compute correctly time boundaries avoiding re-triggering Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not flood of healthy checks Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do not list obvious backends as text backends Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * tests fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * refactoring and consolidation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Drop redundant healthcheck Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * enhancements, refactorings Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
464 lines
17 KiB
Python
464 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
This is an extra gRPC server of LocalAI for fish-speech TTS
|
|
"""
|
|
|
|
from concurrent import futures
|
|
import time
|
|
import argparse
|
|
import signal
|
|
import sys
|
|
import os
|
|
import traceback
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
import torch
|
|
import soundfile as sf
|
|
import numpy as np
|
|
|
|
import json
|
|
|
|
import grpc
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
|
|
from grpc_auth import get_auth_interceptors
|
|
|
|
|
|
|
|
def is_float(s):
|
|
"""Check if a string can be converted to float."""
|
|
try:
|
|
float(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
def is_int(s):
|
|
"""Check if a string can be converted to int."""
|
|
try:
|
|
int(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
|
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1"))
|
|
|
|
|
|
# Implement the BackendServicer class with the service methods
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
"""
|
|
BackendServicer is the class that implements the gRPC service
|
|
"""
|
|
|
|
def Health(self, request, context):
|
|
return backend_pb2.Reply(message=bytes("OK", "utf-8"))
|
|
|
|
def LoadModel(self, request, context):
|
|
try:
|
|
# Get device
|
|
if torch.cuda.is_available():
|
|
print("CUDA is available", file=sys.stderr)
|
|
device = "cuda"
|
|
else:
|
|
print("CUDA is not available", file=sys.stderr)
|
|
device = "cpu"
|
|
mps_available = (
|
|
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
)
|
|
if mps_available:
|
|
device = "mps"
|
|
if not torch.cuda.is_available() and request.CUDA:
|
|
return backend_pb2.Result(success=False, message="CUDA is not available")
|
|
|
|
# Validate mps availability if requested
|
|
if device == "mps" and not torch.backends.mps.is_available():
|
|
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
|
|
device = "cpu"
|
|
|
|
self.device = device
|
|
self._torch_device = torch.device(device)
|
|
|
|
options = request.Options
|
|
|
|
# empty dict
|
|
self.options = {}
|
|
|
|
# The options are a list of strings in this form optname:optvalue
|
|
for opt in options:
|
|
if ":" not in opt:
|
|
continue
|
|
key, value = opt.split(":", 1)
|
|
if is_float(value):
|
|
value = float(value)
|
|
elif is_int(value):
|
|
value = int(value)
|
|
elif value.lower() in ["true", "false"]:
|
|
value = value.lower() == "true"
|
|
self.options[key] = value
|
|
|
|
# Parse voices configuration from options
|
|
self.voices = {}
|
|
if "voices" in self.options:
|
|
try:
|
|
voices_data = self.options["voices"]
|
|
if isinstance(voices_data, str):
|
|
voices_list = json.loads(voices_data)
|
|
else:
|
|
voices_list = voices_data
|
|
|
|
for voice_entry in voices_list:
|
|
if not isinstance(voice_entry, dict):
|
|
print(
|
|
f"[WARNING] Invalid voice entry (not a dict): {voice_entry}",
|
|
file=sys.stderr,
|
|
)
|
|
continue
|
|
|
|
name = voice_entry.get("name")
|
|
audio = voice_entry.get("audio")
|
|
ref_text = voice_entry.get("ref_text", "")
|
|
|
|
if not name or not isinstance(name, str):
|
|
print(
|
|
f"[WARNING] Voice entry missing required 'name' field: {voice_entry}",
|
|
file=sys.stderr,
|
|
)
|
|
continue
|
|
if not audio or not isinstance(audio, str):
|
|
print(
|
|
f"[WARNING] Voice entry missing required 'audio' field: {voice_entry}",
|
|
file=sys.stderr,
|
|
)
|
|
continue
|
|
|
|
self.voices[name] = {"audio": audio, "ref_text": ref_text}
|
|
print(
|
|
f"[INFO] Registered voice '{name}' with audio: {audio}",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
print(f"[INFO] Loaded {len(self.voices)} voice(s)", file=sys.stderr)
|
|
except json.JSONDecodeError as e:
|
|
print(f"[ERROR] Failed to parse voices JSON: {e}", file=sys.stderr)
|
|
except Exception as e:
|
|
print(
|
|
f"[ERROR] Error processing voices configuration: {e}",
|
|
file=sys.stderr,
|
|
)
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
# Store AudioPath, ModelFile, and ModelPath from LoadModel request
|
|
self.audio_path = (
|
|
request.AudioPath
|
|
if hasattr(request, "AudioPath") and request.AudioPath
|
|
else None
|
|
)
|
|
self.model_file = (
|
|
request.ModelFile
|
|
if hasattr(request, "ModelFile") and request.ModelFile
|
|
else None
|
|
)
|
|
self.model_path = (
|
|
request.ModelPath
|
|
if hasattr(request, "ModelPath") and request.ModelPath
|
|
else None
|
|
)
|
|
|
|
# Get model path from request
|
|
model_path = request.Model
|
|
if not model_path:
|
|
model_path = "fishaudio/s2-pro"
|
|
|
|
# If model_path looks like a HuggingFace repo ID (e.g. "fishaudio/fish-speech-1.5"),
|
|
# download it locally first since fish-speech expects a local directory
|
|
if "/" in model_path and not os.path.exists(model_path):
|
|
from huggingface_hub import snapshot_download
|
|
|
|
print(
|
|
f"Downloading model from HuggingFace: {model_path}",
|
|
file=sys.stderr,
|
|
)
|
|
model_path = snapshot_download(repo_id=model_path)
|
|
print(f"Model downloaded to: {model_path}", file=sys.stderr)
|
|
|
|
# Determine precision
|
|
if device in ("mps", "cpu"):
|
|
precision = torch.float32
|
|
else:
|
|
precision = torch.bfloat16
|
|
|
|
# Whether to use torch.compile
|
|
compile_model = self.options.get("compile", False)
|
|
|
|
print(
|
|
f"Using device: {device}, precision: {precision}, compile: {compile_model}",
|
|
file=sys.stderr,
|
|
)
|
|
print(f"Loading model from: {model_path}", file=sys.stderr)
|
|
|
|
# Import fish-speech modules
|
|
from fish_speech.inference_engine import TTSInferenceEngine
|
|
from fish_speech.models.dac.inference import load_model as load_decoder_model
|
|
from fish_speech.models.text2semantic.inference import (
|
|
launch_thread_safe_queue,
|
|
)
|
|
|
|
# Determine decoder checkpoint path
|
|
# The codec model is typically at <checkpoint_path>/codec.pth
|
|
decoder_checkpoint = self.options.get("decoder_checkpoint", None)
|
|
if not decoder_checkpoint:
|
|
# Try common locations
|
|
if os.path.isdir(model_path):
|
|
candidate = os.path.join(model_path, "codec.pth")
|
|
if os.path.exists(candidate):
|
|
decoder_checkpoint = candidate
|
|
|
|
# Launch LLaMA queue (runs in daemon thread)
|
|
print("Launching LLaMA queue...", file=sys.stderr)
|
|
llama_queue = launch_thread_safe_queue(
|
|
checkpoint_path=model_path,
|
|
device=device,
|
|
precision=precision,
|
|
compile=compile_model,
|
|
)
|
|
|
|
# Load DAC decoder
|
|
decoder_config = self.options.get("decoder_config", "modded_dac_vq")
|
|
if not decoder_checkpoint:
|
|
return backend_pb2.Result(
|
|
success=False,
|
|
message="Decoder checkpoint (codec.pth) not found. "
|
|
"Ensure the model directory contains codec.pth or set "
|
|
"decoder_checkpoint option.",
|
|
)
|
|
print(
|
|
f"Loading DAC decoder (config={decoder_config}, checkpoint={decoder_checkpoint})...",
|
|
file=sys.stderr,
|
|
)
|
|
decoder_model = load_decoder_model(
|
|
config_name=decoder_config,
|
|
checkpoint_path=decoder_checkpoint,
|
|
device=device,
|
|
)
|
|
|
|
# Create TTS inference engine
|
|
self.engine = TTSInferenceEngine(
|
|
llama_queue=llama_queue,
|
|
decoder_model=decoder_model,
|
|
precision=precision,
|
|
compile=compile_model,
|
|
)
|
|
|
|
print(f"Model loaded successfully: {model_path}", file=sys.stderr)
|
|
|
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
|
|
except Exception as e:
|
|
print(f"[ERROR] Loading model: {type(e).__name__}: {e}", file=sys.stderr)
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
return backend_pb2.Result(
|
|
success=False, message=f"Failed to load model: {e}"
|
|
)
|
|
|
|
def _get_ref_audio_path(self, voice_name=None):
|
|
"""Get reference audio path from voices dict or stored AudioPath."""
|
|
if voice_name and voice_name in self.voices:
|
|
audio_path = self.voices[voice_name]["audio"]
|
|
|
|
if os.path.isabs(audio_path):
|
|
return audio_path
|
|
|
|
# Try relative to ModelFile
|
|
if self.model_file:
|
|
model_file_base = os.path.dirname(self.model_file)
|
|
ref_path = os.path.join(model_file_base, audio_path)
|
|
if os.path.exists(ref_path):
|
|
return ref_path
|
|
|
|
# Try relative to ModelPath
|
|
if self.model_path:
|
|
ref_path = os.path.join(self.model_path, audio_path)
|
|
if os.path.exists(ref_path):
|
|
return ref_path
|
|
|
|
return audio_path
|
|
|
|
# Fall back to legacy single-voice mode
|
|
if not self.audio_path:
|
|
return None
|
|
|
|
if os.path.isabs(self.audio_path):
|
|
return self.audio_path
|
|
|
|
if self.model_file:
|
|
model_file_base = os.path.dirname(self.model_file)
|
|
ref_path = os.path.join(model_file_base, self.audio_path)
|
|
if os.path.exists(ref_path):
|
|
return ref_path
|
|
|
|
if self.model_path:
|
|
ref_path = os.path.join(self.model_path, self.audio_path)
|
|
if os.path.exists(ref_path):
|
|
return ref_path
|
|
|
|
return self.audio_path
|
|
|
|
def TTS(self, request, context):
|
|
try:
|
|
from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio
|
|
|
|
if not request.dst:
|
|
return backend_pb2.Result(
|
|
success=False, message="dst (output path) is required"
|
|
)
|
|
|
|
text = request.text.strip()
|
|
if not text:
|
|
return backend_pb2.Result(success=False, message="Text is empty")
|
|
|
|
# Get generation parameters from options
|
|
top_p = self.options.get("top_p", 0.8)
|
|
temperature = self.options.get("temperature", 0.8)
|
|
repetition_penalty = self.options.get("repetition_penalty", 1.1)
|
|
max_new_tokens = self.options.get("max_new_tokens", 1024)
|
|
chunk_length = self.options.get("chunk_length", 200)
|
|
|
|
# Build references list for voice cloning
|
|
references = []
|
|
voice_name = request.voice if request.voice else None
|
|
|
|
if voice_name and voice_name in self.voices:
|
|
ref_audio_path = self._get_ref_audio_path(voice_name)
|
|
if ref_audio_path and os.path.exists(ref_audio_path):
|
|
with open(ref_audio_path, "rb") as f:
|
|
audio_bytes = f.read()
|
|
ref_text = self.voices[voice_name].get("ref_text", "")
|
|
references.append(
|
|
ServeReferenceAudio(audio=audio_bytes, text=ref_text)
|
|
)
|
|
print(
|
|
f"[INFO] Using voice '{voice_name}' with reference audio: {ref_audio_path}",
|
|
file=sys.stderr,
|
|
)
|
|
elif self.audio_path:
|
|
ref_audio_path = self._get_ref_audio_path()
|
|
if ref_audio_path and os.path.exists(ref_audio_path):
|
|
with open(ref_audio_path, "rb") as f:
|
|
audio_bytes = f.read()
|
|
ref_text = self.options.get("ref_text", "")
|
|
references.append(
|
|
ServeReferenceAudio(audio=audio_bytes, text=ref_text)
|
|
)
|
|
print(
|
|
f"[INFO] Using reference audio: {ref_audio_path}",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
# Build ServeTTSRequest
|
|
tts_request = ServeTTSRequest(
|
|
text=text,
|
|
references=references,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
repetition_penalty=repetition_penalty,
|
|
max_new_tokens=max_new_tokens,
|
|
chunk_length=chunk_length,
|
|
)
|
|
|
|
# Run inference
|
|
print(f"Generating speech for text: {text[:100]}...", file=sys.stderr)
|
|
start_time = time.time()
|
|
|
|
sample_rate = None
|
|
audio_data = None
|
|
|
|
for result in self.engine.inference(tts_request):
|
|
if result.code == "final":
|
|
sample_rate, audio_data = result.audio
|
|
elif result.code == "error":
|
|
error_msg = str(result.error) if result.error else "Unknown error"
|
|
print(f"[ERROR] TTS inference error: {error_msg}", file=sys.stderr)
|
|
return backend_pb2.Result(
|
|
success=False, message=f"TTS inference error: {error_msg}"
|
|
)
|
|
|
|
generation_duration = time.time() - start_time
|
|
|
|
if audio_data is None or sample_rate is None:
|
|
return backend_pb2.Result(
|
|
success=False, message="No audio output generated"
|
|
)
|
|
|
|
# Ensure audio_data is a numpy array
|
|
if not isinstance(audio_data, np.ndarray):
|
|
audio_data = np.array(audio_data)
|
|
|
|
audio_duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
|
|
print(
|
|
f"[INFO] TTS generation completed: {generation_duration:.2f}s, "
|
|
f"audio_duration={audio_duration:.2f}s, sample_rate={sample_rate}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
# Save output
|
|
sf.write(request.dst, audio_data, sample_rate)
|
|
print(f"Saved {audio_duration:.2f}s audio to {request.dst}", file=sys.stderr)
|
|
|
|
except Exception as err:
|
|
print(f"Error in TTS: {err}", file=sys.stderr)
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
return backend_pb2.Result(
|
|
success=False, message=f"Unexpected {err=}, {type(err)=}"
|
|
)
|
|
|
|
return backend_pb2.Result(success=True)
|
|
|
|
|
|
def serve(address):
|
|
server = grpc.server(
|
|
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
options=[
|
|
("grpc.max_message_length", 50 * 1024 * 1024), # 50MB
|
|
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
|
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
|
],
|
|
|
|
interceptors=get_auth_interceptors(),
|
|
)
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
server.add_insecure_port(address)
|
|
server.start()
|
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
|
|
# Define the signal handler function
|
|
def signal_handler(sig, frame):
|
|
print("Received termination signal. Shutting down...")
|
|
server.stop(0)
|
|
sys.exit(0)
|
|
|
|
# Set the signal handlers for SIGINT and SIGTERM
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
try:
|
|
while True:
|
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
except KeyboardInterrupt:
|
|
server.stop(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
parser.add_argument(
|
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
serve(args.addr)
|