mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 05:36:49 -04:00
feat: add fish-speech backend (#8962)
* feat: add fish-speech backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * drop portaudio Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
17f36e73b5
commit
7dc691c171
23
backend/python/fish-speech/Makefile
Normal file
23
backend/python/fish-speech/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: fish-speech
|
||||
fish-speech:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: fish-speech
|
||||
@echo "Running fish-speech..."
|
||||
bash run.sh
|
||||
@echo "fish-speech run."
|
||||
|
||||
.PHONY: test
|
||||
test: fish-speech
|
||||
@echo "Testing fish-speech..."
|
||||
bash test.sh
|
||||
@echo "fish-speech tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
457
backend/python/fish-speech/backend.py
Normal file
457
backend/python/fish-speech/backend.py
Normal file
@@ -0,0 +1,457 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for fish-speech TTS
|
||||
"""
|
||||
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import traceback
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
import json
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1"))
|
||||
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", "utf-8"))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
# Get device
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA is available", file=sys.stderr)
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
mps_available = (
|
||||
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
)
|
||||
if mps_available:
|
||||
device = "mps"
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
# Validate mps availability if requested
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
|
||||
device = "cpu"
|
||||
|
||||
self.device = device
|
||||
self._torch_device = torch.device(device)
|
||||
|
||||
options = request.Options
|
||||
|
||||
# empty dict
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
# Parse voices configuration from options
|
||||
self.voices = {}
|
||||
if "voices" in self.options:
|
||||
try:
|
||||
voices_data = self.options["voices"]
|
||||
if isinstance(voices_data, str):
|
||||
voices_list = json.loads(voices_data)
|
||||
else:
|
||||
voices_list = voices_data
|
||||
|
||||
for voice_entry in voices_list:
|
||||
if not isinstance(voice_entry, dict):
|
||||
print(
|
||||
f"[WARNING] Invalid voice entry (not a dict): {voice_entry}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
name = voice_entry.get("name")
|
||||
audio = voice_entry.get("audio")
|
||||
ref_text = voice_entry.get("ref_text", "")
|
||||
|
||||
if not name or not isinstance(name, str):
|
||||
print(
|
||||
f"[WARNING] Voice entry missing required 'name' field: {voice_entry}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
if not audio or not isinstance(audio, str):
|
||||
print(
|
||||
f"[WARNING] Voice entry missing required 'audio' field: {voice_entry}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
self.voices[name] = {"audio": audio, "ref_text": ref_text}
|
||||
print(
|
||||
f"[INFO] Registered voice '{name}' with audio: {audio}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
print(f"[INFO] Loaded {len(self.voices)} voice(s)", file=sys.stderr)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[ERROR] Failed to parse voices JSON: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[ERROR] Error processing voices configuration: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
# Store AudioPath, ModelFile, and ModelPath from LoadModel request
|
||||
self.audio_path = (
|
||||
request.AudioPath
|
||||
if hasattr(request, "AudioPath") and request.AudioPath
|
||||
else None
|
||||
)
|
||||
self.model_file = (
|
||||
request.ModelFile
|
||||
if hasattr(request, "ModelFile") and request.ModelFile
|
||||
else None
|
||||
)
|
||||
self.model_path = (
|
||||
request.ModelPath
|
||||
if hasattr(request, "ModelPath") and request.ModelPath
|
||||
else None
|
||||
)
|
||||
|
||||
# Get model path from request
|
||||
model_path = request.Model
|
||||
if not model_path:
|
||||
model_path = "fishaudio/s2-pro"
|
||||
|
||||
# If model_path looks like a HuggingFace repo ID (e.g. "fishaudio/fish-speech-1.5"),
|
||||
# download it locally first since fish-speech expects a local directory
|
||||
if "/" in model_path and not os.path.exists(model_path):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print(
|
||||
f"Downloading model from HuggingFace: {model_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
model_path = snapshot_download(repo_id=model_path)
|
||||
print(f"Model downloaded to: {model_path}", file=sys.stderr)
|
||||
|
||||
# Determine precision
|
||||
if device in ("mps", "cpu"):
|
||||
precision = torch.float32
|
||||
else:
|
||||
precision = torch.bfloat16
|
||||
|
||||
# Whether to use torch.compile
|
||||
compile_model = self.options.get("compile", False)
|
||||
|
||||
print(
|
||||
f"Using device: {device}, precision: {precision}, compile: {compile_model}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f"Loading model from: {model_path}", file=sys.stderr)
|
||||
|
||||
# Import fish-speech modules
|
||||
from fish_speech.inference_engine import TTSInferenceEngine
|
||||
from fish_speech.models.dac.inference import load_model as load_decoder_model
|
||||
from fish_speech.models.text2semantic.inference import (
|
||||
launch_thread_safe_queue,
|
||||
)
|
||||
|
||||
# Determine decoder checkpoint path
|
||||
# The codec model is typically at <checkpoint_path>/codec.pth
|
||||
decoder_checkpoint = self.options.get("decoder_checkpoint", None)
|
||||
if not decoder_checkpoint:
|
||||
# Try common locations
|
||||
if os.path.isdir(model_path):
|
||||
candidate = os.path.join(model_path, "codec.pth")
|
||||
if os.path.exists(candidate):
|
||||
decoder_checkpoint = candidate
|
||||
|
||||
# Launch LLaMA queue (runs in daemon thread)
|
||||
print("Launching LLaMA queue...", file=sys.stderr)
|
||||
llama_queue = launch_thread_safe_queue(
|
||||
checkpoint_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
compile=compile_model,
|
||||
)
|
||||
|
||||
# Load DAC decoder
|
||||
decoder_config = self.options.get("decoder_config", "modded_dac_vq")
|
||||
if not decoder_checkpoint:
|
||||
return backend_pb2.Result(
|
||||
success=False,
|
||||
message="Decoder checkpoint (codec.pth) not found. "
|
||||
"Ensure the model directory contains codec.pth or set "
|
||||
"decoder_checkpoint option.",
|
||||
)
|
||||
print(
|
||||
f"Loading DAC decoder (config={decoder_config}, checkpoint={decoder_checkpoint})...",
|
||||
file=sys.stderr,
|
||||
)
|
||||
decoder_model = load_decoder_model(
|
||||
config_name=decoder_config,
|
||||
checkpoint_path=decoder_checkpoint,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create TTS inference engine
|
||||
self.engine = TTSInferenceEngine(
|
||||
llama_queue=llama_queue,
|
||||
decoder_model=decoder_model,
|
||||
precision=precision,
|
||||
compile=compile_model,
|
||||
)
|
||||
|
||||
print(f"Model loaded successfully: {model_path}", file=sys.stderr)
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Loading model: {type(e).__name__}: {e}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(
|
||||
success=False, message=f"Failed to load model: {e}"
|
||||
)
|
||||
|
||||
def _get_ref_audio_path(self, voice_name=None):
|
||||
"""Get reference audio path from voices dict or stored AudioPath."""
|
||||
if voice_name and voice_name in self.voices:
|
||||
audio_path = self.voices[voice_name]["audio"]
|
||||
|
||||
if os.path.isabs(audio_path):
|
||||
return audio_path
|
||||
|
||||
# Try relative to ModelFile
|
||||
if self.model_file:
|
||||
model_file_base = os.path.dirname(self.model_file)
|
||||
ref_path = os.path.join(model_file_base, audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
# Try relative to ModelPath
|
||||
if self.model_path:
|
||||
ref_path = os.path.join(self.model_path, audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
return audio_path
|
||||
|
||||
# Fall back to legacy single-voice mode
|
||||
if not self.audio_path:
|
||||
return None
|
||||
|
||||
if os.path.isabs(self.audio_path):
|
||||
return self.audio_path
|
||||
|
||||
if self.model_file:
|
||||
model_file_base = os.path.dirname(self.model_file)
|
||||
ref_path = os.path.join(model_file_base, self.audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
if self.model_path:
|
||||
ref_path = os.path.join(self.model_path, self.audio_path)
|
||||
if os.path.exists(ref_path):
|
||||
return ref_path
|
||||
|
||||
return self.audio_path
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio
|
||||
|
||||
if not request.dst:
|
||||
return backend_pb2.Result(
|
||||
success=False, message="dst (output path) is required"
|
||||
)
|
||||
|
||||
text = request.text.strip()
|
||||
if not text:
|
||||
return backend_pb2.Result(success=False, message="Text is empty")
|
||||
|
||||
# Get generation parameters from options
|
||||
top_p = self.options.get("top_p", 0.8)
|
||||
temperature = self.options.get("temperature", 0.8)
|
||||
repetition_penalty = self.options.get("repetition_penalty", 1.1)
|
||||
max_new_tokens = self.options.get("max_new_tokens", 1024)
|
||||
chunk_length = self.options.get("chunk_length", 200)
|
||||
|
||||
# Build references list for voice cloning
|
||||
references = []
|
||||
voice_name = request.voice if request.voice else None
|
||||
|
||||
if voice_name and voice_name in self.voices:
|
||||
ref_audio_path = self._get_ref_audio_path(voice_name)
|
||||
if ref_audio_path and os.path.exists(ref_audio_path):
|
||||
with open(ref_audio_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
ref_text = self.voices[voice_name].get("ref_text", "")
|
||||
references.append(
|
||||
ServeReferenceAudio(audio=audio_bytes, text=ref_text)
|
||||
)
|
||||
print(
|
||||
f"[INFO] Using voice '{voice_name}' with reference audio: {ref_audio_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
elif self.audio_path:
|
||||
ref_audio_path = self._get_ref_audio_path()
|
||||
if ref_audio_path and os.path.exists(ref_audio_path):
|
||||
with open(ref_audio_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
ref_text = self.options.get("ref_text", "")
|
||||
references.append(
|
||||
ServeReferenceAudio(audio=audio_bytes, text=ref_text)
|
||||
)
|
||||
print(
|
||||
f"[INFO] Using reference audio: {ref_audio_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Build ServeTTSRequest
|
||||
tts_request = ServeTTSRequest(
|
||||
text=text,
|
||||
references=references,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_new_tokens=max_new_tokens,
|
||||
chunk_length=chunk_length,
|
||||
)
|
||||
|
||||
# Run inference
|
||||
print(f"Generating speech for text: {text[:100]}...", file=sys.stderr)
|
||||
start_time = time.time()
|
||||
|
||||
sample_rate = None
|
||||
audio_data = None
|
||||
|
||||
for result in self.engine.inference(tts_request):
|
||||
if result.code == "final":
|
||||
sample_rate, audio_data = result.audio
|
||||
elif result.code == "error":
|
||||
error_msg = str(result.error) if result.error else "Unknown error"
|
||||
print(f"[ERROR] TTS inference error: {error_msg}", file=sys.stderr)
|
||||
return backend_pb2.Result(
|
||||
success=False, message=f"TTS inference error: {error_msg}"
|
||||
)
|
||||
|
||||
generation_duration = time.time() - start_time
|
||||
|
||||
if audio_data is None or sample_rate is None:
|
||||
return backend_pb2.Result(
|
||||
success=False, message="No audio output generated"
|
||||
)
|
||||
|
||||
# Ensure audio_data is a numpy array
|
||||
if not isinstance(audio_data, np.ndarray):
|
||||
audio_data = np.array(audio_data)
|
||||
|
||||
audio_duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
|
||||
print(
|
||||
f"[INFO] TTS generation completed: {generation_duration:.2f}s, "
|
||||
f"audio_duration={audio_duration:.2f}s, sample_rate={sample_rate}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Save output
|
||||
sf.write(request.dst, audio_data, sample_rate)
|
||||
print(f"Saved {audio_duration:.2f}s audio to {request.dst}", file=sys.stderr)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in TTS: {err}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return backend_pb2.Result(
|
||||
success=False, message=f"Unexpected {err=}, {type(err)=}"
|
||||
)
|
||||
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
("grpc.max_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
||||
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
51
backend/python/fish-speech/install.sh
Normal file
51
backend/python/fish-speech/install.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation"
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# fish-speech uses pyrootutils which requires a .project-root marker
|
||||
touch "${backend_dir}/.project-root"
|
||||
|
||||
installRequirements
|
||||
|
||||
# Clone fish-speech source (the pip package doesn't include inference modules)
|
||||
FISH_SPEECH_DIR="${EDIR}/fish-speech-src"
|
||||
FISH_SPEECH_REPO="https://github.com/fishaudio/fish-speech.git"
|
||||
FISH_SPEECH_BRANCH="main"
|
||||
|
||||
if [ ! -d "${FISH_SPEECH_DIR}" ]; then
|
||||
echo "Cloning fish-speech source..."
|
||||
git clone --depth 1 --branch "${FISH_SPEECH_BRANCH}" "${FISH_SPEECH_REPO}" "${FISH_SPEECH_DIR}"
|
||||
else
|
||||
echo "Updating fish-speech source..."
|
||||
cd "${FISH_SPEECH_DIR}" && git pull && cd -
|
||||
fi
|
||||
|
||||
# Remove pyaudio from fish-speech deps — it's only used by the upstream client tool
|
||||
# (tools/api_client.py) for speaker playback, not by our gRPC backend server.
|
||||
# It requires native portaudio libs which aren't available on all build environments.
|
||||
sed -i.bak '/"pyaudio"/d' "${FISH_SPEECH_DIR}/pyproject.toml"
|
||||
|
||||
# Install fish-speech deps from source (without the package itself since we use PYTHONPATH)
|
||||
ensureVenv
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e "${FISH_SPEECH_DIR}"
|
||||
else
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e "${FISH_SPEECH_DIR}"
|
||||
fi
|
||||
|
||||
# fish-speech transitive deps (wandb, tensorboard) may downgrade protobuf to 3.x
|
||||
# but our generated backend_pb2.py requires protobuf 5+
|
||||
ensureVenv
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install "protobuf>=5.29.0"
|
||||
else
|
||||
uv pip install "protobuf>=5.29.0"
|
||||
fi
|
||||
15
backend/python/fish-speech/package.sh
Executable file
15
backend/python/fish-speech/package.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to package runtime libraries for the fish-speech backend
|
||||
# This is needed because the final Docker image is FROM scratch,
|
||||
# so system libraries must be explicitly included.
|
||||
|
||||
set -e
|
||||
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/lib
|
||||
|
||||
echo "fish-speech packaging completed successfully"
|
||||
ls -liah $CURDIR/lib/
|
||||
3
backend/python/fish-speech/requirements-cpu.txt
Normal file
3
backend/python/fish-speech/requirements-cpu.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch
|
||||
torchaudio
|
||||
3
backend/python/fish-speech/requirements-cublas12.txt
Normal file
3
backend/python/fish-speech/requirements-cublas12.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torch
|
||||
torchaudio
|
||||
3
backend/python/fish-speech/requirements-cublas13.txt
Normal file
3
backend/python/fish-speech/requirements-cublas13.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
torchaudio
|
||||
3
backend/python/fish-speech/requirements-hipblas.txt
Normal file
3
backend/python/fish-speech/requirements-hipblas.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
torch==2.7.1+rocm6.3
|
||||
torchaudio==2.7.1+rocm6.3
|
||||
3
backend/python/fish-speech/requirements-intel.txt
Normal file
3
backend/python/fish-speech/requirements-intel.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||
torch
|
||||
torchaudio
|
||||
3
backend/python/fish-speech/requirements-l4t12.txt
Normal file
3
backend/python/fish-speech/requirements-l4t12.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/
|
||||
torch
|
||||
torchaudio
|
||||
3
backend/python/fish-speech/requirements-l4t13.txt
Normal file
3
backend/python/fish-speech/requirements-l4t13.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
torch
|
||||
torchaudio
|
||||
2
backend/python/fish-speech/requirements-mps.txt
Normal file
2
backend/python/fish-speech/requirements-mps.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
torch
|
||||
torchaudio
|
||||
9
backend/python/fish-speech/requirements.txt
Normal file
9
backend/python/fish-speech/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
soundfile
|
||||
setuptools
|
||||
six
|
||||
scipy
|
||||
numpy
|
||||
9
backend/python/fish-speech/run.sh
Normal file
9
backend/python/fish-speech/run.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
175
backend/python/fish-speech/test.py
Normal file
175
backend/python/fish-speech/test.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
A test script to test the gRPC service
|
||||
"""
|
||||
import signal
|
||||
import threading
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
BACKEND_LOG = "/tmp/fish-speech-backend.log"
|
||||
|
||||
|
||||
def _dump_backend_log():
|
||||
"""Print backend log — call before exiting so CI always shows it."""
|
||||
if os.path.exists(BACKEND_LOG):
|
||||
with open(BACKEND_LOG, "r") as f:
|
||||
contents = f.read()
|
||||
if contents:
|
||||
print("=== Backend Log ===", file=sys.stderr, flush=True)
|
||||
print(contents, file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
def _sigterm_handler(signum, frame):
|
||||
"""Handle SIGTERM so the backend log is printed before exit."""
|
||||
print(f"\nReceived signal {signum}, dumping backend log before exit...",
|
||||
file=sys.stderr, flush=True)
|
||||
_dump_backend_log()
|
||||
sys.exit(143)
|
||||
|
||||
|
||||
signal.signal(signal.SIGTERM, _sigterm_handler)
|
||||
|
||||
|
||||
def _tail_log(path, stop_event, interval=10):
|
||||
"""Background thread that periodically prints new lines from the backend log."""
|
||||
pos = 0
|
||||
while not stop_event.is_set():
|
||||
stop_event.wait(interval)
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
f.seek(pos)
|
||||
new = f.read()
|
||||
if new:
|
||||
print(f"[backend log] {new}", file=sys.stderr, end="", flush=True)
|
||||
pos = f.tell()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
This method sets up the gRPC service by starting the server
|
||||
"""
|
||||
print("Starting backend server...", file=sys.stderr, flush=True)
|
||||
self.backend_log = open(BACKEND_LOG, "w")
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50051"],
|
||||
stdout=self.backend_log,
|
||||
stderr=self.backend_log,
|
||||
)
|
||||
|
||||
# Start tailing backend log so CI sees progress in real time
|
||||
self._log_stop = threading.Event()
|
||||
self._log_thread = threading.Thread(
|
||||
target=_tail_log, args=(BACKEND_LOG, self._log_stop), daemon=True
|
||||
)
|
||||
self._log_thread.start()
|
||||
|
||||
# Poll for readiness instead of a fixed sleep
|
||||
print("Waiting for backend to be ready...", file=sys.stderr, flush=True)
|
||||
max_wait = 60
|
||||
start = time.time()
|
||||
ready = False
|
||||
while time.time() - start < max_wait:
|
||||
try:
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
resp = stub.Health(backend_pb2.HealthMessage(), timeout=2.0)
|
||||
if resp.message:
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
# Check if process died
|
||||
if self.service.poll() is not None:
|
||||
self.fail(f"Backend process exited early with code {self.service.returncode}")
|
||||
time.sleep(2)
|
||||
|
||||
elapsed = time.time() - start
|
||||
if not ready:
|
||||
self.fail(f"Backend not ready after {max_wait}s")
|
||||
print(f"Backend ready after {elapsed:.1f}s", file=sys.stderr, flush=True)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
This method tears down the gRPC service by terminating the server
|
||||
"""
|
||||
self._log_stop.set()
|
||||
self._log_thread.join(timeout=2)
|
||||
self.service.terminate()
|
||||
try:
|
||||
self.service.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.service.kill()
|
||||
self.service.wait()
|
||||
self.backend_log.close()
|
||||
_dump_backend_log()
|
||||
|
||||
def test_tts(self):
|
||||
"""
|
||||
This method tests if the TTS generation works successfully
|
||||
"""
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
# Limit max_new_tokens for CPU testing (generation is very slow on CPU)
|
||||
print("Loading model fishaudio/s2-pro...", file=sys.stderr, flush=True)
|
||||
load_start = time.time()
|
||||
response = stub.LoadModel(
|
||||
backend_pb2.ModelOptions(
|
||||
Model="fishaudio/s2-pro",
|
||||
Options=["max_new_tokens:50"],
|
||||
),
|
||||
timeout=1800.0
|
||||
)
|
||||
print(
|
||||
f"LoadModel response: success={response.success}, "
|
||||
f"message={response.message}, "
|
||||
f"took {time.time() - load_start:.1f}s",
|
||||
file=sys.stderr, flush=True
|
||||
)
|
||||
self.assertTrue(response.success, f"LoadModel failed: {response.message}")
|
||||
|
||||
# Create temporary output file
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
|
||||
tts_request = backend_pb2.TTSRequest(
|
||||
text="Hi.",
|
||||
dst=output_path
|
||||
)
|
||||
# Allow up to 10 minutes for TTS generation on CPU
|
||||
print("Starting TTS generation...", file=sys.stderr, flush=True)
|
||||
tts_start = time.time()
|
||||
tts_response = stub.TTS(tts_request, timeout=600.0)
|
||||
print(
|
||||
f"TTS response: success={tts_response.success}, "
|
||||
f"took {time.time() - tts_start:.1f}s",
|
||||
file=sys.stderr, flush=True
|
||||
)
|
||||
self.assertIsNotNone(tts_response)
|
||||
self.assertTrue(tts_response.success)
|
||||
|
||||
# Verify output file exists and is not empty
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
file_size = os.path.getsize(output_path)
|
||||
print(f"Output file size: {file_size} bytes", file=sys.stderr, flush=True)
|
||||
self.assertGreater(file_size, 0)
|
||||
|
||||
# Cleanup
|
||||
os.unlink(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
11
backend/python/fish-speech/test.sh
Normal file
11
backend/python/fish-speech/test.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
Reference in New Issue
Block a user