mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-19 15:40:18 -05:00
* chore: init * feat: implement multi-GPU support for Diffusers backend (fixes #8575) --------- Co-authored-by: localai-bot <localai-bot@users.noreply.github.com>
This commit is contained in:
@@ -196,7 +196,7 @@ def get_scheduler(name: str, config: dict = {}):
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant):
|
||||
def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant, device_map=None):
|
||||
"""
|
||||
Load a diffusers pipeline dynamically using the dynamic loader.
|
||||
|
||||
@@ -210,6 +210,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
fromSingleFile: Whether to use from_single_file() vs from_pretrained()
|
||||
torchType: The torch dtype to use
|
||||
variant: Model variant (e.g., "fp16")
|
||||
device_map: Device mapping strategy (e.g., "auto" for multi-GPU)
|
||||
|
||||
Returns:
|
||||
The loaded pipeline instance
|
||||
@@ -231,14 +232,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
dtype = torch.bfloat16
|
||||
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
|
||||
|
||||
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
|
||||
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype, device_map=device_map)
|
||||
quantize(transformer, weights=qfloat8)
|
||||
freeze(transformer)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype, device_map=device_map)
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
||||
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype, device_map=device_map)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
|
||||
@@ -251,13 +252,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
torch_dtype=torch.float32,
|
||||
device_map=device_map
|
||||
)
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="WanPipeline",
|
||||
model_id=request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map
|
||||
)
|
||||
self.txt2vid = True
|
||||
return pipe
|
||||
@@ -267,13 +270,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
request.Model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
torch_dtype=torch.float32,
|
||||
device_map=device_map
|
||||
)
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="WanImageToVideoPipeline",
|
||||
model_id=request.Model,
|
||||
vae=vae,
|
||||
torch_dtype=torchType
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map
|
||||
)
|
||||
self.img2vid = True
|
||||
return pipe
|
||||
@@ -284,7 +289,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
class_name="SanaPipeline",
|
||||
model_id=request.Model,
|
||||
variant="bf16",
|
||||
torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device_map
|
||||
)
|
||||
pipe.vae.to(torch.bfloat16)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
@@ -296,7 +302,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
pipe = load_diffusers_pipeline(
|
||||
class_name="DiffusionPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map
|
||||
)
|
||||
return pipe
|
||||
|
||||
@@ -307,7 +314,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
class_name="StableVideoDiffusionPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
variant=variant,
|
||||
device_map=device_map
|
||||
)
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -331,6 +339,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFile,
|
||||
config=request.Model, # Use request.Model as the config/model_id
|
||||
subfolder="transformer",
|
||||
device_map=device_map,
|
||||
**transformer_kwargs,
|
||||
)
|
||||
|
||||
@@ -340,6 +349,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_id=request.Model,
|
||||
transformer=transformer,
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map,
|
||||
)
|
||||
else:
|
||||
# Single file but not GGUF - use standard single file loading
|
||||
@@ -348,6 +358,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_id=modelFile,
|
||||
from_single_file=True,
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map,
|
||||
)
|
||||
else:
|
||||
# Standard loading from pretrained
|
||||
@@ -355,7 +366,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
class_name="LTX2ImageToVideoPipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
variant=variant,
|
||||
device_map=device_map
|
||||
)
|
||||
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
@@ -380,6 +392,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFile,
|
||||
config=request.Model, # Use request.Model as the config/model_id
|
||||
subfolder="transformer",
|
||||
device_map=device_map,
|
||||
**transformer_kwargs,
|
||||
)
|
||||
|
||||
@@ -389,6 +402,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_id=request.Model,
|
||||
transformer=transformer,
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map,
|
||||
)
|
||||
else:
|
||||
# Single file but not GGUF - use standard single file loading
|
||||
@@ -397,6 +411,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
model_id=modelFile,
|
||||
from_single_file=True,
|
||||
torch_dtype=torchType,
|
||||
device_map=device_map,
|
||||
)
|
||||
else:
|
||||
# Standard loading from pretrained
|
||||
@@ -404,7 +419,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
class_name="LTX2Pipeline",
|
||||
model_id=request.Model,
|
||||
torch_dtype=torchType,
|
||||
variant=variant
|
||||
variant=variant,
|
||||
device_map=device_map
|
||||
)
|
||||
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
@@ -427,6 +443,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if not fromSingleFile:
|
||||
load_kwargs["use_safetensors"] = SAFETENSORS
|
||||
|
||||
# Add device_map for multi-GPU support (when TensorParallelSize > 1)
|
||||
if device_map:
|
||||
load_kwargs["device_map"] = device_map
|
||||
|
||||
# Determine pipeline class name - default to AutoPipelineForText2Image
|
||||
effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image"
|
||||
|
||||
@@ -529,6 +549,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
print(f"LoadModel: PipelineType from request: {request.PipelineType}", file=sys.stderr)
|
||||
|
||||
# Determine device_map for multi-GPU support based on TensorParallelSize
|
||||
# When TensorParallelSize > 1, use device_map='auto' to distribute model across GPUs
|
||||
device_map = None
|
||||
if hasattr(request, 'TensorParallelSize') and request.TensorParallelSize > 1:
|
||||
device_map = "auto"
|
||||
print(f"LoadModel: Multi-GPU mode enabled with TensorParallelSize={request.TensorParallelSize}, using device_map='auto'", file=sys.stderr)
|
||||
|
||||
# Load pipeline using dynamic loader
|
||||
# Special cases that require custom initialization are handled first
|
||||
self.pipe = self._load_pipeline(
|
||||
@@ -536,7 +563,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFile=modelFile,
|
||||
fromSingleFile=fromSingleFile,
|
||||
torchType=torchType,
|
||||
variant=variant
|
||||
variant=variant,
|
||||
device_map=device_map
|
||||
)
|
||||
|
||||
print(f"LoadModel: After loading - ltx2_pipeline: {self.ltx2_pipeline}, img2vid: {self.img2vid}, txt2vid: {self.txt2vid}, PipelineType: {self.PipelineType}", file=sys.stderr)
|
||||
@@ -561,7 +589,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
if request.ControlNet:
|
||||
self.controlnet = ControlNetModel.from_pretrained(
|
||||
request.ControlNet, torch_dtype=torchType, variant=variant
|
||||
request.ControlNet, torch_dtype=torchType, variant=variant, device_map=device_map
|
||||
)
|
||||
self.pipe.controlnet = self.controlnet
|
||||
else:
|
||||
@@ -600,7 +628,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights)
|
||||
|
||||
if device != "cpu":
|
||||
# Only move pipeline to device if NOT using device_map
|
||||
# device_map handles device placement automatically
|
||||
if device_map is None and device != "cpu":
|
||||
self.pipe.to(device)
|
||||
if self.controlnet:
|
||||
self.controlnet.to(device)
|
||||
|
||||
Reference in New Issue
Block a user