diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index b30f31c0c..979731e55 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -196,7 +196,7 @@ def get_scheduler(name: str, config: dict = {}): # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): - def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant): + def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant, device_map=None): """ Load a diffusers pipeline dynamically using the dynamic loader. @@ -210,6 +210,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): fromSingleFile: Whether to use from_single_file() vs from_pretrained() torchType: The torch dtype to use variant: Model variant (e.g., "fp16") + device_map: Device mapping strategy (e.g., "auto" for multi-GPU) Returns: The loaded pipeline instance @@ -231,14 +232,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): dtype = torch.bfloat16 bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev") - transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype) + transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype, device_map=device_map) quantize(transformer, weights=qfloat8) freeze(transformer) - text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype, device_map=device_map) quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) - pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype) + pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype, device_map=device_map) pipe.transformer = transformer pipe.text_encoder_2 = text_encoder_2 @@ -251,13 +252,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): vae = AutoencoderKLWan.from_pretrained( request.Model, subfolder="vae", - torch_dtype=torch.float32 + torch_dtype=torch.float32, + device_map=device_map ) pipe = load_diffusers_pipeline( class_name="WanPipeline", model_id=request.Model, vae=vae, - torch_dtype=torchType + torch_dtype=torchType, + device_map=device_map ) self.txt2vid = True return pipe @@ -267,13 +270,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): vae = AutoencoderKLWan.from_pretrained( request.Model, subfolder="vae", - torch_dtype=torch.float32 + torch_dtype=torch.float32, + device_map=device_map ) pipe = load_diffusers_pipeline( class_name="WanImageToVideoPipeline", model_id=request.Model, vae=vae, - torch_dtype=torchType + torch_dtype=torchType, + device_map=device_map ) self.img2vid = True return pipe @@ -284,7 +289,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): class_name="SanaPipeline", model_id=request.Model, variant="bf16", - torch_dtype=torch.bfloat16 + torch_dtype=torch.bfloat16, + device_map=device_map ) pipe.vae.to(torch.bfloat16) pipe.text_encoder.to(torch.bfloat16) @@ -296,7 +302,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): pipe = load_diffusers_pipeline( class_name="DiffusionPipeline", model_id=request.Model, - torch_dtype=torchType + torch_dtype=torchType, + device_map=device_map ) return pipe @@ -307,7 +314,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): class_name="StableVideoDiffusionPipeline", model_id=request.Model, torch_dtype=torchType, - variant=variant + variant=variant, + device_map=device_map ) if not DISABLE_CPU_OFFLOAD: pipe.enable_model_cpu_offload() @@ -331,6 +339,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): modelFile, config=request.Model, # Use request.Model as the config/model_id subfolder="transformer", + device_map=device_map, **transformer_kwargs, ) @@ -340,6 +349,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): model_id=request.Model, transformer=transformer, torch_dtype=torchType, + device_map=device_map, ) else: # Single file but not GGUF - use standard single file loading @@ -348,6 +358,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): model_id=modelFile, from_single_file=True, torch_dtype=torchType, + device_map=device_map, ) else: # Standard loading from pretrained @@ -355,7 +366,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): class_name="LTX2ImageToVideoPipeline", model_id=request.Model, torch_dtype=torchType, - variant=variant + variant=variant, + device_map=device_map ) if not DISABLE_CPU_OFFLOAD: @@ -380,6 +392,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): modelFile, config=request.Model, # Use request.Model as the config/model_id subfolder="transformer", + device_map=device_map, **transformer_kwargs, ) @@ -389,6 +402,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): model_id=request.Model, transformer=transformer, torch_dtype=torchType, + device_map=device_map, ) else: # Single file but not GGUF - use standard single file loading @@ -397,6 +411,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): model_id=modelFile, from_single_file=True, torch_dtype=torchType, + device_map=device_map, ) else: # Standard loading from pretrained @@ -404,7 +419,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): class_name="LTX2Pipeline", model_id=request.Model, torch_dtype=torchType, - variant=variant + variant=variant, + device_map=device_map ) if not DISABLE_CPU_OFFLOAD: @@ -427,6 +443,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if not fromSingleFile: load_kwargs["use_safetensors"] = SAFETENSORS + # Add device_map for multi-GPU support (when TensorParallelSize > 1) + if device_map: + load_kwargs["device_map"] = device_map + # Determine pipeline class name - default to AutoPipelineForText2Image effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image" @@ -529,6 +549,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): print(f"LoadModel: PipelineType from request: {request.PipelineType}", file=sys.stderr) + # Determine device_map for multi-GPU support based on TensorParallelSize + # When TensorParallelSize > 1, use device_map='auto' to distribute model across GPUs + device_map = None + if hasattr(request, 'TensorParallelSize') and request.TensorParallelSize > 1: + device_map = "auto" + print(f"LoadModel: Multi-GPU mode enabled with TensorParallelSize={request.TensorParallelSize}, using device_map='auto'", file=sys.stderr) + # Load pipeline using dynamic loader # Special cases that require custom initialization are handled first self.pipe = self._load_pipeline( @@ -536,7 +563,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): modelFile=modelFile, fromSingleFile=fromSingleFile, torchType=torchType, - variant=variant + variant=variant, + device_map=device_map ) print(f"LoadModel: After loading - ltx2_pipeline: {self.ltx2_pipeline}, img2vid: {self.img2vid}, txt2vid: {self.txt2vid}, PipelineType: {self.PipelineType}", file=sys.stderr) @@ -561,7 +589,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if request.ControlNet: self.controlnet = ControlNetModel.from_pretrained( - request.ControlNet, torch_dtype=torchType, variant=variant + request.ControlNet, torch_dtype=torchType, variant=variant, device_map=device_map ) self.pipe.controlnet = self.controlnet else: @@ -600,7 +628,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights) - if device != "cpu": + # Only move pipeline to device if NOT using device_map + # device_map handles device placement automatically + if device_map is None and device != "cpu": self.pipe.to(device) if self.controlnet: self.controlnet.to(device)