mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-29 16:52:27 -05:00
feat(vllm-omni): add new backend (#8188)
* feat(vllm-omni: add new backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * default to py3.12 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
05a332cd5f
commit
b2a8a63899
26
.github/workflows/backend.yml
vendored
26
.github/workflows/backend.yml
vendored
@@ -170,6 +170,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "9"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-vllm-omni'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "ubuntu:24.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "vllm-omni"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "9"
|
||||
@@ -653,6 +666,19 @@ jobs:
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-vllm-omni'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-24.04:6.4.4"
|
||||
skip-drivers: 'false'
|
||||
backend: "vllm-omni"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./"
|
||||
ubuntu-version: '2404'
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
|
||||
8
Makefile
8
Makefile
@@ -1,5 +1,5 @@
|
||||
# Disable parallel execution for backend builds
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/moonshine backends/pocket-tts backends/qwen-tts
|
||||
.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts
|
||||
|
||||
GOCMD=go
|
||||
GOTEST=$(GOCMD) test
|
||||
@@ -314,6 +314,7 @@ prepare-test-extra: protogen-python
|
||||
$(MAKE) -C backend/python/diffusers
|
||||
$(MAKE) -C backend/python/chatterbox
|
||||
$(MAKE) -C backend/python/vllm
|
||||
$(MAKE) -C backend/python/vllm-omni
|
||||
$(MAKE) -C backend/python/vibevoice
|
||||
$(MAKE) -C backend/python/moonshine
|
||||
$(MAKE) -C backend/python/pocket-tts
|
||||
@@ -324,6 +325,7 @@ test-extra: prepare-test-extra
|
||||
$(MAKE) -C backend/python/diffusers test
|
||||
$(MAKE) -C backend/python/chatterbox test
|
||||
$(MAKE) -C backend/python/vllm test
|
||||
$(MAKE) -C backend/python/vllm-omni test
|
||||
$(MAKE) -C backend/python/vibevoice test
|
||||
$(MAKE) -C backend/python/moonshine test
|
||||
$(MAKE) -C backend/python/pocket-tts test
|
||||
@@ -455,6 +457,7 @@ BACKEND_KITTEN_TTS = kitten-tts|python|.|false|true
|
||||
BACKEND_NEUTTS = neutts|python|.|false|true
|
||||
BACKEND_KOKORO = kokoro|python|.|false|true
|
||||
BACKEND_VLLM = vllm|python|.|false|true
|
||||
BACKEND_VLLM_OMNI = vllm-omni|python|.|false|true
|
||||
BACKEND_DIFFUSERS = diffusers|python|.|--progress=plain|true
|
||||
BACKEND_CHATTERBOX = chatterbox|python|.|false|true
|
||||
BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true
|
||||
@@ -501,6 +504,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_KITTEN_TTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_NEUTTS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_KOKORO)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VLLM)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VLLM_OMNI)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_DIFFUSERS)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_CHATTERBOX)))
|
||||
$(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE)))
|
||||
@@ -512,7 +516,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_QWEN_TTS)))
|
||||
docker-save-%: backend-images
|
||||
docker save local-ai-backend:$* -o backend-images/$*.tar
|
||||
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts
|
||||
docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-bark docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts
|
||||
|
||||
########################################################
|
||||
### END Backends
|
||||
|
||||
@@ -142,6 +142,31 @@
|
||||
amd: "rocm-vllm"
|
||||
intel: "intel-vllm"
|
||||
nvidia-cuda-12: "cuda12-vllm"
|
||||
- &vllm-omni
|
||||
name: "vllm-omni"
|
||||
license: apache-2.0
|
||||
urls:
|
||||
- https://github.com/vllm-project/vllm-omni
|
||||
tags:
|
||||
- text-to-image
|
||||
- image-generation
|
||||
- text-to-video
|
||||
- video-generation
|
||||
- text-to-speech
|
||||
- TTS
|
||||
- multimodal
|
||||
- LLM
|
||||
icon: https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png
|
||||
description: |
|
||||
vLLM-Omni is a unified interface for multimodal generation with vLLM.
|
||||
It supports image generation (text-to-image, image editing), video generation
|
||||
(text-to-video, image-to-video), text generation with multimodal inputs, and
|
||||
text-to-speech generation. Only supports NVIDIA (CUDA) and ROCm platforms.
|
||||
alias: "vllm-omni"
|
||||
capabilities:
|
||||
nvidia: "cuda12-vllm-omni"
|
||||
amd: "rocm-vllm-omni"
|
||||
nvidia-cuda-12: "cuda12-vllm-omni"
|
||||
- &mlx
|
||||
name: "mlx"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx"
|
||||
@@ -973,6 +998,33 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-vllm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-vllm
|
||||
# vllm-omni
|
||||
- !!merge <<: *vllm-omni
|
||||
name: "vllm-omni-development"
|
||||
capabilities:
|
||||
nvidia: "cuda12-vllm-omni-development"
|
||||
amd: "rocm-vllm-omni-development"
|
||||
nvidia-cuda-12: "cuda12-vllm-omni-development"
|
||||
- !!merge <<: *vllm-omni
|
||||
name: "cuda12-vllm-omni"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vllm-omni"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-vllm-omni
|
||||
- !!merge <<: *vllm-omni
|
||||
name: "rocm-vllm-omni"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-vllm-omni"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-vllm-omni
|
||||
- !!merge <<: *vllm-omni
|
||||
name: "cuda12-vllm-omni-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vllm-omni"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-vllm-omni
|
||||
- !!merge <<: *vllm-omni
|
||||
name: "rocm-vllm-omni-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-vllm-omni"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-vllm-omni
|
||||
# rfdetr
|
||||
- !!merge <<: *rfdetr
|
||||
name: "rfdetr-development"
|
||||
|
||||
23
backend/python/vllm-omni/Makefile
Normal file
23
backend/python/vllm-omni/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: vllm-omni
|
||||
vllm-omni:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: vllm-omni
|
||||
@echo "Running vllm-omni..."
|
||||
bash run.sh
|
||||
@echo "vllm-omni run."
|
||||
|
||||
.PHONY: test
|
||||
test: vllm-omni
|
||||
@echo "Testing vllm-omni..."
|
||||
bash test.sh
|
||||
@echo "vllm-omni tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
682
backend/python/vllm-omni/backend.py
Normal file
682
backend/python/vllm-omni/backend.py
Normal file
@@ -0,0 +1,682 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LocalAI vLLM-Omni Backend
|
||||
|
||||
This backend provides gRPC access to vllm-omni for multimodal generation:
|
||||
- Image generation (text-to-image, image editing)
|
||||
- Video generation (text-to-video, image-to-video)
|
||||
- Text generation with multimodal inputs (LLM)
|
||||
- Text-to-speech generation
|
||||
"""
|
||||
from concurrent import futures
|
||||
import traceback
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
from vllm_omni.entrypoints.omni import Omni
|
||||
from vllm_omni.outputs import OmniRequestOutput
|
||||
from vllm_omni.diffusion.data import DiffusionParallelConfig
|
||||
from vllm_omni.utils.platform_utils import detect_device_type, is_npu
|
||||
from vllm import SamplingParams
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def _detect_model_type(self, model_name):
|
||||
"""Detect model type from model name."""
|
||||
model_lower = model_name.lower()
|
||||
if "tts" in model_lower or "qwen3-tts" in model_lower:
|
||||
return "tts"
|
||||
elif "omni" in model_lower and "qwen3" in model_lower:
|
||||
return "llm"
|
||||
elif "wan" in model_lower or "t2v" in model_lower or "i2v" in model_lower:
|
||||
return "video"
|
||||
elif "image" in model_lower or "z-image" in model_lower or "qwen-image" in model_lower:
|
||||
return "image"
|
||||
else:
|
||||
# Default to image for diffusion models, llm for others
|
||||
return "image"
|
||||
|
||||
def _detect_tts_task_type(self):
|
||||
"""Detect TTS task type from model name."""
|
||||
model_lower = self.model_name.lower()
|
||||
if "customvoice" in model_lower:
|
||||
return "CustomVoice"
|
||||
elif "voicedesign" in model_lower:
|
||||
return "VoiceDesign"
|
||||
elif "base" in model_lower:
|
||||
return "Base"
|
||||
else:
|
||||
# Default to CustomVoice
|
||||
return "CustomVoice"
|
||||
|
||||
def _load_image(self, image_path):
|
||||
"""Load an image from file path or base64 encoded data."""
|
||||
# Try file path first
|
||||
if os.path.exists(image_path):
|
||||
return Image.open(image_path)
|
||||
# Try base64 decode
|
||||
try:
|
||||
image_data = base64.b64decode(image_path)
|
||||
return Image.open(io.BytesIO(image_data))
|
||||
except:
|
||||
return None
|
||||
|
||||
def _load_video(self, video_path):
|
||||
"""Load a video from file path or base64 encoded data."""
|
||||
from vllm.assets.video import VideoAsset, video_to_ndarrays
|
||||
if os.path.exists(video_path):
|
||||
return video_to_ndarrays(video_path, num_frames=16)
|
||||
# Try base64 decode
|
||||
try:
|
||||
timestamp = str(int(time.time() * 1000))
|
||||
p = f"/tmp/vl-{timestamp}.data"
|
||||
with open(p, "wb") as f:
|
||||
f.write(base64.b64decode(video_path))
|
||||
video = VideoAsset(name=p).np_ndarrays
|
||||
os.remove(p)
|
||||
return video
|
||||
except:
|
||||
return None
|
||||
|
||||
def _load_audio(self, audio_path):
|
||||
"""Load audio from file path or base64 encoded data."""
|
||||
import librosa
|
||||
if os.path.exists(audio_path):
|
||||
audio_signal, sr = librosa.load(audio_path, sr=16000)
|
||||
return (audio_signal.astype(np.float32), sr)
|
||||
# Try base64 decode
|
||||
try:
|
||||
audio_data = base64.b64decode(audio_path)
|
||||
# Save to temp file and load
|
||||
timestamp = str(int(time.time() * 1000))
|
||||
p = f"/tmp/audio-{timestamp}.wav"
|
||||
with open(p, "wb") as f:
|
||||
f.write(audio_data)
|
||||
audio_signal, sr = librosa.load(p, sr=16000)
|
||||
os.remove(p)
|
||||
return (audio_signal.astype(np.float32), sr)
|
||||
except:
|
||||
return None
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
print(f"Loading model {request.Model}...", file=sys.stderr)
|
||||
print(f"Request {request}", file=sys.stderr)
|
||||
|
||||
# Parse options from request.Options (key:value pairs)
|
||||
self.options = {}
|
||||
for opt in request.Options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1)
|
||||
# Convert value to appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
# Detect model type
|
||||
self.model_name = request.Model
|
||||
self.model_type = request.Type if request.Type else self._detect_model_type(request.Model)
|
||||
print(f"Detected model type: {self.model_type}", file=sys.stderr)
|
||||
|
||||
# Build DiffusionParallelConfig if diffusion model (image or video)
|
||||
parallel_config = None
|
||||
if self.model_type in ["image", "video"]:
|
||||
parallel_config = DiffusionParallelConfig(
|
||||
ulysses_degree=self.options.get("ulysses_degree", 1),
|
||||
ring_degree=self.options.get("ring_degree", 1),
|
||||
cfg_parallel_size=self.options.get("cfg_parallel_size", 1),
|
||||
tensor_parallel_size=self.options.get("tensor_parallel_size", 1),
|
||||
)
|
||||
|
||||
# Build cache_config dict if cache_backend specified
|
||||
cache_backend = self.options.get("cache_backend") # "cache_dit" or "tea_cache"
|
||||
cache_config = None
|
||||
if cache_backend == "cache_dit":
|
||||
cache_config = {
|
||||
"Fn_compute_blocks": self.options.get("cache_dit_fn_compute_blocks", 1),
|
||||
"Bn_compute_blocks": self.options.get("cache_dit_bn_compute_blocks", 0),
|
||||
"max_warmup_steps": self.options.get("cache_dit_max_warmup_steps", 4),
|
||||
"residual_diff_threshold": self.options.get("cache_dit_residual_diff_threshold", 0.24),
|
||||
"max_continuous_cached_steps": self.options.get("cache_dit_max_continuous_cached_steps", 3),
|
||||
"enable_taylorseer": self.options.get("cache_dit_enable_taylorseer", False),
|
||||
"taylorseer_order": self.options.get("cache_dit_taylorseer_order", 1),
|
||||
"scm_steps_mask_policy": self.options.get("cache_dit_scm_steps_mask_policy"),
|
||||
"scm_steps_policy": self.options.get("cache_dit_scm_steps_policy", "dynamic"),
|
||||
}
|
||||
elif cache_backend == "tea_cache":
|
||||
cache_config = {
|
||||
"rel_l1_thresh": self.options.get("tea_cache_rel_l1_thresh", 0.2),
|
||||
}
|
||||
|
||||
# Base Omni initialization parameters
|
||||
omni_kwargs = {
|
||||
"model": request.Model,
|
||||
}
|
||||
|
||||
# Add diffusion-specific parameters (image/video models)
|
||||
if self.model_type in ["image", "video"]:
|
||||
omni_kwargs.update({
|
||||
"vae_use_slicing": is_npu(),
|
||||
"vae_use_tiling": is_npu(),
|
||||
"cache_backend": cache_backend,
|
||||
"cache_config": cache_config,
|
||||
"parallel_config": parallel_config,
|
||||
"enforce_eager": self.options.get("enforce_eager", request.EnforceEager),
|
||||
"enable_cpu_offload": self.options.get("enable_cpu_offload", False),
|
||||
})
|
||||
# Video-specific parameters
|
||||
if self.model_type == "video":
|
||||
omni_kwargs.update({
|
||||
"boundary_ratio": self.options.get("boundary_ratio", 0.875),
|
||||
"flow_shift": self.options.get("flow_shift", 5.0),
|
||||
})
|
||||
|
||||
# Add LLM/TTS-specific parameters
|
||||
if self.model_type in ["llm", "tts"]:
|
||||
omni_kwargs.update({
|
||||
"stage_configs_path": self.options.get("stage_configs_path"),
|
||||
"log_stats": self.options.get("enable_stats", False),
|
||||
"stage_init_timeout": self.options.get("stage_init_timeout", 300),
|
||||
})
|
||||
# vllm engine options (passed through Omni for LLM/TTS)
|
||||
if request.GPUMemoryUtilization > 0:
|
||||
omni_kwargs["gpu_memory_utilization"] = request.GPUMemoryUtilization
|
||||
if request.TensorParallelSize > 0:
|
||||
omni_kwargs["tensor_parallel_size"] = request.TensorParallelSize
|
||||
if request.TrustRemoteCode:
|
||||
omni_kwargs["trust_remote_code"] = request.TrustRemoteCode
|
||||
if request.MaxModelLen > 0:
|
||||
omni_kwargs["max_model_len"] = request.MaxModelLen
|
||||
|
||||
self.omni = Omni(**omni_kwargs)
|
||||
print("Model loaded successfully", file=sys.stderr)
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
def GenerateImage(self, request, context):
|
||||
try:
|
||||
# Validate model is loaded and is image/diffusion type
|
||||
if not hasattr(self, 'omni'):
|
||||
return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.")
|
||||
if self.model_type not in ["image"]:
|
||||
return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support image generation")
|
||||
|
||||
# Extract parameters
|
||||
prompt = request.positive_prompt
|
||||
negative_prompt = request.negative_prompt if request.negative_prompt else None
|
||||
width = request.width if request.width > 0 else 1024
|
||||
height = request.height if request.height > 0 else 1024
|
||||
seed = request.seed if request.seed > 0 else None
|
||||
num_inference_steps = request.step if request.step > 0 else 50
|
||||
cfg_scale = self.options.get("cfg_scale", 4.0)
|
||||
guidance_scale = self.options.get("guidance_scale", 1.0)
|
||||
|
||||
# Create generator if seed provided
|
||||
generator = None
|
||||
if seed:
|
||||
device = detect_device_type()
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
# Handle image input for image editing
|
||||
pil_image = None
|
||||
if request.src or (request.ref_images and len(request.ref_images) > 0):
|
||||
image_path = request.ref_images[0] if request.ref_images else request.src
|
||||
pil_image = self._load_image(image_path)
|
||||
if pil_image is None:
|
||||
return backend_pb2.Result(success=False, message=f"Invalid image source: {image_path}")
|
||||
pil_image = pil_image.convert("RGB")
|
||||
|
||||
# Build generate kwargs
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"generator": generator,
|
||||
"true_cfg_scale": cfg_scale,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
if pil_image:
|
||||
generate_kwargs["pil_image"] = pil_image
|
||||
|
||||
# Call omni.generate()
|
||||
outputs = self.omni.generate(**generate_kwargs)
|
||||
|
||||
# Extract images (following example pattern)
|
||||
if not outputs or len(outputs) == 0:
|
||||
return backend_pb2.Result(success=False, message="No output generated")
|
||||
|
||||
first_output = outputs[0]
|
||||
if not hasattr(first_output, "request_output") or not first_output.request_output:
|
||||
return backend_pb2.Result(success=False, message="Invalid output structure")
|
||||
|
||||
req_out = first_output.request_output[0]
|
||||
if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
|
||||
return backend_pb2.Result(success=False, message="No images in output")
|
||||
|
||||
images = req_out.images
|
||||
if not images or len(images) == 0:
|
||||
return backend_pb2.Result(success=False, message="Empty images list")
|
||||
|
||||
# Save image
|
||||
output_image = images[0]
|
||||
output_image.save(request.dst)
|
||||
return backend_pb2.Result(message="Image generated successfully", success=True)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error generating image: {err}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Error generating image: {err}")
|
||||
|
||||
def GenerateVideo(self, request, context):
|
||||
try:
|
||||
# Validate model is loaded and is video/diffusion type
|
||||
if not hasattr(self, 'omni'):
|
||||
return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.")
|
||||
if self.model_type not in ["video"]:
|
||||
return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support video generation")
|
||||
|
||||
# Extract parameters
|
||||
prompt = request.prompt
|
||||
negative_prompt = request.negative_prompt if request.negative_prompt else ""
|
||||
width = request.width if request.width > 0 else 1280
|
||||
height = request.height if request.height > 0 else 720
|
||||
num_frames = request.num_frames if request.num_frames > 0 else 81
|
||||
fps = request.fps if request.fps > 0 else 24
|
||||
seed = request.seed if request.seed > 0 else None
|
||||
guidance_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
|
||||
guidance_scale_high = self.options.get("guidance_scale_high")
|
||||
num_inference_steps = request.step if request.step > 0 else 40
|
||||
|
||||
# Create generator
|
||||
generator = None
|
||||
if seed:
|
||||
device = detect_device_type()
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
# Handle image input for image-to-video
|
||||
pil_image = None
|
||||
if request.start_image:
|
||||
pil_image = self._load_image(request.start_image)
|
||||
if pil_image is None:
|
||||
return backend_pb2.Result(success=False, message=f"Invalid start_image: {request.start_image}")
|
||||
pil_image = pil_image.convert("RGB")
|
||||
# Resize to target dimensions
|
||||
pil_image = pil_image.resize((width, height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Build generate kwargs
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"generator": generator,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"num_frames": num_frames,
|
||||
}
|
||||
if pil_image:
|
||||
generate_kwargs["pil_image"] = pil_image
|
||||
if guidance_scale_high:
|
||||
generate_kwargs["guidance_scale_2"] = guidance_scale_high
|
||||
|
||||
# Call omni.generate()
|
||||
frames = self.omni.generate(**generate_kwargs)
|
||||
|
||||
# Extract video frames (following example pattern)
|
||||
if isinstance(frames, list) and len(frames) > 0:
|
||||
first_item = frames[0]
|
||||
|
||||
if hasattr(first_item, "final_output_type"):
|
||||
if first_item.final_output_type != "image":
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected output type: {first_item.final_output_type}")
|
||||
|
||||
# Pipeline mode: extract from nested request_output
|
||||
if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output:
|
||||
if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0:
|
||||
inner_output = first_item.request_output[0]
|
||||
if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"):
|
||||
frames = inner_output.images[0] if inner_output.images else None
|
||||
# Diffusion mode: use direct images field
|
||||
elif hasattr(first_item, "images") and first_item.images:
|
||||
frames = first_item.images
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message="No video frames found")
|
||||
|
||||
if frames is None:
|
||||
return backend_pb2.Result(success=False, message="No video frames found in output")
|
||||
|
||||
# Convert frames to numpy array (following example)
|
||||
if isinstance(frames, torch.Tensor):
|
||||
video_tensor = frames.detach().cpu()
|
||||
# Handle different tensor shapes [B, C, F, H, W] or [B, F, H, W, C]
|
||||
if video_tensor.dim() == 5:
|
||||
if video_tensor.shape[1] in (3, 4):
|
||||
video_tensor = video_tensor[0].permute(1, 2, 3, 0)
|
||||
else:
|
||||
video_tensor = video_tensor[0]
|
||||
elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4):
|
||||
video_tensor = video_tensor.permute(1, 2, 3, 0)
|
||||
# Normalize from [-1,1] to [0,1] if float
|
||||
if video_tensor.is_floating_point():
|
||||
video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5
|
||||
video_array = video_tensor.float().numpy()
|
||||
else:
|
||||
video_array = frames
|
||||
if hasattr(video_array, "shape") and video_array.ndim == 5:
|
||||
video_array = video_array[0]
|
||||
|
||||
# Convert 4D array (frames, H, W, C) to list of frames
|
||||
if isinstance(video_array, np.ndarray) and video_array.ndim == 4:
|
||||
video_array = list(video_array)
|
||||
|
||||
# Save video
|
||||
export_to_video(video_array, request.dst, fps=fps)
|
||||
return backend_pb2.Result(message="Video generated successfully", success=True)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error generating video: {err}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Non-streaming text generation with multimodal inputs."""
|
||||
gen = self._predict(request, context, streaming=False)
|
||||
try:
|
||||
res = next(gen)
|
||||
return res
|
||||
except StopIteration:
|
||||
return backend_pb2.Reply(message=bytes("", 'utf-8'))
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""Streaming text generation with multimodal inputs."""
|
||||
return self._predict(request, context, streaming=True)
|
||||
|
||||
def _predict(self, request, context, streaming=False):
|
||||
"""Internal method for text generation (streaming and non-streaming)."""
|
||||
try:
|
||||
# Validate model is loaded and is LLM type
|
||||
if not hasattr(self, 'omni'):
|
||||
yield backend_pb2.Reply(message=bytes("Model not loaded. Call LoadModel first.", 'utf-8'))
|
||||
return
|
||||
if self.model_type not in ["llm"]:
|
||||
yield backend_pb2.Reply(message=bytes(f"Model type {self.model_type} does not support text generation", 'utf-8'))
|
||||
return
|
||||
|
||||
# Extract prompt
|
||||
if request.Prompt:
|
||||
prompt = request.Prompt
|
||||
elif request.Messages and request.UseTokenizerTemplate:
|
||||
# Build prompt from messages (simplified - would need tokenizer for full template)
|
||||
prompt = ""
|
||||
for msg in request.Messages:
|
||||
role = msg.role
|
||||
content = msg.content
|
||||
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
else:
|
||||
yield backend_pb2.Reply(message=bytes("", 'utf-8'))
|
||||
return
|
||||
|
||||
# Build multi_modal_data dict
|
||||
multi_modal_data = {}
|
||||
|
||||
# Process images
|
||||
if request.Images:
|
||||
image_data = []
|
||||
for img_path in request.Images:
|
||||
img = self._load_image(img_path)
|
||||
if img:
|
||||
# Convert to format expected by vllm
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
img_data = convert_image_mode(img, "RGB")
|
||||
image_data.append(img_data)
|
||||
if image_data:
|
||||
multi_modal_data["image"] = image_data
|
||||
|
||||
# Process videos
|
||||
if request.Videos:
|
||||
video_data = []
|
||||
for video_path in request.Videos:
|
||||
video = self._load_video(video_path)
|
||||
if video is not None:
|
||||
video_data.append(video)
|
||||
if video_data:
|
||||
multi_modal_data["video"] = video_data
|
||||
|
||||
# Process audio
|
||||
if request.Audios:
|
||||
audio_data = []
|
||||
for audio_path in request.Audios:
|
||||
audio = self._load_audio(audio_path)
|
||||
if audio is not None:
|
||||
audio_data.append(audio)
|
||||
if audio_data:
|
||||
multi_modal_data["audio"] = audio_data
|
||||
|
||||
# Build inputs dict
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": multi_modal_data if multi_modal_data else None,
|
||||
}
|
||||
|
||||
# Build sampling params
|
||||
sampling_params = SamplingParams(
|
||||
temperature=request.Temperature if request.Temperature > 0 else 0.7,
|
||||
top_p=request.TopP if request.TopP > 0 else 0.9,
|
||||
top_k=request.TopK if request.TopK > 0 else -1,
|
||||
max_tokens=request.Tokens if request.Tokens > 0 else 200,
|
||||
presence_penalty=request.PresencePenalty if request.PresencePenalty != 0 else 0.0,
|
||||
frequency_penalty=request.FrequencyPenalty if request.FrequencyPenalty != 0 else 0.0,
|
||||
repetition_penalty=request.RepetitionPenalty if request.RepetitionPenalty != 0 else 1.0,
|
||||
seed=request.Seed if request.Seed > 0 else None,
|
||||
stop=request.StopPrompts if request.StopPrompts else None,
|
||||
stop_token_ids=request.StopTokenIds if request.StopTokenIds else None,
|
||||
ignore_eos=request.IgnoreEOS,
|
||||
)
|
||||
sampling_params_list = [sampling_params]
|
||||
|
||||
# Call omni.generate() (returns generator for LLM mode)
|
||||
omni_generator = self.omni.generate([inputs], sampling_params_list)
|
||||
|
||||
# Extract text from outputs
|
||||
generated_text = ""
|
||||
for stage_outputs in omni_generator:
|
||||
if stage_outputs.final_output_type == "text":
|
||||
for output in stage_outputs.request_output:
|
||||
text_output = output.outputs[0].text
|
||||
if streaming:
|
||||
# Remove already sent text (vllm concatenates)
|
||||
delta_text = text_output.removeprefix(generated_text)
|
||||
yield backend_pb2.Reply(message=bytes(delta_text, encoding='utf-8'))
|
||||
generated_text = text_output
|
||||
|
||||
if not streaming:
|
||||
yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error in Predict: {err}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
yield backend_pb2.Reply(message=bytes(f"Error: {err}", encoding='utf-8'))
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# Validate model is loaded and is TTS type
|
||||
if not hasattr(self, 'omni'):
|
||||
return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.")
|
||||
if self.model_type not in ["tts"]:
|
||||
return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support TTS")
|
||||
|
||||
# Extract parameters
|
||||
text = request.text
|
||||
language = request.language if request.language else "Auto"
|
||||
voice = request.voice if request.voice else None
|
||||
task_type = self._detect_tts_task_type()
|
||||
|
||||
# Build prompt with chat template
|
||||
# TODO: for now vllm-omni supports only qwen3-tts, so we hardcode it, however, we want to support other models in the future.
|
||||
# and we might need to use the chat template here
|
||||
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
# Build inputs dict
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"additional_information": {
|
||||
"task_type": [task_type],
|
||||
"text": [text],
|
||||
"language": [language],
|
||||
"max_new_tokens": [2048],
|
||||
}
|
||||
}
|
||||
|
||||
# Add task-specific fields
|
||||
if task_type == "CustomVoice":
|
||||
if voice:
|
||||
inputs["additional_information"]["speaker"] = [voice]
|
||||
# Add instruct if provided in options
|
||||
if "instruct" in self.options:
|
||||
inputs["additional_information"]["instruct"] = [self.options["instruct"]]
|
||||
elif task_type == "VoiceDesign":
|
||||
if "instruct" in self.options:
|
||||
inputs["additional_information"]["instruct"] = [self.options["instruct"]]
|
||||
inputs["additional_information"]["non_streaming_mode"] = [True]
|
||||
elif task_type == "Base":
|
||||
# Voice cloning requires ref_audio and ref_text
|
||||
if "ref_audio" in self.options:
|
||||
inputs["additional_information"]["ref_audio"] = [self.options["ref_audio"]]
|
||||
if "ref_text" in self.options:
|
||||
inputs["additional_information"]["ref_text"] = [self.options["ref_text"]]
|
||||
if "x_vector_only_mode" in self.options:
|
||||
inputs["additional_information"]["x_vector_only_mode"] = [self.options["x_vector_only_mode"]]
|
||||
|
||||
# Build sampling params
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.9,
|
||||
top_p=1.0,
|
||||
top_k=50,
|
||||
max_tokens=2048,
|
||||
seed=42,
|
||||
detokenize=False,
|
||||
repetition_penalty=1.05,
|
||||
)
|
||||
sampling_params_list = [sampling_params]
|
||||
|
||||
# Call omni.generate()
|
||||
omni_generator = self.omni.generate(inputs, sampling_params_list)
|
||||
|
||||
# Extract audio (following TTS example)
|
||||
for stage_outputs in omni_generator:
|
||||
for output in stage_outputs.request_output:
|
||||
if "audio" in output.multimodal_output:
|
||||
audio_tensor = output.multimodal_output["audio"]
|
||||
audio_samplerate = output.multimodal_output["sr"].item()
|
||||
|
||||
# Convert to numpy
|
||||
audio_numpy = audio_tensor.float().detach().cpu().numpy()
|
||||
if audio_numpy.ndim > 1:
|
||||
audio_numpy = audio_numpy.flatten()
|
||||
|
||||
# Save audio file
|
||||
sf.write(request.dst, audio_numpy, samplerate=audio_samplerate, format="WAV")
|
||||
return backend_pb2.Result(message="TTS audio generated successfully", success=True)
|
||||
|
||||
return backend_pb2.Result(success=False, message="No audio output generated")
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error generating TTS: {err}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return backend_pb2.Result(success=False, message=f"Error generating TTS: {err}")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Signal handlers for graceful shutdown
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
62
backend/python/vllm-omni/install.sh
Executable file
62
backend/python/vllm-omni/install.sh
Executable file
@@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# Handle l4t build profiles (Python 3.12, pip fallback) if needed
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then
|
||||
PYTHON_VERSION="3.12"
|
||||
PYTHON_PATCH="12"
|
||||
PY_STANDALONE_TAG="20251120"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then
|
||||
USE_PIP=true
|
||||
fi
|
||||
|
||||
# Install base requirements first
|
||||
installRequirements
|
||||
|
||||
# Install vllm based on build type
|
||||
if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
# ROCm
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
|
||||
else
|
||||
uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700
|
||||
fi
|
||||
elif [ "x${BUILD_TYPE}" == "xcublas" ] || [ "x${BUILD_TYPE}" == "x" ]; then
|
||||
# CUDA (default) or CPU
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install vllm==0.14.0 --torch-backend=auto
|
||||
else
|
||||
uv pip install vllm==0.14.0 --torch-backend=auto
|
||||
fi
|
||||
else
|
||||
echo "Unsupported build type: ${BUILD_TYPE}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Clone and install vllm-omni from source
|
||||
if [ ! -d vllm-omni ]; then
|
||||
git clone https://github.com/vllm-project/vllm-omni.git
|
||||
fi
|
||||
|
||||
cd vllm-omni/
|
||||
|
||||
if [ "x${USE_PIP}" == "xtrue" ]; then
|
||||
pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e .
|
||||
else
|
||||
uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e .
|
||||
fi
|
||||
|
||||
cd ..
|
||||
2
backend/python/vllm-omni/requirements-after.txt
Normal file
2
backend/python/vllm-omni/requirements-after.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
diffusers
|
||||
librosa
|
||||
1
backend/python/vllm-omni/requirements-cublas12-after.txt
Normal file
1
backend/python/vllm-omni/requirements-cublas12-after.txt
Normal file
@@ -0,0 +1 @@
|
||||
https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
|
||||
4
backend/python/vllm-omni/requirements-cublas12.txt
Normal file
4
backend/python/vllm-omni/requirements-cublas12.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
accelerate
|
||||
torch==2.7.0
|
||||
transformers
|
||||
bitsandbytes
|
||||
5
backend/python/vllm-omni/requirements-hipblas.txt
Normal file
5
backend/python/vllm-omni/requirements-hipblas.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.4
|
||||
accelerate
|
||||
torch
|
||||
transformers
|
||||
bitsandbytes
|
||||
7
backend/python/vllm-omni/requirements.txt
Normal file
7
backend/python/vllm-omni/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
pillow
|
||||
numpy
|
||||
soundfile
|
||||
11
backend/python/vllm-omni/run.sh
Executable file
11
backend/python/vllm-omni/run.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
82
backend/python/vllm-omni/test.py
Normal file
82
backend/python/vllm-omni/test.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service.
|
||||
|
||||
This class contains methods to test the startup and shutdown of the gRPC service.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
# Use a small image generation model for testing
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="Tongyi-MAI/Z-Image-Turbo"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_generate_image(self):
|
||||
"""
|
||||
This method tests if image generation works
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="Tongyi-MAI/Z-Image-Turbo"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
req = backend_pb2.GenerateImageRequest(
|
||||
positive_prompt="a cup of coffee on the table",
|
||||
dst="/tmp/test_output.png",
|
||||
width=512,
|
||||
height=512,
|
||||
step=20,
|
||||
seed=42additional_information
|
||||
)
|
||||
resp = stub.GenerateImage(req)
|
||||
self.assertTrue(resp.success)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("GenerateImage service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
additional_information
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
12
backend/python/vllm-omni/test.sh
Executable file
12
backend/python/vllm-omni/test.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
Reference in New Issue
Block a user