fix: multi-GPU support for Diffusers (Issue #8575) (#8605)

* chore: init

* feat: implement multi-GPU support for Diffusers backend (fixes #8575)

---------

Co-authored-by: localai-bot <localai-bot@users.noreply.github.com>
This commit is contained in:
LocalAI [bot]
2026-02-19 21:35:58 +01:00
committed by GitHub
parent 76fba02e56
commit e555057f8b

View File

@@ -196,7 +196,7 @@ def get_scheduler(name: str, config: dict = {}):
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant):
def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant, device_map=None):
"""
Load a diffusers pipeline dynamically using the dynamic loader.
@@ -210,6 +210,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
fromSingleFile: Whether to use from_single_file() vs from_pretrained()
torchType: The torch dtype to use
variant: Model variant (e.g., "fp16")
device_map: Device mapping strategy (e.g., "auto" for multi-GPU)
Returns:
The loaded pipeline instance
@@ -231,14 +232,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
dtype = torch.bfloat16
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype, device_map=device_map)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype, device_map=device_map)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype, device_map=device_map)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
@@ -251,13 +252,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
vae = AutoencoderKLWan.from_pretrained(
request.Model,
subfolder="vae",
torch_dtype=torch.float32
torch_dtype=torch.float32,
device_map=device_map
)
pipe = load_diffusers_pipeline(
class_name="WanPipeline",
model_id=request.Model,
vae=vae,
torch_dtype=torchType
torch_dtype=torchType,
device_map=device_map
)
self.txt2vid = True
return pipe
@@ -267,13 +270,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
vae = AutoencoderKLWan.from_pretrained(
request.Model,
subfolder="vae",
torch_dtype=torch.float32
torch_dtype=torch.float32,
device_map=device_map
)
pipe = load_diffusers_pipeline(
class_name="WanImageToVideoPipeline",
model_id=request.Model,
vae=vae,
torch_dtype=torchType
torch_dtype=torchType,
device_map=device_map
)
self.img2vid = True
return pipe
@@ -284,7 +289,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
class_name="SanaPipeline",
model_id=request.Model,
variant="bf16",
torch_dtype=torch.bfloat16
torch_dtype=torch.bfloat16,
device_map=device_map
)
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
@@ -296,7 +302,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
pipe = load_diffusers_pipeline(
class_name="DiffusionPipeline",
model_id=request.Model,
torch_dtype=torchType
torch_dtype=torchType,
device_map=device_map
)
return pipe
@@ -307,7 +314,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
class_name="StableVideoDiffusionPipeline",
model_id=request.Model,
torch_dtype=torchType,
variant=variant
variant=variant,
device_map=device_map
)
if not DISABLE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload()
@@ -331,6 +339,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
modelFile,
config=request.Model, # Use request.Model as the config/model_id
subfolder="transformer",
device_map=device_map,
**transformer_kwargs,
)
@@ -340,6 +349,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
model_id=request.Model,
transformer=transformer,
torch_dtype=torchType,
device_map=device_map,
)
else:
# Single file but not GGUF - use standard single file loading
@@ -348,6 +358,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
model_id=modelFile,
from_single_file=True,
torch_dtype=torchType,
device_map=device_map,
)
else:
# Standard loading from pretrained
@@ -355,7 +366,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
class_name="LTX2ImageToVideoPipeline",
model_id=request.Model,
torch_dtype=torchType,
variant=variant
variant=variant,
device_map=device_map
)
if not DISABLE_CPU_OFFLOAD:
@@ -380,6 +392,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
modelFile,
config=request.Model, # Use request.Model as the config/model_id
subfolder="transformer",
device_map=device_map,
**transformer_kwargs,
)
@@ -389,6 +402,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
model_id=request.Model,
transformer=transformer,
torch_dtype=torchType,
device_map=device_map,
)
else:
# Single file but not GGUF - use standard single file loading
@@ -397,6 +411,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
model_id=modelFile,
from_single_file=True,
torch_dtype=torchType,
device_map=device_map,
)
else:
# Standard loading from pretrained
@@ -404,7 +419,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
class_name="LTX2Pipeline",
model_id=request.Model,
torch_dtype=torchType,
variant=variant
variant=variant,
device_map=device_map
)
if not DISABLE_CPU_OFFLOAD:
@@ -427,6 +443,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if not fromSingleFile:
load_kwargs["use_safetensors"] = SAFETENSORS
# Add device_map for multi-GPU support (when TensorParallelSize > 1)
if device_map:
load_kwargs["device_map"] = device_map
# Determine pipeline class name - default to AutoPipelineForText2Image
effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image"
@@ -529,6 +549,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
print(f"LoadModel: PipelineType from request: {request.PipelineType}", file=sys.stderr)
# Determine device_map for multi-GPU support based on TensorParallelSize
# When TensorParallelSize > 1, use device_map='auto' to distribute model across GPUs
device_map = None
if hasattr(request, 'TensorParallelSize') and request.TensorParallelSize > 1:
device_map = "auto"
print(f"LoadModel: Multi-GPU mode enabled with TensorParallelSize={request.TensorParallelSize}, using device_map='auto'", file=sys.stderr)
# Load pipeline using dynamic loader
# Special cases that require custom initialization are handled first
self.pipe = self._load_pipeline(
@@ -536,7 +563,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
modelFile=modelFile,
fromSingleFile=fromSingleFile,
torchType=torchType,
variant=variant
variant=variant,
device_map=device_map
)
print(f"LoadModel: After loading - ltx2_pipeline: {self.ltx2_pipeline}, img2vid: {self.img2vid}, txt2vid: {self.txt2vid}, PipelineType: {self.PipelineType}", file=sys.stderr)
@@ -561,7 +589,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.ControlNet:
self.controlnet = ControlNetModel.from_pretrained(
request.ControlNet, torch_dtype=torchType, variant=variant
request.ControlNet, torch_dtype=torchType, variant=variant, device_map=device_map
)
self.pipe.controlnet = self.controlnet
else:
@@ -600,7 +628,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights)
if device != "cpu":
# Only move pipeline to device if NOT using device_map
# device_map handles device placement automatically
if device_map is None and device != "cpu":
self.pipe.to(device)
if self.controlnet:
self.controlnet.to(device)