mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-31 09:42:45 -05:00
* feat(vllm-omni: add new backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * default to py3.12 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
683 lines
30 KiB
Python
683 lines
30 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
LocalAI vLLM-Omni Backend
|
|
|
|
This backend provides gRPC access to vllm-omni for multimodal generation:
|
|
- Image generation (text-to-image, image editing)
|
|
- Video generation (text-to-video, image-to-video)
|
|
- Text generation with multimodal inputs (LLM)
|
|
- Text-to-speech generation
|
|
"""
|
|
from concurrent import futures
|
|
import traceback
|
|
import argparse
|
|
import signal
|
|
import sys
|
|
import time
|
|
import os
|
|
import base64
|
|
import io
|
|
|
|
from PIL import Image
|
|
import torch
|
|
import numpy as np
|
|
import soundfile as sf
|
|
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
import grpc
|
|
|
|
from vllm_omni.entrypoints.omni import Omni
|
|
from vllm_omni.outputs import OmniRequestOutput
|
|
from vllm_omni.diffusion.data import DiffusionParallelConfig
|
|
from vllm_omni.utils.platform_utils import detect_device_type, is_npu
|
|
from vllm import SamplingParams
|
|
from diffusers.utils import export_to_video
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
|
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
|
|
|
|
def is_float(s):
|
|
"""Check if a string can be converted to float."""
|
|
try:
|
|
float(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
def is_int(s):
|
|
"""Check if a string can be converted to int."""
|
|
try:
|
|
int(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
# Implement the BackendServicer class with the service methods
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
|
|
def _detect_model_type(self, model_name):
|
|
"""Detect model type from model name."""
|
|
model_lower = model_name.lower()
|
|
if "tts" in model_lower or "qwen3-tts" in model_lower:
|
|
return "tts"
|
|
elif "omni" in model_lower and "qwen3" in model_lower:
|
|
return "llm"
|
|
elif "wan" in model_lower or "t2v" in model_lower or "i2v" in model_lower:
|
|
return "video"
|
|
elif "image" in model_lower or "z-image" in model_lower or "qwen-image" in model_lower:
|
|
return "image"
|
|
else:
|
|
# Default to image for diffusion models, llm for others
|
|
return "image"
|
|
|
|
def _detect_tts_task_type(self):
|
|
"""Detect TTS task type from model name."""
|
|
model_lower = self.model_name.lower()
|
|
if "customvoice" in model_lower:
|
|
return "CustomVoice"
|
|
elif "voicedesign" in model_lower:
|
|
return "VoiceDesign"
|
|
elif "base" in model_lower:
|
|
return "Base"
|
|
else:
|
|
# Default to CustomVoice
|
|
return "CustomVoice"
|
|
|
|
def _load_image(self, image_path):
|
|
"""Load an image from file path or base64 encoded data."""
|
|
# Try file path first
|
|
if os.path.exists(image_path):
|
|
return Image.open(image_path)
|
|
# Try base64 decode
|
|
try:
|
|
image_data = base64.b64decode(image_path)
|
|
return Image.open(io.BytesIO(image_data))
|
|
except:
|
|
return None
|
|
|
|
def _load_video(self, video_path):
|
|
"""Load a video from file path or base64 encoded data."""
|
|
from vllm.assets.video import VideoAsset, video_to_ndarrays
|
|
if os.path.exists(video_path):
|
|
return video_to_ndarrays(video_path, num_frames=16)
|
|
# Try base64 decode
|
|
try:
|
|
timestamp = str(int(time.time() * 1000))
|
|
p = f"/tmp/vl-{timestamp}.data"
|
|
with open(p, "wb") as f:
|
|
f.write(base64.b64decode(video_path))
|
|
video = VideoAsset(name=p).np_ndarrays
|
|
os.remove(p)
|
|
return video
|
|
except:
|
|
return None
|
|
|
|
def _load_audio(self, audio_path):
|
|
"""Load audio from file path or base64 encoded data."""
|
|
import librosa
|
|
if os.path.exists(audio_path):
|
|
audio_signal, sr = librosa.load(audio_path, sr=16000)
|
|
return (audio_signal.astype(np.float32), sr)
|
|
# Try base64 decode
|
|
try:
|
|
audio_data = base64.b64decode(audio_path)
|
|
# Save to temp file and load
|
|
timestamp = str(int(time.time() * 1000))
|
|
p = f"/tmp/audio-{timestamp}.wav"
|
|
with open(p, "wb") as f:
|
|
f.write(audio_data)
|
|
audio_signal, sr = librosa.load(p, sr=16000)
|
|
os.remove(p)
|
|
return (audio_signal.astype(np.float32), sr)
|
|
except:
|
|
return None
|
|
|
|
def Health(self, request, context):
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
|
|
def LoadModel(self, request, context):
|
|
try:
|
|
print(f"Loading model {request.Model}...", file=sys.stderr)
|
|
print(f"Request {request}", file=sys.stderr)
|
|
|
|
# Parse options from request.Options (key:value pairs)
|
|
self.options = {}
|
|
for opt in request.Options:
|
|
if ":" not in opt:
|
|
continue
|
|
key, value = opt.split(":", 1)
|
|
# Convert value to appropriate type
|
|
if is_float(value):
|
|
value = float(value)
|
|
elif is_int(value):
|
|
value = int(value)
|
|
elif value.lower() in ["true", "false"]:
|
|
value = value.lower() == "true"
|
|
self.options[key] = value
|
|
|
|
print(f"Options: {self.options}", file=sys.stderr)
|
|
|
|
# Detect model type
|
|
self.model_name = request.Model
|
|
self.model_type = request.Type if request.Type else self._detect_model_type(request.Model)
|
|
print(f"Detected model type: {self.model_type}", file=sys.stderr)
|
|
|
|
# Build DiffusionParallelConfig if diffusion model (image or video)
|
|
parallel_config = None
|
|
if self.model_type in ["image", "video"]:
|
|
parallel_config = DiffusionParallelConfig(
|
|
ulysses_degree=self.options.get("ulysses_degree", 1),
|
|
ring_degree=self.options.get("ring_degree", 1),
|
|
cfg_parallel_size=self.options.get("cfg_parallel_size", 1),
|
|
tensor_parallel_size=self.options.get("tensor_parallel_size", 1),
|
|
)
|
|
|
|
# Build cache_config dict if cache_backend specified
|
|
cache_backend = self.options.get("cache_backend") # "cache_dit" or "tea_cache"
|
|
cache_config = None
|
|
if cache_backend == "cache_dit":
|
|
cache_config = {
|
|
"Fn_compute_blocks": self.options.get("cache_dit_fn_compute_blocks", 1),
|
|
"Bn_compute_blocks": self.options.get("cache_dit_bn_compute_blocks", 0),
|
|
"max_warmup_steps": self.options.get("cache_dit_max_warmup_steps", 4),
|
|
"residual_diff_threshold": self.options.get("cache_dit_residual_diff_threshold", 0.24),
|
|
"max_continuous_cached_steps": self.options.get("cache_dit_max_continuous_cached_steps", 3),
|
|
"enable_taylorseer": self.options.get("cache_dit_enable_taylorseer", False),
|
|
"taylorseer_order": self.options.get("cache_dit_taylorseer_order", 1),
|
|
"scm_steps_mask_policy": self.options.get("cache_dit_scm_steps_mask_policy"),
|
|
"scm_steps_policy": self.options.get("cache_dit_scm_steps_policy", "dynamic"),
|
|
}
|
|
elif cache_backend == "tea_cache":
|
|
cache_config = {
|
|
"rel_l1_thresh": self.options.get("tea_cache_rel_l1_thresh", 0.2),
|
|
}
|
|
|
|
# Base Omni initialization parameters
|
|
omni_kwargs = {
|
|
"model": request.Model,
|
|
}
|
|
|
|
# Add diffusion-specific parameters (image/video models)
|
|
if self.model_type in ["image", "video"]:
|
|
omni_kwargs.update({
|
|
"vae_use_slicing": is_npu(),
|
|
"vae_use_tiling": is_npu(),
|
|
"cache_backend": cache_backend,
|
|
"cache_config": cache_config,
|
|
"parallel_config": parallel_config,
|
|
"enforce_eager": self.options.get("enforce_eager", request.EnforceEager),
|
|
"enable_cpu_offload": self.options.get("enable_cpu_offload", False),
|
|
})
|
|
# Video-specific parameters
|
|
if self.model_type == "video":
|
|
omni_kwargs.update({
|
|
"boundary_ratio": self.options.get("boundary_ratio", 0.875),
|
|
"flow_shift": self.options.get("flow_shift", 5.0),
|
|
})
|
|
|
|
# Add LLM/TTS-specific parameters
|
|
if self.model_type in ["llm", "tts"]:
|
|
omni_kwargs.update({
|
|
"stage_configs_path": self.options.get("stage_configs_path"),
|
|
"log_stats": self.options.get("enable_stats", False),
|
|
"stage_init_timeout": self.options.get("stage_init_timeout", 300),
|
|
})
|
|
# vllm engine options (passed through Omni for LLM/TTS)
|
|
if request.GPUMemoryUtilization > 0:
|
|
omni_kwargs["gpu_memory_utilization"] = request.GPUMemoryUtilization
|
|
if request.TensorParallelSize > 0:
|
|
omni_kwargs["tensor_parallel_size"] = request.TensorParallelSize
|
|
if request.TrustRemoteCode:
|
|
omni_kwargs["trust_remote_code"] = request.TrustRemoteCode
|
|
if request.MaxModelLen > 0:
|
|
omni_kwargs["max_model_len"] = request.MaxModelLen
|
|
|
|
self.omni = Omni(**omni_kwargs)
|
|
print("Model loaded successfully", file=sys.stderr)
|
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
|
|
except Exception as err:
|
|
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
|
|
traceback.print_exc()
|
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
|
|
def GenerateImage(self, request, context):
|
|
try:
|
|
# Validate model is loaded and is image/diffusion type
|
|
if not hasattr(self, 'omni'):
|
|
return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.")
|
|
if self.model_type not in ["image"]:
|
|
return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support image generation")
|
|
|
|
# Extract parameters
|
|
prompt = request.positive_prompt
|
|
negative_prompt = request.negative_prompt if request.negative_prompt else None
|
|
width = request.width if request.width > 0 else 1024
|
|
height = request.height if request.height > 0 else 1024
|
|
seed = request.seed if request.seed > 0 else None
|
|
num_inference_steps = request.step if request.step > 0 else 50
|
|
cfg_scale = self.options.get("cfg_scale", 4.0)
|
|
guidance_scale = self.options.get("guidance_scale", 1.0)
|
|
|
|
# Create generator if seed provided
|
|
generator = None
|
|
if seed:
|
|
device = detect_device_type()
|
|
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
# Handle image input for image editing
|
|
pil_image = None
|
|
if request.src or (request.ref_images and len(request.ref_images) > 0):
|
|
image_path = request.ref_images[0] if request.ref_images else request.src
|
|
pil_image = self._load_image(image_path)
|
|
if pil_image is None:
|
|
return backend_pb2.Result(success=False, message=f"Invalid image source: {image_path}")
|
|
pil_image = pil_image.convert("RGB")
|
|
|
|
# Build generate kwargs
|
|
generate_kwargs = {
|
|
"prompt": prompt,
|
|
"negative_prompt": negative_prompt,
|
|
"height": height,
|
|
"width": width,
|
|
"generator": generator,
|
|
"true_cfg_scale": cfg_scale,
|
|
"guidance_scale": guidance_scale,
|
|
"num_inference_steps": num_inference_steps,
|
|
}
|
|
if pil_image:
|
|
generate_kwargs["pil_image"] = pil_image
|
|
|
|
# Call omni.generate()
|
|
outputs = self.omni.generate(**generate_kwargs)
|
|
|
|
# Extract images (following example pattern)
|
|
if not outputs or len(outputs) == 0:
|
|
return backend_pb2.Result(success=False, message="No output generated")
|
|
|
|
first_output = outputs[0]
|
|
if not hasattr(first_output, "request_output") or not first_output.request_output:
|
|
return backend_pb2.Result(success=False, message="Invalid output structure")
|
|
|
|
req_out = first_output.request_output[0]
|
|
if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
|
|
return backend_pb2.Result(success=False, message="No images in output")
|
|
|
|
images = req_out.images
|
|
if not images or len(images) == 0:
|
|
return backend_pb2.Result(success=False, message="Empty images list")
|
|
|
|
# Save image
|
|
output_image = images[0]
|
|
output_image.save(request.dst)
|
|
return backend_pb2.Result(message="Image generated successfully", success=True)
|
|
|
|
except Exception as err:
|
|
print(f"Error generating image: {err}", file=sys.stderr)
|
|
traceback.print_exc()
|
|
return backend_pb2.Result(success=False, message=f"Error generating image: {err}")
|
|
|
|
def GenerateVideo(self, request, context):
|
|
try:
|
|
# Validate model is loaded and is video/diffusion type
|
|
if not hasattr(self, 'omni'):
|
|
return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.")
|
|
if self.model_type not in ["video"]:
|
|
return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support video generation")
|
|
|
|
# Extract parameters
|
|
prompt = request.prompt
|
|
negative_prompt = request.negative_prompt if request.negative_prompt else ""
|
|
width = request.width if request.width > 0 else 1280
|
|
height = request.height if request.height > 0 else 720
|
|
num_frames = request.num_frames if request.num_frames > 0 else 81
|
|
fps = request.fps if request.fps > 0 else 24
|
|
seed = request.seed if request.seed > 0 else None
|
|
guidance_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
|
|
guidance_scale_high = self.options.get("guidance_scale_high")
|
|
num_inference_steps = request.step if request.step > 0 else 40
|
|
|
|
# Create generator
|
|
generator = None
|
|
if seed:
|
|
device = detect_device_type()
|
|
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
# Handle image input for image-to-video
|
|
pil_image = None
|
|
if request.start_image:
|
|
pil_image = self._load_image(request.start_image)
|
|
if pil_image is None:
|
|
return backend_pb2.Result(success=False, message=f"Invalid start_image: {request.start_image}")
|
|
pil_image = pil_image.convert("RGB")
|
|
# Resize to target dimensions
|
|
pil_image = pil_image.resize((width, height), Image.Resampling.LANCZOS)
|
|
|
|
# Build generate kwargs
|
|
generate_kwargs = {
|
|
"prompt": prompt,
|
|
"negative_prompt": negative_prompt,
|
|
"height": height,
|
|
"width": width,
|
|
"generator": generator,
|
|
"guidance_scale": guidance_scale,
|
|
"num_inference_steps": num_inference_steps,
|
|
"num_frames": num_frames,
|
|
}
|
|
if pil_image:
|
|
generate_kwargs["pil_image"] = pil_image
|
|
if guidance_scale_high:
|
|
generate_kwargs["guidance_scale_2"] = guidance_scale_high
|
|
|
|
# Call omni.generate()
|
|
frames = self.omni.generate(**generate_kwargs)
|
|
|
|
# Extract video frames (following example pattern)
|
|
if isinstance(frames, list) and len(frames) > 0:
|
|
first_item = frames[0]
|
|
|
|
if hasattr(first_item, "final_output_type"):
|
|
if first_item.final_output_type != "image":
|
|
return backend_pb2.Result(success=False, message=f"Unexpected output type: {first_item.final_output_type}")
|
|
|
|
# Pipeline mode: extract from nested request_output
|
|
if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output:
|
|
if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0:
|
|
inner_output = first_item.request_output[0]
|
|
if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"):
|
|
frames = inner_output.images[0] if inner_output.images else None
|
|
# Diffusion mode: use direct images field
|
|
elif hasattr(first_item, "images") and first_item.images:
|
|
frames = first_item.images
|
|
else:
|
|
return backend_pb2.Result(success=False, message="No video frames found")
|
|
|
|
if frames is None:
|
|
return backend_pb2.Result(success=False, message="No video frames found in output")
|
|
|
|
# Convert frames to numpy array (following example)
|
|
if isinstance(frames, torch.Tensor):
|
|
video_tensor = frames.detach().cpu()
|
|
# Handle different tensor shapes [B, C, F, H, W] or [B, F, H, W, C]
|
|
if video_tensor.dim() == 5:
|
|
if video_tensor.shape[1] in (3, 4):
|
|
video_tensor = video_tensor[0].permute(1, 2, 3, 0)
|
|
else:
|
|
video_tensor = video_tensor[0]
|
|
elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4):
|
|
video_tensor = video_tensor.permute(1, 2, 3, 0)
|
|
# Normalize from [-1,1] to [0,1] if float
|
|
if video_tensor.is_floating_point():
|
|
video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5
|
|
video_array = video_tensor.float().numpy()
|
|
else:
|
|
video_array = frames
|
|
if hasattr(video_array, "shape") and video_array.ndim == 5:
|
|
video_array = video_array[0]
|
|
|
|
# Convert 4D array (frames, H, W, C) to list of frames
|
|
if isinstance(video_array, np.ndarray) and video_array.ndim == 4:
|
|
video_array = list(video_array)
|
|
|
|
# Save video
|
|
export_to_video(video_array, request.dst, fps=fps)
|
|
return backend_pb2.Result(message="Video generated successfully", success=True)
|
|
|
|
except Exception as err:
|
|
print(f"Error generating video: {err}", file=sys.stderr)
|
|
traceback.print_exc()
|
|
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
|
|
|
|
def Predict(self, request, context):
|
|
"""Non-streaming text generation with multimodal inputs."""
|
|
gen = self._predict(request, context, streaming=False)
|
|
try:
|
|
res = next(gen)
|
|
return res
|
|
except StopIteration:
|
|
return backend_pb2.Reply(message=bytes("", 'utf-8'))
|
|
|
|
def PredictStream(self, request, context):
|
|
"""Streaming text generation with multimodal inputs."""
|
|
return self._predict(request, context, streaming=True)
|
|
|
|
def _predict(self, request, context, streaming=False):
|
|
"""Internal method for text generation (streaming and non-streaming)."""
|
|
try:
|
|
# Validate model is loaded and is LLM type
|
|
if not hasattr(self, 'omni'):
|
|
yield backend_pb2.Reply(message=bytes("Model not loaded. Call LoadModel first.", 'utf-8'))
|
|
return
|
|
if self.model_type not in ["llm"]:
|
|
yield backend_pb2.Reply(message=bytes(f"Model type {self.model_type} does not support text generation", 'utf-8'))
|
|
return
|
|
|
|
# Extract prompt
|
|
if request.Prompt:
|
|
prompt = request.Prompt
|
|
elif request.Messages and request.UseTokenizerTemplate:
|
|
# Build prompt from messages (simplified - would need tokenizer for full template)
|
|
prompt = ""
|
|
for msg in request.Messages:
|
|
role = msg.role
|
|
content = msg.content
|
|
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
|
prompt += "<|im_start|>assistant\n"
|
|
else:
|
|
yield backend_pb2.Reply(message=bytes("", 'utf-8'))
|
|
return
|
|
|
|
# Build multi_modal_data dict
|
|
multi_modal_data = {}
|
|
|
|
# Process images
|
|
if request.Images:
|
|
image_data = []
|
|
for img_path in request.Images:
|
|
img = self._load_image(img_path)
|
|
if img:
|
|
# Convert to format expected by vllm
|
|
from vllm.multimodal.image import convert_image_mode
|
|
img_data = convert_image_mode(img, "RGB")
|
|
image_data.append(img_data)
|
|
if image_data:
|
|
multi_modal_data["image"] = image_data
|
|
|
|
# Process videos
|
|
if request.Videos:
|
|
video_data = []
|
|
for video_path in request.Videos:
|
|
video = self._load_video(video_path)
|
|
if video is not None:
|
|
video_data.append(video)
|
|
if video_data:
|
|
multi_modal_data["video"] = video_data
|
|
|
|
# Process audio
|
|
if request.Audios:
|
|
audio_data = []
|
|
for audio_path in request.Audios:
|
|
audio = self._load_audio(audio_path)
|
|
if audio is not None:
|
|
audio_data.append(audio)
|
|
if audio_data:
|
|
multi_modal_data["audio"] = audio_data
|
|
|
|
# Build inputs dict
|
|
inputs = {
|
|
"prompt": prompt,
|
|
"multi_modal_data": multi_modal_data if multi_modal_data else None,
|
|
}
|
|
|
|
# Build sampling params
|
|
sampling_params = SamplingParams(
|
|
temperature=request.Temperature if request.Temperature > 0 else 0.7,
|
|
top_p=request.TopP if request.TopP > 0 else 0.9,
|
|
top_k=request.TopK if request.TopK > 0 else -1,
|
|
max_tokens=request.Tokens if request.Tokens > 0 else 200,
|
|
presence_penalty=request.PresencePenalty if request.PresencePenalty != 0 else 0.0,
|
|
frequency_penalty=request.FrequencyPenalty if request.FrequencyPenalty != 0 else 0.0,
|
|
repetition_penalty=request.RepetitionPenalty if request.RepetitionPenalty != 0 else 1.0,
|
|
seed=request.Seed if request.Seed > 0 else None,
|
|
stop=request.StopPrompts if request.StopPrompts else None,
|
|
stop_token_ids=request.StopTokenIds if request.StopTokenIds else None,
|
|
ignore_eos=request.IgnoreEOS,
|
|
)
|
|
sampling_params_list = [sampling_params]
|
|
|
|
# Call omni.generate() (returns generator for LLM mode)
|
|
omni_generator = self.omni.generate([inputs], sampling_params_list)
|
|
|
|
# Extract text from outputs
|
|
generated_text = ""
|
|
for stage_outputs in omni_generator:
|
|
if stage_outputs.final_output_type == "text":
|
|
for output in stage_outputs.request_output:
|
|
text_output = output.outputs[0].text
|
|
if streaming:
|
|
# Remove already sent text (vllm concatenates)
|
|
delta_text = text_output.removeprefix(generated_text)
|
|
yield backend_pb2.Reply(message=bytes(delta_text, encoding='utf-8'))
|
|
generated_text = text_output
|
|
|
|
if not streaming:
|
|
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
|
|
|
except Exception as err:
|
|
print(f"Error in Predict: {err}", file=sys.stderr)
|
|
traceback.print_exc()
|
|
yield backend_pb2.Reply(message=bytes(f"Error: {err}", encoding='utf-8'))
|
|
|
|
def TTS(self, request, context):
|
|
try:
|
|
# Validate model is loaded and is TTS type
|
|
if not hasattr(self, 'omni'):
|
|
return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.")
|
|
if self.model_type not in ["tts"]:
|
|
return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support TTS")
|
|
|
|
# Extract parameters
|
|
text = request.text
|
|
language = request.language if request.language else "Auto"
|
|
voice = request.voice if request.voice else None
|
|
task_type = self._detect_tts_task_type()
|
|
|
|
# Build prompt with chat template
|
|
# TODO: for now vllm-omni supports only qwen3-tts, so we hardcode it, however, we want to support other models in the future.
|
|
# and we might need to use the chat template here
|
|
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
|
|
|
|
# Build inputs dict
|
|
inputs = {
|
|
"prompt": prompt,
|
|
"additional_information": {
|
|
"task_type": [task_type],
|
|
"text": [text],
|
|
"language": [language],
|
|
"max_new_tokens": [2048],
|
|
}
|
|
}
|
|
|
|
# Add task-specific fields
|
|
if task_type == "CustomVoice":
|
|
if voice:
|
|
inputs["additional_information"]["speaker"] = [voice]
|
|
# Add instruct if provided in options
|
|
if "instruct" in self.options:
|
|
inputs["additional_information"]["instruct"] = [self.options["instruct"]]
|
|
elif task_type == "VoiceDesign":
|
|
if "instruct" in self.options:
|
|
inputs["additional_information"]["instruct"] = [self.options["instruct"]]
|
|
inputs["additional_information"]["non_streaming_mode"] = [True]
|
|
elif task_type == "Base":
|
|
# Voice cloning requires ref_audio and ref_text
|
|
if "ref_audio" in self.options:
|
|
inputs["additional_information"]["ref_audio"] = [self.options["ref_audio"]]
|
|
if "ref_text" in self.options:
|
|
inputs["additional_information"]["ref_text"] = [self.options["ref_text"]]
|
|
if "x_vector_only_mode" in self.options:
|
|
inputs["additional_information"]["x_vector_only_mode"] = [self.options["x_vector_only_mode"]]
|
|
|
|
# Build sampling params
|
|
sampling_params = SamplingParams(
|
|
temperature=0.9,
|
|
top_p=1.0,
|
|
top_k=50,
|
|
max_tokens=2048,
|
|
seed=42,
|
|
detokenize=False,
|
|
repetition_penalty=1.05,
|
|
)
|
|
sampling_params_list = [sampling_params]
|
|
|
|
# Call omni.generate()
|
|
omni_generator = self.omni.generate(inputs, sampling_params_list)
|
|
|
|
# Extract audio (following TTS example)
|
|
for stage_outputs in omni_generator:
|
|
for output in stage_outputs.request_output:
|
|
if "audio" in output.multimodal_output:
|
|
audio_tensor = output.multimodal_output["audio"]
|
|
audio_samplerate = output.multimodal_output["sr"].item()
|
|
|
|
# Convert to numpy
|
|
audio_numpy = audio_tensor.float().detach().cpu().numpy()
|
|
if audio_numpy.ndim > 1:
|
|
audio_numpy = audio_numpy.flatten()
|
|
|
|
# Save audio file
|
|
sf.write(request.dst, audio_numpy, samplerate=audio_samplerate, format="WAV")
|
|
return backend_pb2.Result(message="TTS audio generated successfully", success=True)
|
|
|
|
return backend_pb2.Result(success=False, message="No audio output generated")
|
|
|
|
except Exception as err:
|
|
print(f"Error generating TTS: {err}", file=sys.stderr)
|
|
traceback.print_exc()
|
|
return backend_pb2.Result(success=False, message=f"Error generating TTS: {err}")
|
|
|
|
|
|
def serve(address):
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
options=[
|
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
|
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
|
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
|
])
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
server.add_insecure_port(address)
|
|
server.start()
|
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
|
|
# Signal handlers for graceful shutdown
|
|
def signal_handler(sig, frame):
|
|
print("Received termination signal. Shutting down...")
|
|
server.stop(0)
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
try:
|
|
while True:
|
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
except KeyboardInterrupt:
|
|
server.stop(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
parser.add_argument(
|
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
serve(args.addr)
|