mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-25 07:29:07 -05:00
* Initial plan Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add dynamic loader for diffusers pipelines and refactor backend.py Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix pipeline discovery error handling and test mock issue Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Address code review feedback: direct imports, better error handling, improved tests Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Address remaining code review feedback: specific exceptions, registry access, test imports Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add defensive fallback for DiffusionPipeline registry access Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Actually use dynamic pipeline loading for all pipelines in backend Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Use dynamic loader consistently for all pipelines including AutoPipelineForText2Image Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Move dynamic loader tests into test.py for CI compatibility Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Extend dynamic loader to discover any diffusers class type, not just DiffusionPipeline Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add AutoPipeline classes to pipeline registry for default model loading Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix(python): set pyvenv python home Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * do pyenv update during start Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Minor changes Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: mudler <2420543+mudler@users.noreply.github.com> Co-authored-by: Ettore Di Giacinto <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
778 lines
30 KiB
Python
Executable File
778 lines
30 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
LocalAI Diffusers Backend
|
|
|
|
This backend provides gRPC access to diffusers pipelines with dynamic pipeline loading.
|
|
New pipelines added to diffusers become available automatically without code changes.
|
|
"""
|
|
from concurrent import futures
|
|
import traceback
|
|
import argparse
|
|
from collections import defaultdict
|
|
from enum import Enum
|
|
import signal
|
|
import sys
|
|
import time
|
|
import os
|
|
|
|
from PIL import Image
|
|
import torch
|
|
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
import grpc
|
|
|
|
# Import dynamic loader for pipeline discovery
|
|
from diffusers_dynamic_loader import (
|
|
get_pipeline_registry,
|
|
resolve_pipeline_class,
|
|
get_available_pipelines,
|
|
load_diffusers_pipeline,
|
|
)
|
|
|
|
# Import specific items still needed for special cases and safety checker
|
|
from diffusers import DiffusionPipeline, ControlNetModel
|
|
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKLWan
|
|
from diffusers.pipelines.stable_diffusion import safety_checker
|
|
from diffusers.utils import load_image, export_to_video
|
|
from compel import Compel, ReturnedEmbeddingsType
|
|
from optimum.quanto import freeze, qfloat8, quantize
|
|
from transformers import T5EncoderModel
|
|
from safetensors.torch import load_file
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|
COMPEL = os.environ.get("COMPEL", "0") == "1"
|
|
XPU = os.environ.get("XPU", "0") == "1"
|
|
CLIPSKIP = os.environ.get("CLIPSKIP", "1") == "1"
|
|
SAFETENSORS = os.environ.get("SAFETENSORS", "1") == "1"
|
|
CHUNK_SIZE = os.environ.get("CHUNK_SIZE", "8")
|
|
FPS = os.environ.get("FPS", "7")
|
|
DISABLE_CPU_OFFLOAD = os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
|
|
FRAMES = os.environ.get("FRAMES", "64")
|
|
|
|
if XPU:
|
|
print(torch.xpu.get_device_name(0))
|
|
|
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
|
|
|
|
|
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
|
|
def sc(self, clip_input, images): return images, [False for i in images]
|
|
|
|
|
|
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
|
|
safety_checker.StableDiffusionSafetyChecker.forward = sc
|
|
|
|
from diffusers.schedulers import (
|
|
DDIMScheduler,
|
|
DPMSolverMultistepScheduler,
|
|
DPMSolverSinglestepScheduler,
|
|
EulerAncestralDiscreteScheduler,
|
|
EulerDiscreteScheduler,
|
|
HeunDiscreteScheduler,
|
|
KDPM2AncestralDiscreteScheduler,
|
|
KDPM2DiscreteScheduler,
|
|
LMSDiscreteScheduler,
|
|
PNDMScheduler,
|
|
UniPCMultistepScheduler,
|
|
)
|
|
|
|
def is_float(s):
|
|
"""Check if a string can be converted to float."""
|
|
try:
|
|
float(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
def is_int(s):
|
|
"""Check if a string can be converted to int."""
|
|
try:
|
|
int(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
|
# Credits to https://github.com/neggles
|
|
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
|
|
class DiffusionScheduler(str, Enum):
|
|
ddim = "ddim" # DDIM
|
|
pndm = "pndm" # PNDM
|
|
heun = "heun" # Heun
|
|
unipc = "unipc" # UniPC
|
|
euler = "euler" # Euler
|
|
euler_a = "euler_a" # Euler a
|
|
|
|
lms = "lms" # LMS
|
|
k_lms = "k_lms" # LMS Karras
|
|
|
|
dpm_2 = "dpm_2" # DPM2
|
|
k_dpm_2 = "k_dpm_2" # DPM2 Karras
|
|
|
|
dpm_2_a = "dpm_2_a" # DPM2 a
|
|
k_dpm_2_a = "k_dpm_2_a" # DPM2 a Karras
|
|
|
|
dpmpp_2m = "dpmpp_2m" # DPM++ 2M
|
|
k_dpmpp_2m = "k_dpmpp_2m" # DPM++ 2M Karras
|
|
|
|
dpmpp_sde = "dpmpp_sde" # DPM++ SDE
|
|
k_dpmpp_sde = "k_dpmpp_sde" # DPM++ SDE Karras
|
|
|
|
dpmpp_2m_sde = "dpmpp_2m_sde" # DPM++ 2M SDE
|
|
k_dpmpp_2m_sde = "k_dpmpp_2m_sde" # DPM++ 2M SDE Karras
|
|
|
|
|
|
def get_scheduler(name: str, config: dict = {}):
|
|
is_karras = name.startswith("k_")
|
|
if is_karras:
|
|
# strip the k_ prefix and add the karras sigma flag to config
|
|
name = name.lstrip("k_")
|
|
config["use_karras_sigmas"] = True
|
|
|
|
if name == DiffusionScheduler.ddim:
|
|
sched_class = DDIMScheduler
|
|
elif name == DiffusionScheduler.pndm:
|
|
sched_class = PNDMScheduler
|
|
elif name == DiffusionScheduler.heun:
|
|
sched_class = HeunDiscreteScheduler
|
|
elif name == DiffusionScheduler.unipc:
|
|
sched_class = UniPCMultistepScheduler
|
|
elif name == DiffusionScheduler.euler:
|
|
sched_class = EulerDiscreteScheduler
|
|
elif name == DiffusionScheduler.euler_a:
|
|
sched_class = EulerAncestralDiscreteScheduler
|
|
elif name == DiffusionScheduler.lms:
|
|
sched_class = LMSDiscreteScheduler
|
|
elif name == DiffusionScheduler.dpm_2:
|
|
# Equivalent to DPM2 in K-Diffusion
|
|
sched_class = KDPM2DiscreteScheduler
|
|
elif name == DiffusionScheduler.dpm_2_a:
|
|
# Equivalent to `DPM2 a`` in K-Diffusion
|
|
sched_class = KDPM2AncestralDiscreteScheduler
|
|
elif name == DiffusionScheduler.dpmpp_2m:
|
|
# Equivalent to `DPM++ 2M` in K-Diffusion
|
|
sched_class = DPMSolverMultistepScheduler
|
|
config["algorithm_type"] = "dpmsolver++"
|
|
config["solver_order"] = 2
|
|
elif name == DiffusionScheduler.dpmpp_sde:
|
|
# Equivalent to `DPM++ SDE` in K-Diffusion
|
|
sched_class = DPMSolverSinglestepScheduler
|
|
elif name == DiffusionScheduler.dpmpp_2m_sde:
|
|
# Equivalent to `DPM++ 2M SDE` in K-Diffusion
|
|
sched_class = DPMSolverMultistepScheduler
|
|
config["algorithm_type"] = "sde-dpmsolver++"
|
|
else:
|
|
raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'")
|
|
|
|
return sched_class.from_config(config)
|
|
|
|
|
|
# Implement the BackendServicer class with the service methods
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|
|
|
def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant):
|
|
"""
|
|
Load a diffusers pipeline dynamically using the dynamic loader.
|
|
|
|
This method uses load_diffusers_pipeline() for most pipelines, falling back
|
|
to explicit handling only for pipelines requiring custom initialization
|
|
(e.g., quantization, special VAE handling).
|
|
|
|
Args:
|
|
request: The gRPC request containing pipeline configuration
|
|
modelFile: Path to the model file (for single file loading)
|
|
fromSingleFile: Whether to use from_single_file() vs from_pretrained()
|
|
torchType: The torch dtype to use
|
|
variant: Model variant (e.g., "fp16")
|
|
|
|
Returns:
|
|
The loaded pipeline instance
|
|
"""
|
|
pipeline_type = request.PipelineType
|
|
|
|
# Handle IMG2IMG request flag with default pipeline
|
|
if request.IMG2IMG and pipeline_type == "":
|
|
pipeline_type = "StableDiffusionImg2ImgPipeline"
|
|
|
|
# ================================================================
|
|
# Special cases requiring custom initialization logic
|
|
# Only handle pipelines that truly need custom code (quantization,
|
|
# special VAE handling, etc.). All other pipelines use dynamic loading.
|
|
# ================================================================
|
|
|
|
# FluxTransformer2DModel - requires quantization and custom transformer loading
|
|
if pipeline_type == "FluxTransformer2DModel":
|
|
dtype = torch.bfloat16
|
|
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
|
|
|
|
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
|
|
quantize(transformer, weights=qfloat8)
|
|
freeze(transformer)
|
|
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
|
quantize(text_encoder_2, weights=qfloat8)
|
|
freeze(text_encoder_2)
|
|
|
|
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
|
pipe.transformer = transformer
|
|
pipe.text_encoder_2 = text_encoder_2
|
|
|
|
if request.LowVRAM:
|
|
pipe.enable_model_cpu_offload()
|
|
return pipe
|
|
|
|
# WanPipeline - requires special VAE with float32 dtype
|
|
if pipeline_type == "WanPipeline":
|
|
vae = AutoencoderKLWan.from_pretrained(
|
|
request.Model,
|
|
subfolder="vae",
|
|
torch_dtype=torch.float32
|
|
)
|
|
pipe = load_diffusers_pipeline(
|
|
class_name="WanPipeline",
|
|
model_id=request.Model,
|
|
vae=vae,
|
|
torch_dtype=torchType
|
|
)
|
|
self.txt2vid = True
|
|
return pipe
|
|
|
|
# WanImageToVideoPipeline - requires special VAE with float32 dtype
|
|
if pipeline_type == "WanImageToVideoPipeline":
|
|
vae = AutoencoderKLWan.from_pretrained(
|
|
request.Model,
|
|
subfolder="vae",
|
|
torch_dtype=torch.float32
|
|
)
|
|
pipe = load_diffusers_pipeline(
|
|
class_name="WanImageToVideoPipeline",
|
|
model_id=request.Model,
|
|
vae=vae,
|
|
torch_dtype=torchType
|
|
)
|
|
self.img2vid = True
|
|
return pipe
|
|
|
|
# SanaPipeline - requires special VAE and text encoder dtype conversion
|
|
if pipeline_type == "SanaPipeline":
|
|
pipe = load_diffusers_pipeline(
|
|
class_name="SanaPipeline",
|
|
model_id=request.Model,
|
|
variant="bf16",
|
|
torch_dtype=torch.bfloat16
|
|
)
|
|
pipe.vae.to(torch.bfloat16)
|
|
pipe.text_encoder.to(torch.bfloat16)
|
|
return pipe
|
|
|
|
# VideoDiffusionPipeline - alias for DiffusionPipeline with txt2vid flag
|
|
if pipeline_type == "VideoDiffusionPipeline":
|
|
self.txt2vid = True
|
|
pipe = load_diffusers_pipeline(
|
|
class_name="DiffusionPipeline",
|
|
model_id=request.Model,
|
|
torch_dtype=torchType
|
|
)
|
|
return pipe
|
|
|
|
# StableVideoDiffusionPipeline - needs img2vid flag and CPU offload
|
|
if pipeline_type == "StableVideoDiffusionPipeline":
|
|
self.img2vid = True
|
|
pipe = load_diffusers_pipeline(
|
|
class_name="StableVideoDiffusionPipeline",
|
|
model_id=request.Model,
|
|
torch_dtype=torchType,
|
|
variant=variant
|
|
)
|
|
if not DISABLE_CPU_OFFLOAD:
|
|
pipe.enable_model_cpu_offload()
|
|
return pipe
|
|
|
|
# ================================================================
|
|
# Dynamic pipeline loading - the default path for most pipelines
|
|
# Uses the dynamic loader to instantiate any pipeline by class name
|
|
# ================================================================
|
|
|
|
# Build kwargs for dynamic loading
|
|
load_kwargs = {"torch_dtype": torchType}
|
|
|
|
# Add variant if not loading from single file
|
|
if not fromSingleFile and variant:
|
|
load_kwargs["variant"] = variant
|
|
|
|
# Add use_safetensors for from_pretrained
|
|
if not fromSingleFile:
|
|
load_kwargs["use_safetensors"] = SAFETENSORS
|
|
|
|
# Determine pipeline class name - default to AutoPipelineForText2Image
|
|
effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image"
|
|
|
|
# Use dynamic loader for all pipelines
|
|
try:
|
|
pipe = load_diffusers_pipeline(
|
|
class_name=effective_pipeline_type,
|
|
model_id=modelFile if fromSingleFile else request.Model,
|
|
from_single_file=fromSingleFile,
|
|
**load_kwargs
|
|
)
|
|
except Exception as e:
|
|
# Provide helpful error with available pipelines
|
|
available = get_available_pipelines()
|
|
raise ValueError(
|
|
f"Failed to load pipeline '{effective_pipeline_type}': {e}\n"
|
|
f"Available pipelines: {', '.join(available[:30])}..."
|
|
) from e
|
|
|
|
# Apply LowVRAM optimization if supported and requested
|
|
if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload'):
|
|
pipe.enable_model_cpu_offload()
|
|
|
|
return pipe
|
|
|
|
def Health(self, request, context):
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
|
|
|
def LoadModel(self, request, context):
|
|
try:
|
|
print(f"Loading model {request.Model}...", file=sys.stderr)
|
|
print(f"Request {request}", file=sys.stderr)
|
|
torchType = torch.float32
|
|
variant = None
|
|
|
|
if request.F16Memory:
|
|
torchType = torch.float16
|
|
variant = "fp16"
|
|
|
|
options = request.Options
|
|
|
|
# empty dict
|
|
self.options = {}
|
|
|
|
# The options are a list of strings in this form optname:optvalue
|
|
# We are storing all the options in a dict so we can use it later when
|
|
# generating the images
|
|
for opt in options:
|
|
if ":" not in opt:
|
|
continue
|
|
key, value = opt.split(":")
|
|
# if value is a number, convert it to the appropriate type
|
|
if is_float(value):
|
|
value = float(value)
|
|
elif is_int(value):
|
|
value = int(value)
|
|
elif value.lower() in ["true", "false"]:
|
|
value = value.lower() == "true"
|
|
self.options[key] = value
|
|
|
|
# From options, extract if present "torch_dtype" and set it to the appropriate type
|
|
if "torch_dtype" in self.options:
|
|
if self.options["torch_dtype"] == "fp16":
|
|
torchType = torch.float16
|
|
elif self.options["torch_dtype"] == "bf16":
|
|
torchType = torch.bfloat16
|
|
elif self.options["torch_dtype"] == "fp32":
|
|
torchType = torch.float32
|
|
# remove it from options
|
|
del self.options["torch_dtype"]
|
|
|
|
print(f"Options: {self.options}", file=sys.stderr)
|
|
|
|
local = False
|
|
modelFile = request.Model
|
|
|
|
self.cfg_scale = 7
|
|
self.PipelineType = request.PipelineType
|
|
|
|
if request.CFGScale != 0:
|
|
self.cfg_scale = request.CFGScale
|
|
|
|
clipmodel = "Lykon/dreamshaper-8"
|
|
if request.CLIPModel != "":
|
|
clipmodel = request.CLIPModel
|
|
clipsubfolder = "text_encoder"
|
|
if request.CLIPSubfolder != "":
|
|
clipsubfolder = request.CLIPSubfolder
|
|
|
|
# Check if ModelFile exists
|
|
if request.ModelFile != "":
|
|
if os.path.exists(request.ModelFile):
|
|
local = True
|
|
modelFile = request.ModelFile
|
|
|
|
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
|
self.img2vid = False
|
|
self.txt2vid = False
|
|
|
|
# Load pipeline using dynamic loader
|
|
# Special cases that require custom initialization are handled first
|
|
self.pipe = self._load_pipeline(
|
|
request=request,
|
|
modelFile=modelFile,
|
|
fromSingleFile=fromSingleFile,
|
|
torchType=torchType,
|
|
variant=variant
|
|
)
|
|
|
|
if CLIPSKIP and request.CLIPSkip != 0:
|
|
self.clip_skip = request.CLIPSkip
|
|
else:
|
|
self.clip_skip = 0
|
|
|
|
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
|
# TODO: this needs to be customized
|
|
if request.SchedulerType != "":
|
|
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
|
|
|
if COMPEL:
|
|
self.compel = Compel(
|
|
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
|
|
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
|
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
|
|
requires_pooled=[False, True]
|
|
)
|
|
|
|
if request.ControlNet:
|
|
self.controlnet = ControlNetModel.from_pretrained(
|
|
request.ControlNet, torch_dtype=torchType, variant=variant
|
|
)
|
|
self.pipe.controlnet = self.controlnet
|
|
else:
|
|
self.controlnet = None
|
|
|
|
if request.LoraAdapter and not os.path.isabs(request.LoraAdapter):
|
|
# modify LoraAdapter to be relative to modelFileBase
|
|
request.LoraAdapter = os.path.join(request.ModelPath, request.LoraAdapter)
|
|
|
|
device = "cpu" if not request.CUDA else "cuda"
|
|
if XPU:
|
|
device = "xpu"
|
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
if mps_available:
|
|
device = "mps"
|
|
self.device = device
|
|
if request.LoraAdapter:
|
|
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
|
|
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
|
|
self.pipe.load_lora_weights(request.LoraAdapter)
|
|
else:
|
|
self.pipe.unet.load_attn_procs(request.LoraAdapter)
|
|
if len(request.LoraAdapters) > 0:
|
|
i = 0
|
|
adapters_name = []
|
|
adapters_weights = []
|
|
for adapter in request.LoraAdapters:
|
|
if not os.path.isabs(adapter):
|
|
adapter = os.path.join(request.ModelPath, adapter)
|
|
self.pipe.load_lora_weights(adapter, adapter_name=f"adapter_{i}")
|
|
adapters_name.append(f"adapter_{i}")
|
|
i += 1
|
|
|
|
for adapters_weight in request.LoraScales:
|
|
adapters_weights.append(adapters_weight)
|
|
|
|
self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights)
|
|
|
|
if device != "cpu":
|
|
self.pipe.to(device)
|
|
if self.controlnet:
|
|
self.controlnet.to(device)
|
|
|
|
except Exception as err:
|
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
|
# Implement your logic here for the LoadModel service
|
|
# Replace this with your desired response
|
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
|
|
|
# https://github.com/huggingface/diffusers/issues/3064
|
|
def load_lora_weights(self, checkpoint_path, multiplier, device, dtype):
|
|
LORA_PREFIX_UNET = "lora_unet"
|
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
|
# load LoRA weight from .safetensors
|
|
state_dict = load_file(checkpoint_path, device=device)
|
|
|
|
updates = defaultdict(dict)
|
|
for key, value in state_dict.items():
|
|
# it is suggested to print out the key, it usually will be something like below
|
|
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
|
|
|
layer, elem = key.split('.', 1)
|
|
updates[layer][elem] = value
|
|
|
|
# directly update weight in diffusers model
|
|
for layer, elems in updates.items():
|
|
|
|
if "text" in layer:
|
|
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
|
curr_layer = self.pipe.text_encoder
|
|
else:
|
|
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
|
curr_layer = self.pipe.unet
|
|
|
|
# find the target layer
|
|
temp_name = layer_infos.pop(0)
|
|
while len(layer_infos) > -1:
|
|
try:
|
|
curr_layer = curr_layer.__getattr__(temp_name)
|
|
if len(layer_infos) > 0:
|
|
temp_name = layer_infos.pop(0)
|
|
elif len(layer_infos) == 0:
|
|
break
|
|
except Exception:
|
|
if len(temp_name) > 0:
|
|
temp_name += "_" + layer_infos.pop(0)
|
|
else:
|
|
temp_name = layer_infos.pop(0)
|
|
|
|
# get elements for this layer
|
|
weight_up = elems['lora_up.weight'].to(dtype)
|
|
weight_down = elems['lora_down.weight'].to(dtype)
|
|
alpha = elems['alpha'] if 'alpha' in elems else None
|
|
if alpha:
|
|
alpha = alpha.item() / weight_up.shape[1]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
# update weight
|
|
if len(weight_up.shape) == 4:
|
|
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
|
else:
|
|
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
|
|
|
|
def GenerateImage(self, request, context):
|
|
|
|
prompt = request.positive_prompt
|
|
|
|
steps = 1
|
|
|
|
if request.step != 0:
|
|
steps = request.step
|
|
|
|
# create a dictionary of values for the parameters
|
|
options = {
|
|
"num_inference_steps": steps,
|
|
}
|
|
|
|
if hasattr(request, 'negative_prompt') and request.negative_prompt != "":
|
|
options["negative_prompt"] = request.negative_prompt
|
|
|
|
# Handle image source: prioritize RefImages over request.src
|
|
image_src = None
|
|
if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0:
|
|
# Use the first reference image if available
|
|
image_src = request.ref_images[0]
|
|
print(f"Using reference image: {image_src}", file=sys.stderr)
|
|
elif request.src != "":
|
|
# Fall back to request.src if no ref_images
|
|
image_src = request.src
|
|
print(f"Using source image: {image_src}", file=sys.stderr)
|
|
else:
|
|
print("No image source provided", file=sys.stderr)
|
|
|
|
if image_src and not self.controlnet and not self.img2vid:
|
|
image = Image.open(image_src)
|
|
options["image"] = image
|
|
elif self.controlnet and image_src:
|
|
pose_image = load_image(image_src)
|
|
options["image"] = pose_image
|
|
|
|
if CLIPSKIP and self.clip_skip != 0:
|
|
options["clip_skip"] = self.clip_skip
|
|
|
|
kwargs = {}
|
|
|
|
# populate kwargs from self.options.
|
|
kwargs.update(self.options)
|
|
|
|
# Set seed
|
|
if request.seed > 0:
|
|
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(
|
|
request.seed
|
|
)
|
|
|
|
if self.PipelineType == "FluxPipeline":
|
|
kwargs["max_sequence_length"] = 256
|
|
|
|
if request.width:
|
|
kwargs["width"] = request.width
|
|
|
|
if request.height:
|
|
kwargs["height"] = request.height
|
|
|
|
if self.PipelineType == "FluxTransformer2DModel":
|
|
kwargs["output_type"] = "pil"
|
|
kwargs["generator"] = torch.Generator("cpu").manual_seed(0)
|
|
|
|
if self.img2vid:
|
|
# Load the conditioning image
|
|
if image_src:
|
|
image = load_image(image_src)
|
|
else:
|
|
# Fallback to request.src for img2vid if no ref_images
|
|
image = load_image(request.src)
|
|
image = image.resize((1024, 576))
|
|
|
|
generator = torch.manual_seed(request.seed)
|
|
frames = self.pipe(image, guidance_scale=self.cfg_scale, decode_chunk_size=CHUNK_SIZE, generator=generator).frames[0]
|
|
export_to_video(frames, request.dst, fps=FPS)
|
|
return backend_pb2.Result(message="Media generated successfully", success=True)
|
|
|
|
if self.txt2vid:
|
|
video_frames = self.pipe(prompt, guidance_scale=self.cfg_scale, num_inference_steps=steps, num_frames=int(FRAMES)).frames
|
|
export_to_video(video_frames, request.dst)
|
|
return backend_pb2.Result(message="Media generated successfully", success=True)
|
|
|
|
print(f"Generating image with {kwargs=}", file=sys.stderr)
|
|
image = {}
|
|
if COMPEL:
|
|
conditioning, pooled = self.compel.build_conditioning_tensor(prompt)
|
|
kwargs["prompt_embeds"] = conditioning
|
|
kwargs["pooled_prompt_embeds"] = pooled
|
|
# pass the kwargs dictionary to the self.pipe method
|
|
image = self.pipe(
|
|
guidance_scale=self.cfg_scale,
|
|
**kwargs
|
|
).images[0]
|
|
else:
|
|
# pass the kwargs dictionary to the self.pipe method
|
|
image = self.pipe(
|
|
prompt,
|
|
guidance_scale=self.cfg_scale,
|
|
**kwargs
|
|
).images[0]
|
|
|
|
# save the result
|
|
image.save(request.dst)
|
|
|
|
return backend_pb2.Result(message="Media generated", success=True)
|
|
|
|
def GenerateVideo(self, request, context):
|
|
try:
|
|
prompt = request.prompt
|
|
if not prompt:
|
|
return backend_pb2.Result(success=False, message="No prompt provided for video generation")
|
|
|
|
# Set default values from request or use defaults
|
|
num_frames = request.num_frames if request.num_frames > 0 else 81
|
|
fps = request.fps if request.fps > 0 else 16
|
|
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
|
|
num_inference_steps = request.step if request.step > 0 else 40
|
|
|
|
# Prepare generation parameters
|
|
kwargs = {
|
|
"prompt": prompt,
|
|
"negative_prompt": request.negative_prompt if request.negative_prompt else "",
|
|
"height": request.height if request.height > 0 else 720,
|
|
"width": request.width if request.width > 0 else 1280,
|
|
"num_frames": num_frames,
|
|
"guidance_scale": cfg_scale,
|
|
"num_inference_steps": num_inference_steps,
|
|
}
|
|
|
|
# Add custom options from self.options (including guidance_scale_2 if specified)
|
|
kwargs.update(self.options)
|
|
|
|
# Set seed if provided
|
|
if request.seed > 0:
|
|
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)
|
|
|
|
# Handle start and end images for video generation
|
|
if request.start_image:
|
|
kwargs["start_image"] = load_image(request.start_image)
|
|
if request.end_image:
|
|
kwargs["end_image"] = load_image(request.end_image)
|
|
|
|
print(f"Generating video with {kwargs=}", file=sys.stderr)
|
|
|
|
# Generate video frames based on pipeline type
|
|
if self.PipelineType == "WanPipeline":
|
|
# WAN2.2 text-to-video generation
|
|
output = self.pipe(**kwargs)
|
|
frames = output.frames[0] # WAN2.2 returns frames in this format
|
|
elif self.PipelineType == "WanImageToVideoPipeline":
|
|
# WAN2.2 image-to-video generation
|
|
if request.start_image:
|
|
# Load and resize the input image according to WAN2.2 requirements
|
|
image = load_image(request.start_image)
|
|
# Use request dimensions or defaults, but respect WAN2.2 constraints
|
|
request_height = request.height if request.height > 0 else 480
|
|
request_width = request.width if request.width > 0 else 832
|
|
max_area = request_height * request_width
|
|
aspect_ratio = image.height / image.width
|
|
mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
|
|
height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
|
|
width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
|
|
image = image.resize((width, height))
|
|
kwargs["image"] = image
|
|
kwargs["height"] = height
|
|
kwargs["width"] = width
|
|
|
|
output = self.pipe(**kwargs)
|
|
frames = output.frames[0]
|
|
elif self.img2vid:
|
|
# Generic image-to-video generation
|
|
if request.start_image:
|
|
image = load_image(request.start_image)
|
|
image = image.resize((request.width if request.width > 0 else 1024,
|
|
request.height if request.height > 0 else 576))
|
|
kwargs["image"] = image
|
|
|
|
output = self.pipe(**kwargs)
|
|
frames = output.frames[0]
|
|
elif self.txt2vid:
|
|
# Generic text-to-video generation
|
|
output = self.pipe(**kwargs)
|
|
frames = output.frames[0]
|
|
else:
|
|
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
|
|
|
|
# Export video
|
|
export_to_video(frames, request.dst, fps=fps)
|
|
|
|
return backend_pb2.Result(message="Video generated successfully", success=True)
|
|
|
|
except Exception as err:
|
|
print(f"Error generating video: {err}", file=sys.stderr)
|
|
traceback.print_exc()
|
|
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
|
|
|
|
|
|
def serve(address):
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
|
options=[
|
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
|
])
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
|
server.add_insecure_port(address)
|
|
server.start()
|
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
|
|
|
# Define the signal handler function
|
|
def signal_handler(sig, frame):
|
|
print("Received termination signal. Shutting down...")
|
|
server.stop(0)
|
|
sys.exit(0)
|
|
|
|
# Set the signal handlers for SIGINT and SIGTERM
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
try:
|
|
while True:
|
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
|
except KeyboardInterrupt:
|
|
server.stop(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
|
parser.add_argument(
|
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
serve(args.addr)
|