From b2a8a638995016e66573dbf0cbc8f8cc920f74f1 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 24 Jan 2026 22:23:30 +0100 Subject: [PATCH] feat(vllm-omni): add new backend (#8188) * feat(vllm-omni: add new backend Signed-off-by: Ettore Di Giacinto * default to py3.12 Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- .github/workflows/backend.yml | 26 + Makefile | 8 +- backend/index.yaml | 52 ++ backend/python/vllm-omni/Makefile | 23 + backend/python/vllm-omni/backend.py | 682 ++++++++++++++++++ backend/python/vllm-omni/install.sh | 62 ++ .../python/vllm-omni/requirements-after.txt | 2 + .../vllm-omni/requirements-cublas12-after.txt | 1 + .../vllm-omni/requirements-cublas12.txt | 4 + .../python/vllm-omni/requirements-hipblas.txt | 5 + backend/python/vllm-omni/requirements.txt | 7 + backend/python/vllm-omni/run.sh | 11 + backend/python/vllm-omni/test.py | 82 +++ backend/python/vllm-omni/test.sh | 12 + 14 files changed, 975 insertions(+), 2 deletions(-) create mode 100644 backend/python/vllm-omni/Makefile create mode 100644 backend/python/vllm-omni/backend.py create mode 100755 backend/python/vllm-omni/install.sh create mode 100644 backend/python/vllm-omni/requirements-after.txt create mode 100644 backend/python/vllm-omni/requirements-cublas12-after.txt create mode 100644 backend/python/vllm-omni/requirements-cublas12.txt create mode 100644 backend/python/vllm-omni/requirements-hipblas.txt create mode 100644 backend/python/vllm-omni/requirements.txt create mode 100755 backend/python/vllm-omni/run.sh create mode 100644 backend/python/vllm-omni/test.py create mode 100755 backend/python/vllm-omni/test.sh diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 479e06e80..e6da180c0 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -170,6 +170,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'cublas' + cuda-major-version: "12" + cuda-minor-version: "9" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-nvidia-cuda-12-vllm-omni' + runs-on: 'arc-runner-set' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "vllm-omni" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "9" @@ -653,6 +666,19 @@ jobs: dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' + - build-type: 'hipblas' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64' + tag-latest: 'auto' + tag-suffix: '-gpu-rocm-hipblas-vllm-omni' + runs-on: 'arc-runner-set' + base-image: "rocm/dev-ubuntu-24.04:6.4.4" + skip-drivers: 'false' + backend: "vllm-omni" + dockerfile: "./backend/Dockerfile.python" + context: "./" + ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" diff --git a/Makefile b/Makefile index 301adf72e..b5c3e64c8 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/moonshine backends/pocket-tts backends/qwen-tts +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts GOCMD=go GOTEST=$(GOCMD) test @@ -314,6 +314,7 @@ prepare-test-extra: protogen-python $(MAKE) -C backend/python/diffusers $(MAKE) -C backend/python/chatterbox $(MAKE) -C backend/python/vllm + $(MAKE) -C backend/python/vllm-omni $(MAKE) -C backend/python/vibevoice $(MAKE) -C backend/python/moonshine $(MAKE) -C backend/python/pocket-tts @@ -324,6 +325,7 @@ test-extra: prepare-test-extra $(MAKE) -C backend/python/diffusers test $(MAKE) -C backend/python/chatterbox test $(MAKE) -C backend/python/vllm test + $(MAKE) -C backend/python/vllm-omni test $(MAKE) -C backend/python/vibevoice test $(MAKE) -C backend/python/moonshine test $(MAKE) -C backend/python/pocket-tts test @@ -455,6 +457,7 @@ BACKEND_KITTEN_TTS = kitten-tts|python|.|false|true BACKEND_NEUTTS = neutts|python|.|false|true BACKEND_KOKORO = kokoro|python|.|false|true BACKEND_VLLM = vllm|python|.|false|true +BACKEND_VLLM_OMNI = vllm-omni|python|.|false|true BACKEND_DIFFUSERS = diffusers|python|.|--progress=plain|true BACKEND_CHATTERBOX = chatterbox|python|.|false|true BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true @@ -501,6 +504,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_KITTEN_TTS))) $(eval $(call generate-docker-build-target,$(BACKEND_NEUTTS))) $(eval $(call generate-docker-build-target,$(BACKEND_KOKORO))) $(eval $(call generate-docker-build-target,$(BACKEND_VLLM))) +$(eval $(call generate-docker-build-target,$(BACKEND_VLLM_OMNI))) $(eval $(call generate-docker-build-target,$(BACKEND_DIFFUSERS))) $(eval $(call generate-docker-build-target,$(BACKEND_CHATTERBOX))) $(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE))) @@ -512,7 +516,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_QWEN_TTS))) docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar -docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts +docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts ######################################################## ### END Backends diff --git a/backend/index.yaml b/backend/index.yaml index 3b9d3ee69..7fa205f24 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -142,6 +142,31 @@ amd: "rocm-vllm" intel: "intel-vllm" nvidia-cuda-12: "cuda12-vllm" +- &vllm-omni + name: "vllm-omni" + license: apache-2.0 + urls: + - https://github.com/vllm-project/vllm-omni + tags: + - text-to-image + - image-generation + - text-to-video + - video-generation + - text-to-speech + - TTS + - multimodal + - LLM + icon: https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png + description: | + vLLM-Omni is a unified interface for multimodal generation with vLLM. + It supports image generation (text-to-image, image editing), video generation + (text-to-video, image-to-video), text generation with multimodal inputs, and + text-to-speech generation. Only supports NVIDIA (CUDA) and ROCm platforms. + alias: "vllm-omni" + capabilities: + nvidia: "cuda12-vllm-omni" + amd: "rocm-vllm-omni" + nvidia-cuda-12: "cuda12-vllm-omni" - &mlx name: "mlx" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx" @@ -973,6 +998,33 @@ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-vllm" mirrors: - localai/localai-backends:master-gpu-intel-vllm +# vllm-omni +- !!merge <<: *vllm-omni + name: "vllm-omni-development" + capabilities: + nvidia: "cuda12-vllm-omni-development" + amd: "rocm-vllm-omni-development" + nvidia-cuda-12: "cuda12-vllm-omni-development" +- !!merge <<: *vllm-omni + name: "cuda12-vllm-omni" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vllm-omni" + mirrors: + - localai/localai-backends:latest-gpu-nvidia-cuda-12-vllm-omni +- !!merge <<: *vllm-omni + name: "rocm-vllm-omni" + uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-vllm-omni" + mirrors: + - localai/localai-backends:latest-gpu-rocm-hipblas-vllm-omni +- !!merge <<: *vllm-omni + name: "cuda12-vllm-omni-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vllm-omni" + mirrors: + - localai/localai-backends:master-gpu-nvidia-cuda-12-vllm-omni +- !!merge <<: *vllm-omni + name: "rocm-vllm-omni-development" + uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-vllm-omni" + mirrors: + - localai/localai-backends:master-gpu-rocm-hipblas-vllm-omni # rfdetr - !!merge <<: *rfdetr name: "rfdetr-development" diff --git a/backend/python/vllm-omni/Makefile b/backend/python/vllm-omni/Makefile new file mode 100644 index 000000000..aaeb3188a --- /dev/null +++ b/backend/python/vllm-omni/Makefile @@ -0,0 +1,23 @@ +.PHONY: vllm-omni +vllm-omni: + bash install.sh + +.PHONY: run +run: vllm-omni + @echo "Running vllm-omni..." + bash run.sh + @echo "vllm-omni run." + +.PHONY: test +test: vllm-omni + @echo "Testing vllm-omni..." + bash test.sh + @echo "vllm-omni tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/vllm-omni/backend.py b/backend/python/vllm-omni/backend.py new file mode 100644 index 000000000..c21aeb0ad --- /dev/null +++ b/backend/python/vllm-omni/backend.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python3 +""" +LocalAI vLLM-Omni Backend + +This backend provides gRPC access to vllm-omni for multimodal generation: +- Image generation (text-to-image, image editing) +- Video generation (text-to-video, image-to-video) +- Text generation with multimodal inputs (LLM) +- Text-to-speech generation +""" +from concurrent import futures +import traceback +import argparse +import signal +import sys +import time +import os +import base64 +import io + +from PIL import Image +import torch +import numpy as np +import soundfile as sf + +import backend_pb2 +import backend_pb2_grpc + +import grpc + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.utils.platform_utils import detect_device_type, is_npu +from vllm import SamplingParams +from diffusers.utils import export_to_video + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + + +def is_float(s): + """Check if a string can be converted to float.""" + try: + float(s) + return True + except ValueError: + return False + + +def is_int(s): + """Check if a string can be converted to int.""" + try: + int(s) + return True + except ValueError: + return False + + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + + def _detect_model_type(self, model_name): + """Detect model type from model name.""" + model_lower = model_name.lower() + if "tts" in model_lower or "qwen3-tts" in model_lower: + return "tts" + elif "omni" in model_lower and "qwen3" in model_lower: + return "llm" + elif "wan" in model_lower or "t2v" in model_lower or "i2v" in model_lower: + return "video" + elif "image" in model_lower or "z-image" in model_lower or "qwen-image" in model_lower: + return "image" + else: + # Default to image for diffusion models, llm for others + return "image" + + def _detect_tts_task_type(self): + """Detect TTS task type from model name.""" + model_lower = self.model_name.lower() + if "customvoice" in model_lower: + return "CustomVoice" + elif "voicedesign" in model_lower: + return "VoiceDesign" + elif "base" in model_lower: + return "Base" + else: + # Default to CustomVoice + return "CustomVoice" + + def _load_image(self, image_path): + """Load an image from file path or base64 encoded data.""" + # Try file path first + if os.path.exists(image_path): + return Image.open(image_path) + # Try base64 decode + try: + image_data = base64.b64decode(image_path) + return Image.open(io.BytesIO(image_data)) + except: + return None + + def _load_video(self, video_path): + """Load a video from file path or base64 encoded data.""" + from vllm.assets.video import VideoAsset, video_to_ndarrays + if os.path.exists(video_path): + return video_to_ndarrays(video_path, num_frames=16) + # Try base64 decode + try: + timestamp = str(int(time.time() * 1000)) + p = f"/tmp/vl-{timestamp}.data" + with open(p, "wb") as f: + f.write(base64.b64decode(video_path)) + video = VideoAsset(name=p).np_ndarrays + os.remove(p) + return video + except: + return None + + def _load_audio(self, audio_path): + """Load audio from file path or base64 encoded data.""" + import librosa + if os.path.exists(audio_path): + audio_signal, sr = librosa.load(audio_path, sr=16000) + return (audio_signal.astype(np.float32), sr) + # Try base64 decode + try: + audio_data = base64.b64decode(audio_path) + # Save to temp file and load + timestamp = str(int(time.time() * 1000)) + p = f"/tmp/audio-{timestamp}.wav" + with open(p, "wb") as f: + f.write(audio_data) + audio_signal, sr = librosa.load(p, sr=16000) + os.remove(p) + return (audio_signal.astype(np.float32), sr) + except: + return None + + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + try: + print(f"Loading model {request.Model}...", file=sys.stderr) + print(f"Request {request}", file=sys.stderr) + + # Parse options from request.Options (key:value pairs) + self.options = {} + for opt in request.Options: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + # Convert value to appropriate type + if is_float(value): + value = float(value) + elif is_int(value): + value = int(value) + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" + self.options[key] = value + + print(f"Options: {self.options}", file=sys.stderr) + + # Detect model type + self.model_name = request.Model + self.model_type = request.Type if request.Type else self._detect_model_type(request.Model) + print(f"Detected model type: {self.model_type}", file=sys.stderr) + + # Build DiffusionParallelConfig if diffusion model (image or video) + parallel_config = None + if self.model_type in ["image", "video"]: + parallel_config = DiffusionParallelConfig( + ulysses_degree=self.options.get("ulysses_degree", 1), + ring_degree=self.options.get("ring_degree", 1), + cfg_parallel_size=self.options.get("cfg_parallel_size", 1), + tensor_parallel_size=self.options.get("tensor_parallel_size", 1), + ) + + # Build cache_config dict if cache_backend specified + cache_backend = self.options.get("cache_backend") # "cache_dit" or "tea_cache" + cache_config = None + if cache_backend == "cache_dit": + cache_config = { + "Fn_compute_blocks": self.options.get("cache_dit_fn_compute_blocks", 1), + "Bn_compute_blocks": self.options.get("cache_dit_bn_compute_blocks", 0), + "max_warmup_steps": self.options.get("cache_dit_max_warmup_steps", 4), + "residual_diff_threshold": self.options.get("cache_dit_residual_diff_threshold", 0.24), + "max_continuous_cached_steps": self.options.get("cache_dit_max_continuous_cached_steps", 3), + "enable_taylorseer": self.options.get("cache_dit_enable_taylorseer", False), + "taylorseer_order": self.options.get("cache_dit_taylorseer_order", 1), + "scm_steps_mask_policy": self.options.get("cache_dit_scm_steps_mask_policy"), + "scm_steps_policy": self.options.get("cache_dit_scm_steps_policy", "dynamic"), + } + elif cache_backend == "tea_cache": + cache_config = { + "rel_l1_thresh": self.options.get("tea_cache_rel_l1_thresh", 0.2), + } + + # Base Omni initialization parameters + omni_kwargs = { + "model": request.Model, + } + + # Add diffusion-specific parameters (image/video models) + if self.model_type in ["image", "video"]: + omni_kwargs.update({ + "vae_use_slicing": is_npu(), + "vae_use_tiling": is_npu(), + "cache_backend": cache_backend, + "cache_config": cache_config, + "parallel_config": parallel_config, + "enforce_eager": self.options.get("enforce_eager", request.EnforceEager), + "enable_cpu_offload": self.options.get("enable_cpu_offload", False), + }) + # Video-specific parameters + if self.model_type == "video": + omni_kwargs.update({ + "boundary_ratio": self.options.get("boundary_ratio", 0.875), + "flow_shift": self.options.get("flow_shift", 5.0), + }) + + # Add LLM/TTS-specific parameters + if self.model_type in ["llm", "tts"]: + omni_kwargs.update({ + "stage_configs_path": self.options.get("stage_configs_path"), + "log_stats": self.options.get("enable_stats", False), + "stage_init_timeout": self.options.get("stage_init_timeout", 300), + }) + # vllm engine options (passed through Omni for LLM/TTS) + if request.GPUMemoryUtilization > 0: + omni_kwargs["gpu_memory_utilization"] = request.GPUMemoryUtilization + if request.TensorParallelSize > 0: + omni_kwargs["tensor_parallel_size"] = request.TensorParallelSize + if request.TrustRemoteCode: + omni_kwargs["trust_remote_code"] = request.TrustRemoteCode + if request.MaxModelLen > 0: + omni_kwargs["max_model_len"] = request.MaxModelLen + + self.omni = Omni(**omni_kwargs) + print("Model loaded successfully", file=sys.stderr) + return backend_pb2.Result(message="Model loaded successfully", success=True) + + except Exception as err: + print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + + def GenerateImage(self, request, context): + try: + # Validate model is loaded and is image/diffusion type + if not hasattr(self, 'omni'): + return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.") + if self.model_type not in ["image"]: + return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support image generation") + + # Extract parameters + prompt = request.positive_prompt + negative_prompt = request.negative_prompt if request.negative_prompt else None + width = request.width if request.width > 0 else 1024 + height = request.height if request.height > 0 else 1024 + seed = request.seed if request.seed > 0 else None + num_inference_steps = request.step if request.step > 0 else 50 + cfg_scale = self.options.get("cfg_scale", 4.0) + guidance_scale = self.options.get("guidance_scale", 1.0) + + # Create generator if seed provided + generator = None + if seed: + device = detect_device_type() + generator = torch.Generator(device=device).manual_seed(seed) + + # Handle image input for image editing + pil_image = None + if request.src or (request.ref_images and len(request.ref_images) > 0): + image_path = request.ref_images[0] if request.ref_images else request.src + pil_image = self._load_image(image_path) + if pil_image is None: + return backend_pb2.Result(success=False, message=f"Invalid image source: {image_path}") + pil_image = pil_image.convert("RGB") + + # Build generate kwargs + generate_kwargs = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "generator": generator, + "true_cfg_scale": cfg_scale, + "guidance_scale": guidance_scale, + "num_inference_steps": num_inference_steps, + } + if pil_image: + generate_kwargs["pil_image"] = pil_image + + # Call omni.generate() + outputs = self.omni.generate(**generate_kwargs) + + # Extract images (following example pattern) + if not outputs or len(outputs) == 0: + return backend_pb2.Result(success=False, message="No output generated") + + first_output = outputs[0] + if not hasattr(first_output, "request_output") or not first_output.request_output: + return backend_pb2.Result(success=False, message="Invalid output structure") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + return backend_pb2.Result(success=False, message="No images in output") + + images = req_out.images + if not images or len(images) == 0: + return backend_pb2.Result(success=False, message="Empty images list") + + # Save image + output_image = images[0] + output_image.save(request.dst) + return backend_pb2.Result(message="Image generated successfully", success=True) + + except Exception as err: + print(f"Error generating image: {err}", file=sys.stderr) + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"Error generating image: {err}") + + def GenerateVideo(self, request, context): + try: + # Validate model is loaded and is video/diffusion type + if not hasattr(self, 'omni'): + return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.") + if self.model_type not in ["video"]: + return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support video generation") + + # Extract parameters + prompt = request.prompt + negative_prompt = request.negative_prompt if request.negative_prompt else "" + width = request.width if request.width > 0 else 1280 + height = request.height if request.height > 0 else 720 + num_frames = request.num_frames if request.num_frames > 0 else 81 + fps = request.fps if request.fps > 0 else 24 + seed = request.seed if request.seed > 0 else None + guidance_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0 + guidance_scale_high = self.options.get("guidance_scale_high") + num_inference_steps = request.step if request.step > 0 else 40 + + # Create generator + generator = None + if seed: + device = detect_device_type() + generator = torch.Generator(device=device).manual_seed(seed) + + # Handle image input for image-to-video + pil_image = None + if request.start_image: + pil_image = self._load_image(request.start_image) + if pil_image is None: + return backend_pb2.Result(success=False, message=f"Invalid start_image: {request.start_image}") + pil_image = pil_image.convert("RGB") + # Resize to target dimensions + pil_image = pil_image.resize((width, height), Image.Resampling.LANCZOS) + + # Build generate kwargs + generate_kwargs = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "generator": generator, + "guidance_scale": guidance_scale, + "num_inference_steps": num_inference_steps, + "num_frames": num_frames, + } + if pil_image: + generate_kwargs["pil_image"] = pil_image + if guidance_scale_high: + generate_kwargs["guidance_scale_2"] = guidance_scale_high + + # Call omni.generate() + frames = self.omni.generate(**generate_kwargs) + + # Extract video frames (following example pattern) + if isinstance(frames, list) and len(frames) > 0: + first_item = frames[0] + + if hasattr(first_item, "final_output_type"): + if first_item.final_output_type != "image": + return backend_pb2.Result(success=False, message=f"Unexpected output type: {first_item.final_output_type}") + + # Pipeline mode: extract from nested request_output + if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output: + if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0: + inner_output = first_item.request_output[0] + if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): + frames = inner_output.images[0] if inner_output.images else None + # Diffusion mode: use direct images field + elif hasattr(first_item, "images") and first_item.images: + frames = first_item.images + else: + return backend_pb2.Result(success=False, message="No video frames found") + + if frames is None: + return backend_pb2.Result(success=False, message="No video frames found in output") + + # Convert frames to numpy array (following example) + if isinstance(frames, torch.Tensor): + video_tensor = frames.detach().cpu() + # Handle different tensor shapes [B, C, F, H, W] or [B, F, H, W, C] + if video_tensor.dim() == 5: + if video_tensor.shape[1] in (3, 4): + video_tensor = video_tensor[0].permute(1, 2, 3, 0) + else: + video_tensor = video_tensor[0] + elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): + video_tensor = video_tensor.permute(1, 2, 3, 0) + # Normalize from [-1,1] to [0,1] if float + if video_tensor.is_floating_point(): + video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 + video_array = video_tensor.float().numpy() + else: + video_array = frames + if hasattr(video_array, "shape") and video_array.ndim == 5: + video_array = video_array[0] + + # Convert 4D array (frames, H, W, C) to list of frames + if isinstance(video_array, np.ndarray) and video_array.ndim == 4: + video_array = list(video_array) + + # Save video + export_to_video(video_array, request.dst, fps=fps) + return backend_pb2.Result(message="Video generated successfully", success=True) + + except Exception as err: + print(f"Error generating video: {err}", file=sys.stderr) + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"Error generating video: {err}") + + def Predict(self, request, context): + """Non-streaming text generation with multimodal inputs.""" + gen = self._predict(request, context, streaming=False) + try: + res = next(gen) + return res + except StopIteration: + return backend_pb2.Reply(message=bytes("", 'utf-8')) + + def PredictStream(self, request, context): + """Streaming text generation with multimodal inputs.""" + return self._predict(request, context, streaming=True) + + def _predict(self, request, context, streaming=False): + """Internal method for text generation (streaming and non-streaming).""" + try: + # Validate model is loaded and is LLM type + if not hasattr(self, 'omni'): + yield backend_pb2.Reply(message=bytes("Model not loaded. Call LoadModel first.", 'utf-8')) + return + if self.model_type not in ["llm"]: + yield backend_pb2.Reply(message=bytes(f"Model type {self.model_type} does not support text generation", 'utf-8')) + return + + # Extract prompt + if request.Prompt: + prompt = request.Prompt + elif request.Messages and request.UseTokenizerTemplate: + # Build prompt from messages (simplified - would need tokenizer for full template) + prompt = "" + for msg in request.Messages: + role = msg.role + content = msg.content + prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n" + prompt += "<|im_start|>assistant\n" + else: + yield backend_pb2.Reply(message=bytes("", 'utf-8')) + return + + # Build multi_modal_data dict + multi_modal_data = {} + + # Process images + if request.Images: + image_data = [] + for img_path in request.Images: + img = self._load_image(img_path) + if img: + # Convert to format expected by vllm + from vllm.multimodal.image import convert_image_mode + img_data = convert_image_mode(img, "RGB") + image_data.append(img_data) + if image_data: + multi_modal_data["image"] = image_data + + # Process videos + if request.Videos: + video_data = [] + for video_path in request.Videos: + video = self._load_video(video_path) + if video is not None: + video_data.append(video) + if video_data: + multi_modal_data["video"] = video_data + + # Process audio + if request.Audios: + audio_data = [] + for audio_path in request.Audios: + audio = self._load_audio(audio_path) + if audio is not None: + audio_data.append(audio) + if audio_data: + multi_modal_data["audio"] = audio_data + + # Build inputs dict + inputs = { + "prompt": prompt, + "multi_modal_data": multi_modal_data if multi_modal_data else None, + } + + # Build sampling params + sampling_params = SamplingParams( + temperature=request.Temperature if request.Temperature > 0 else 0.7, + top_p=request.TopP if request.TopP > 0 else 0.9, + top_k=request.TopK if request.TopK > 0 else -1, + max_tokens=request.Tokens if request.Tokens > 0 else 200, + presence_penalty=request.PresencePenalty if request.PresencePenalty != 0 else 0.0, + frequency_penalty=request.FrequencyPenalty if request.FrequencyPenalty != 0 else 0.0, + repetition_penalty=request.RepetitionPenalty if request.RepetitionPenalty != 0 else 1.0, + seed=request.Seed if request.Seed > 0 else None, + stop=request.StopPrompts if request.StopPrompts else None, + stop_token_ids=request.StopTokenIds if request.StopTokenIds else None, + ignore_eos=request.IgnoreEOS, + ) + sampling_params_list = [sampling_params] + + # Call omni.generate() (returns generator for LLM mode) + omni_generator = self.omni.generate([inputs], sampling_params_list) + + # Extract text from outputs + generated_text = "" + for stage_outputs in omni_generator: + if stage_outputs.final_output_type == "text": + for output in stage_outputs.request_output: + text_output = output.outputs[0].text + if streaming: + # Remove already sent text (vllm concatenates) + delta_text = text_output.removeprefix(generated_text) + yield backend_pb2.Reply(message=bytes(delta_text, encoding='utf-8')) + generated_text = text_output + + if not streaming: + yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) + + except Exception as err: + print(f"Error in Predict: {err}", file=sys.stderr) + traceback.print_exc() + yield backend_pb2.Reply(message=bytes(f"Error: {err}", encoding='utf-8')) + + def TTS(self, request, context): + try: + # Validate model is loaded and is TTS type + if not hasattr(self, 'omni'): + return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.") + if self.model_type not in ["tts"]: + return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support TTS") + + # Extract parameters + text = request.text + language = request.language if request.language else "Auto" + voice = request.voice if request.voice else None + task_type = self._detect_tts_task_type() + + # Build prompt with chat template + # TODO: for now vllm-omni supports only qwen3-tts, so we hardcode it, however, we want to support other models in the future. + # and we might need to use the chat template here + prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + + # Build inputs dict + inputs = { + "prompt": prompt, + "additional_information": { + "task_type": [task_type], + "text": [text], + "language": [language], + "max_new_tokens": [2048], + } + } + + # Add task-specific fields + if task_type == "CustomVoice": + if voice: + inputs["additional_information"]["speaker"] = [voice] + # Add instruct if provided in options + if "instruct" in self.options: + inputs["additional_information"]["instruct"] = [self.options["instruct"]] + elif task_type == "VoiceDesign": + if "instruct" in self.options: + inputs["additional_information"]["instruct"] = [self.options["instruct"]] + inputs["additional_information"]["non_streaming_mode"] = [True] + elif task_type == "Base": + # Voice cloning requires ref_audio and ref_text + if "ref_audio" in self.options: + inputs["additional_information"]["ref_audio"] = [self.options["ref_audio"]] + if "ref_text" in self.options: + inputs["additional_information"]["ref_text"] = [self.options["ref_text"]] + if "x_vector_only_mode" in self.options: + inputs["additional_information"]["x_vector_only_mode"] = [self.options["x_vector_only_mode"]] + + # Build sampling params + sampling_params = SamplingParams( + temperature=0.9, + top_p=1.0, + top_k=50, + max_tokens=2048, + seed=42, + detokenize=False, + repetition_penalty=1.05, + ) + sampling_params_list = [sampling_params] + + # Call omni.generate() + omni_generator = self.omni.generate(inputs, sampling_params_list) + + # Extract audio (following TTS example) + for stage_outputs in omni_generator: + for output in stage_outputs.request_output: + if "audio" in output.multimodal_output: + audio_tensor = output.multimodal_output["audio"] + audio_samplerate = output.multimodal_output["sr"].item() + + # Convert to numpy + audio_numpy = audio_tensor.float().detach().cpu().numpy() + if audio_numpy.ndim > 1: + audio_numpy = audio_numpy.flatten() + + # Save audio file + sf.write(request.dst, audio_numpy, samplerate=audio_samplerate, format="WAV") + return backend_pb2.Result(message="TTS audio generated successfully", success=True) + + return backend_pb2.Result(success=False, message="No audio output generated") + + except Exception as err: + print(f"Error generating TTS: {err}", file=sys.stderr) + traceback.print_exc() + return backend_pb2.Result(success=False, message=f"Error generating TTS: {err}") + + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_send_message_length', 50 * 1024 * 1024), + ('grpc.max_receive_message_length', 50 * 1024 * 1024), + ]) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Signal handlers for graceful shutdown + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) diff --git a/backend/python/vllm-omni/install.sh b/backend/python/vllm-omni/install.sh new file mode 100755 index 000000000..3aa6367d3 --- /dev/null +++ b/backend/python/vllm-omni/install.sh @@ -0,0 +1,62 @@ +#!/bin/bash +set -e + +PYTHON_VERSION="3.12" +PYTHON_PATCH="12" +PY_STANDALONE_TAG="20251120" + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +# Handle l4t build profiles (Python 3.12, pip fallback) if needed +if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then + PYTHON_VERSION="3.12" + PYTHON_PATCH="12" + PY_STANDALONE_TAG="20251120" +fi + +if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then + USE_PIP=true +fi + +# Install base requirements first +installRequirements + +# Install vllm based on build type +if [ "x${BUILD_TYPE}" == "xhipblas" ]; then + # ROCm + if [ "x${USE_PIP}" == "xtrue" ]; then + pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700 + else + uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700 + fi +elif [ "x${BUILD_TYPE}" == "xcublas" ] || [ "x${BUILD_TYPE}" == "x" ]; then + # CUDA (default) or CPU + if [ "x${USE_PIP}" == "xtrue" ]; then + pip install vllm==0.14.0 --torch-backend=auto + else + uv pip install vllm==0.14.0 --torch-backend=auto + fi +else + echo "Unsupported build type: ${BUILD_TYPE}" >&2 + exit 1 +fi + +# Clone and install vllm-omni from source +if [ ! -d vllm-omni ]; then + git clone https://github.com/vllm-project/vllm-omni.git +fi + +cd vllm-omni/ + +if [ "x${USE_PIP}" == "xtrue" ]; then + pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e . +else + uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e . +fi + +cd .. diff --git a/backend/python/vllm-omni/requirements-after.txt b/backend/python/vllm-omni/requirements-after.txt new file mode 100644 index 000000000..e20a00869 --- /dev/null +++ b/backend/python/vllm-omni/requirements-after.txt @@ -0,0 +1,2 @@ +diffusers +librosa diff --git a/backend/python/vllm-omni/requirements-cublas12-after.txt b/backend/python/vllm-omni/requirements-cublas12-after.txt new file mode 100644 index 000000000..9251ba608 --- /dev/null +++ b/backend/python/vllm-omni/requirements-cublas12-after.txt @@ -0,0 +1 @@ +https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl diff --git a/backend/python/vllm-omni/requirements-cublas12.txt b/backend/python/vllm-omni/requirements-cublas12.txt new file mode 100644 index 000000000..deafb2c6e --- /dev/null +++ b/backend/python/vllm-omni/requirements-cublas12.txt @@ -0,0 +1,4 @@ +accelerate +torch==2.7.0 +transformers +bitsandbytes diff --git a/backend/python/vllm-omni/requirements-hipblas.txt b/backend/python/vllm-omni/requirements-hipblas.txt new file mode 100644 index 000000000..426494ec0 --- /dev/null +++ b/backend/python/vllm-omni/requirements-hipblas.txt @@ -0,0 +1,5 @@ +--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.4 +accelerate +torch +transformers +bitsandbytes diff --git a/backend/python/vllm-omni/requirements.txt b/backend/python/vllm-omni/requirements.txt new file mode 100644 index 000000000..e6eb4e177 --- /dev/null +++ b/backend/python/vllm-omni/requirements.txt @@ -0,0 +1,7 @@ +grpcio==1.76.0 +protobuf +certifi +setuptools +pillow +numpy +soundfile diff --git a/backend/python/vllm-omni/run.sh b/backend/python/vllm-omni/run.sh new file mode 100755 index 000000000..8f608a541 --- /dev/null +++ b/backend/python/vllm-omni/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/vllm-omni/test.py b/backend/python/vllm-omni/test.py new file mode 100644 index 000000000..ef8255725 --- /dev/null +++ b/backend/python/vllm-omni/test.py @@ -0,0 +1,82 @@ +import unittest +import subprocess +import time +import backend_pb2 +import backend_pb2_grpc + +import grpc + +class TestBackendServicer(unittest.TestCase): + """ + TestBackendServicer is the class that tests the gRPC service. + + This class contains methods to test the startup and shutdown of the gRPC service. + """ + def setUp(self): + self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"]) + time.sleep(10) + + def tearDown(self) -> None: + self.service.terminate() + self.service.wait() + + def test_server_startup(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + + def test_load_model(self): + """ + This method tests if the model is loaded successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + # Use a small image generation model for testing + response = stub.LoadModel(backend_pb2.ModelOptions(Model="Tongyi-MAI/Z-Image-Turbo")) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_generate_image(self): + """ + This method tests if image generation works + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="Tongyi-MAI/Z-Image-Turbo")) + self.assertTrue(response.success) + + req = backend_pb2.GenerateImageRequest( + positive_prompt="a cup of coffee on the table", + dst="/tmp/test_output.png", + width=512, + height=512, + step=20, + seed=42additional_information + ) + resp = stub.GenerateImage(req) + self.assertTrue(resp.success) + except Exception as err: + print(err) + self.fail("GenerateImage service failed") + finally: + self.tearDown() +additional_information +if __name__ == "__main__": + unittest.main() diff --git a/backend/python/vllm-omni/test.sh b/backend/python/vllm-omni/test.sh new file mode 100755 index 000000000..f31ae54e4 --- /dev/null +++ b/backend/python/vllm-omni/test.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests