Files
LocalAI/backend/python/mlx-distributed/sharding.py
Ettore Di Giacinto a026277ab9 feat(mlx-distributed): add new MLX-distributed backend (#8801)
* feat(mlx-distributed): add new MLX-distributed backend

Add new MLX distributed backend with support for both TCP and RDMA for
model sharding.

This implementation ties in the discovery implementation already in
place, and re-uses the same P2P mechanism for the TCP MLX-distributed
inferencing.

The Auto-parallel implementation is inspired by Exo's
ones (who have been added to acknowledgement for the great work!)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* expose a CLI to facilitate backend starting

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: make manual rank0 configurable via model configs

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add missing features from mlx backend

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Apply suggestion from @mudler

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-09 17:29:32 +01:00

137 lines
4.5 KiB
Python

"""
Auto-parallelism for MLX distributed inference.
Provides pipeline parallelism (Ring backend) by wrapping model layers with
distributed send/recv operations. Ported from exo's auto_parallel.py with
simplifications for LocalAI's use case.
"""
import mlx.core as mx
import mlx.nn as nn
class PipelineFirstLayer(nn.Module):
"""Wraps the first layer on each rank to receive from the previous rank."""
def __init__(self, original_layer, rank, group):
super().__init__()
dict.__setitem__(self, "_original_layer", original_layer)
self.rank = rank
self.group = group
@property
def original_layer(self):
return self["_original_layer"]
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self["_original_layer"], name)
def __call__(self, x, *args, **kwargs):
if self.rank != 0:
mx.eval(x)
x = mx.distributed.recv_like(x, self.rank - 1, group=self.group)
mx.eval(x)
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(nn.Module):
"""Wraps the last layer on each rank to send to the next rank."""
def __init__(self, original_layer, rank, world_size, group):
super().__init__()
dict.__setitem__(self, "_original_layer", original_layer)
self.rank = rank
self.world_size = world_size
self.group = group
@property
def original_layer(self):
return self["_original_layer"]
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self["_original_layer"], name)
def __call__(self, x, *args, **kwargs):
output = self.original_layer(x, *args, **kwargs)
mx.eval(output)
if self.rank != self.world_size - 1:
output = mx.distributed.send(
output, (self.rank + 1) % self.world_size, group=self.group
)
mx.eval(output)
# Gather output from all ranks so every rank has the final result
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
]
mx.eval(output)
return output
def get_inner_model(model):
"""Get the inner model (model.model or model.transformer)."""
for attr in ("model", "transformer"):
inner = getattr(model, attr, None)
if isinstance(inner, nn.Module):
# Some models have model.model (e.g. language_model.model)
inner_inner = getattr(inner, "model", None)
if isinstance(inner_inner, nn.Module):
return inner_inner
return inner
raise ValueError("Model must have a 'model' or 'transformer' attribute")
def get_layers(inner_model):
"""Get the list of transformer layers."""
for attr in ("layers", "h"):
layers = getattr(inner_model, attr, None)
if layers is not None:
return layers
raise ValueError("Model must have a 'layers' or 'h' attribute")
def pipeline_auto_parallel(model, group, start_layer=None, end_layer=None):
"""Apply pipeline parallelism to a model.
Each rank only keeps its slice of layers. The first layer receives from
the previous rank, and the last layer sends to the next rank.
Args:
model: The MLX model (must have model.layers or similar)
group: The distributed group
start_layer: First layer index for this rank (auto-computed if None)
end_layer: Last layer index (exclusive) for this rank (auto-computed if None)
"""
rank = group.rank()
world_size = group.size()
inner = get_inner_model(model)
layers = list(get_layers(inner))
total_layers = len(layers)
if start_layer is None or end_layer is None:
layers_per_rank = total_layers // world_size
remainder = total_layers % world_size
start_layer = rank * layers_per_rank + min(rank, remainder)
end_layer = start_layer + layers_per_rank + (1 if rank < remainder else 0)
layers = layers[start_layer:end_layer]
for layer in layers:
mx.eval(layer)
# Wrap first and last layers
layers[0] = PipelineFirstLayer(layers[0], rank, group=group)
layers[-1] = PipelineLastLayer(layers[-1], rank, world_size, group=group)
# Replace layers on the inner model
if hasattr(inner, "layers"):
inner.layers = layers
elif hasattr(inner, "h"):
inner.h = layers
return model