Files
LocalAI/backend/python/fish-speech/backend.py
Ettore Di Giacinto 59108fbe32 feat: add distributed mode (#9124)
* feat: add distributed mode (experimental)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix data races, mutexes, transactions

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix events and tool stream in agent chat

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* use ginkgo

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix(cron): compute correctly time boundaries avoiding re-triggering

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not flood of healthy checks

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* do not list obvious backends as text backends

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* tests fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* refactoring and consolidation

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Drop redundant healthcheck

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* enhancements, refactorings

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-30 00:47:27 +02:00

464 lines
17 KiB
Python

#!/usr/bin/env python3
"""
This is an extra gRPC server of LocalAI for fish-speech TTS
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import traceback
import backend_pb2
import backend_pb2_grpc
import torch
import soundfile as sf
import numpy as np
import json
import grpc
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
def is_float(s):
"""Check if a string can be converted to float."""
try:
float(s)
return True
except ValueError:
return False
def is_int(s):
"""Check if a string can be converted to int."""
try:
int(s)
return True
except ValueError:
return False
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1"))
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
BackendServicer is the class that implements the gRPC service
"""
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", "utf-8"))
def LoadModel(self, request, context):
try:
# Get device
if torch.cuda.is_available():
print("CUDA is available", file=sys.stderr)
device = "cuda"
else:
print("CUDA is not available", file=sys.stderr)
device = "cpu"
mps_available = (
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
)
if mps_available:
device = "mps"
if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")
# Validate mps availability if requested
if device == "mps" and not torch.backends.mps.is_available():
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
device = "cpu"
self.device = device
self._torch_device = torch.device(device)
options = request.Options
# empty dict
self.options = {}
# The options are a list of strings in this form optname:optvalue
for opt in options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value
# Parse voices configuration from options
self.voices = {}
if "voices" in self.options:
try:
voices_data = self.options["voices"]
if isinstance(voices_data, str):
voices_list = json.loads(voices_data)
else:
voices_list = voices_data
for voice_entry in voices_list:
if not isinstance(voice_entry, dict):
print(
f"[WARNING] Invalid voice entry (not a dict): {voice_entry}",
file=sys.stderr,
)
continue
name = voice_entry.get("name")
audio = voice_entry.get("audio")
ref_text = voice_entry.get("ref_text", "")
if not name or not isinstance(name, str):
print(
f"[WARNING] Voice entry missing required 'name' field: {voice_entry}",
file=sys.stderr,
)
continue
if not audio or not isinstance(audio, str):
print(
f"[WARNING] Voice entry missing required 'audio' field: {voice_entry}",
file=sys.stderr,
)
continue
self.voices[name] = {"audio": audio, "ref_text": ref_text}
print(
f"[INFO] Registered voice '{name}' with audio: {audio}",
file=sys.stderr,
)
print(f"[INFO] Loaded {len(self.voices)} voice(s)", file=sys.stderr)
except json.JSONDecodeError as e:
print(f"[ERROR] Failed to parse voices JSON: {e}", file=sys.stderr)
except Exception as e:
print(
f"[ERROR] Error processing voices configuration: {e}",
file=sys.stderr,
)
print(traceback.format_exc(), file=sys.stderr)
# Store AudioPath, ModelFile, and ModelPath from LoadModel request
self.audio_path = (
request.AudioPath
if hasattr(request, "AudioPath") and request.AudioPath
else None
)
self.model_file = (
request.ModelFile
if hasattr(request, "ModelFile") and request.ModelFile
else None
)
self.model_path = (
request.ModelPath
if hasattr(request, "ModelPath") and request.ModelPath
else None
)
# Get model path from request
model_path = request.Model
if not model_path:
model_path = "fishaudio/s2-pro"
# If model_path looks like a HuggingFace repo ID (e.g. "fishaudio/fish-speech-1.5"),
# download it locally first since fish-speech expects a local directory
if "/" in model_path and not os.path.exists(model_path):
from huggingface_hub import snapshot_download
print(
f"Downloading model from HuggingFace: {model_path}",
file=sys.stderr,
)
model_path = snapshot_download(repo_id=model_path)
print(f"Model downloaded to: {model_path}", file=sys.stderr)
# Determine precision
if device in ("mps", "cpu"):
precision = torch.float32
else:
precision = torch.bfloat16
# Whether to use torch.compile
compile_model = self.options.get("compile", False)
print(
f"Using device: {device}, precision: {precision}, compile: {compile_model}",
file=sys.stderr,
)
print(f"Loading model from: {model_path}", file=sys.stderr)
# Import fish-speech modules
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.dac.inference import load_model as load_decoder_model
from fish_speech.models.text2semantic.inference import (
launch_thread_safe_queue,
)
# Determine decoder checkpoint path
# The codec model is typically at <checkpoint_path>/codec.pth
decoder_checkpoint = self.options.get("decoder_checkpoint", None)
if not decoder_checkpoint:
# Try common locations
if os.path.isdir(model_path):
candidate = os.path.join(model_path, "codec.pth")
if os.path.exists(candidate):
decoder_checkpoint = candidate
# Launch LLaMA queue (runs in daemon thread)
print("Launching LLaMA queue...", file=sys.stderr)
llama_queue = launch_thread_safe_queue(
checkpoint_path=model_path,
device=device,
precision=precision,
compile=compile_model,
)
# Load DAC decoder
decoder_config = self.options.get("decoder_config", "modded_dac_vq")
if not decoder_checkpoint:
return backend_pb2.Result(
success=False,
message="Decoder checkpoint (codec.pth) not found. "
"Ensure the model directory contains codec.pth or set "
"decoder_checkpoint option.",
)
print(
f"Loading DAC decoder (config={decoder_config}, checkpoint={decoder_checkpoint})...",
file=sys.stderr,
)
decoder_model = load_decoder_model(
config_name=decoder_config,
checkpoint_path=decoder_checkpoint,
device=device,
)
# Create TTS inference engine
self.engine = TTSInferenceEngine(
llama_queue=llama_queue,
decoder_model=decoder_model,
precision=precision,
compile=compile_model,
)
print(f"Model loaded successfully: {model_path}", file=sys.stderr)
return backend_pb2.Result(message="Model loaded successfully", success=True)
except Exception as e:
print(f"[ERROR] Loading model: {type(e).__name__}: {e}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return backend_pb2.Result(
success=False, message=f"Failed to load model: {e}"
)
def _get_ref_audio_path(self, voice_name=None):
"""Get reference audio path from voices dict or stored AudioPath."""
if voice_name and voice_name in self.voices:
audio_path = self.voices[voice_name]["audio"]
if os.path.isabs(audio_path):
return audio_path
# Try relative to ModelFile
if self.model_file:
model_file_base = os.path.dirname(self.model_file)
ref_path = os.path.join(model_file_base, audio_path)
if os.path.exists(ref_path):
return ref_path
# Try relative to ModelPath
if self.model_path:
ref_path = os.path.join(self.model_path, audio_path)
if os.path.exists(ref_path):
return ref_path
return audio_path
# Fall back to legacy single-voice mode
if not self.audio_path:
return None
if os.path.isabs(self.audio_path):
return self.audio_path
if self.model_file:
model_file_base = os.path.dirname(self.model_file)
ref_path = os.path.join(model_file_base, self.audio_path)
if os.path.exists(ref_path):
return ref_path
if self.model_path:
ref_path = os.path.join(self.model_path, self.audio_path)
if os.path.exists(ref_path):
return ref_path
return self.audio_path
def TTS(self, request, context):
try:
from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio
if not request.dst:
return backend_pb2.Result(
success=False, message="dst (output path) is required"
)
text = request.text.strip()
if not text:
return backend_pb2.Result(success=False, message="Text is empty")
# Get generation parameters from options
top_p = self.options.get("top_p", 0.8)
temperature = self.options.get("temperature", 0.8)
repetition_penalty = self.options.get("repetition_penalty", 1.1)
max_new_tokens = self.options.get("max_new_tokens", 1024)
chunk_length = self.options.get("chunk_length", 200)
# Build references list for voice cloning
references = []
voice_name = request.voice if request.voice else None
if voice_name and voice_name in self.voices:
ref_audio_path = self._get_ref_audio_path(voice_name)
if ref_audio_path and os.path.exists(ref_audio_path):
with open(ref_audio_path, "rb") as f:
audio_bytes = f.read()
ref_text = self.voices[voice_name].get("ref_text", "")
references.append(
ServeReferenceAudio(audio=audio_bytes, text=ref_text)
)
print(
f"[INFO] Using voice '{voice_name}' with reference audio: {ref_audio_path}",
file=sys.stderr,
)
elif self.audio_path:
ref_audio_path = self._get_ref_audio_path()
if ref_audio_path and os.path.exists(ref_audio_path):
with open(ref_audio_path, "rb") as f:
audio_bytes = f.read()
ref_text = self.options.get("ref_text", "")
references.append(
ServeReferenceAudio(audio=audio_bytes, text=ref_text)
)
print(
f"[INFO] Using reference audio: {ref_audio_path}",
file=sys.stderr,
)
# Build ServeTTSRequest
tts_request = ServeTTSRequest(
text=text,
references=references,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
)
# Run inference
print(f"Generating speech for text: {text[:100]}...", file=sys.stderr)
start_time = time.time()
sample_rate = None
audio_data = None
for result in self.engine.inference(tts_request):
if result.code == "final":
sample_rate, audio_data = result.audio
elif result.code == "error":
error_msg = str(result.error) if result.error else "Unknown error"
print(f"[ERROR] TTS inference error: {error_msg}", file=sys.stderr)
return backend_pb2.Result(
success=False, message=f"TTS inference error: {error_msg}"
)
generation_duration = time.time() - start_time
if audio_data is None or sample_rate is None:
return backend_pb2.Result(
success=False, message="No audio output generated"
)
# Ensure audio_data is a numpy array
if not isinstance(audio_data, np.ndarray):
audio_data = np.array(audio_data)
audio_duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
print(
f"[INFO] TTS generation completed: {generation_duration:.2f}s, "
f"audio_duration={audio_duration:.2f}s, sample_rate={sample_rate}",
file=sys.stderr,
flush=True,
)
# Save output
sf.write(request.dst, audio_data, sample_rate)
print(f"Saved {audio_duration:.2f}s audio to {request.dst}", file=sys.stderr)
except Exception as err:
print(f"Error in TTS: {err}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return backend_pb2.Result(
success=False, message=f"Unexpected {err=}, {type(err)=}"
)
return backend_pb2.Result(success=True)
def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
("grpc.max_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
],
interceptors=get_auth_interceptors(),
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
serve(args.addr)