Files
LocalAI/backend/python/mlx-distributed/coordinator.py
Ettore Di Giacinto a026277ab9 feat(mlx-distributed): add new MLX-distributed backend (#8801)
* feat(mlx-distributed): add new MLX-distributed backend

Add new MLX distributed backend with support for both TCP and RDMA for
model sharding.

This implementation ties in the discovery implementation already in
place, and re-uses the same P2P mechanism for the TCP MLX-distributed
inferencing.

The Auto-parallel implementation is inspired by Exo's
ones (who have been added to acknowledgement for the great work!)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* expose a CLI to facilitate backend starting

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat: make manual rank0 configurable via model configs

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add missing features from mlx backend

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Apply suggestion from @mudler

Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-09 17:29:32 +01:00

105 lines
3.7 KiB
Python

"""
Distributed coordination using MLX distributed primitives.
Rank 0 broadcasts commands and tokens to all ranks via all_sum/all_gather.
Worker ranks wait in a loop for commands from rank 0.
"""
import json
import struct
import mlx.core as mx
CMD_IDLE = 0
CMD_GENERATE = 1
CMD_LOAD_MODEL = 2
CMD_SHUTDOWN = -1
class DistributedCoordinator:
def __init__(self, group):
self.group = group
self.rank = group.rank()
self.world_size = group.size()
def broadcast_command(self, cmd, payload_size=0):
"""Rank 0 broadcasts a command to all ranks.
Uses all_sum with only rank 0 providing non-zero values so every
rank receives the same command array.
"""
if self.rank == 0:
cmd_array = mx.array([cmd, payload_size], dtype=mx.int32)
else:
cmd_array = mx.zeros((2,), dtype=mx.int32)
result = mx.distributed.all_sum(cmd_array, group=self.group)
mx.eval(result)
return int(result[0].item()), int(result[1].item())
def broadcast_tokens(self, tokens):
"""Broadcast input token ids from rank 0 to all ranks.
Rank 0 provides the real token array; other ranks provide zeros of the
same shape. ``all_sum`` ensures every rank ends up with identical data.
"""
if self.rank == 0:
token_array = mx.array(tokens, dtype=mx.int32)
else:
token_array = mx.zeros((len(tokens),), dtype=mx.int32)
result = mx.distributed.all_sum(token_array, group=self.group)
mx.eval(result)
return result
def broadcast_token_count(self, count):
"""Broadcast the number of tokens so workers can prepare a buffer."""
if self.rank == 0:
count_array = mx.array([count], dtype=mx.int32)
else:
count_array = mx.zeros((1,), dtype=mx.int32)
result = mx.distributed.all_sum(count_array, group=self.group)
mx.eval(result)
return int(result[0].item())
def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0):
"""Broadcast generation parameters from rank 0."""
if self.rank == 0:
params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32)
else:
params = mx.zeros((3,), dtype=mx.float32)
result = mx.distributed.all_sum(params, group=self.group)
mx.eval(result)
return {
"max_tokens": int(result[0].item()),
"temperature": float(result[1].item()),
"top_p": float(result[2].item()),
}
def wait_for_command(self):
"""Worker ranks block here until rank 0 broadcasts a command."""
return self.broadcast_command(CMD_IDLE, 0)
def broadcast_model_name(self, model_name=""):
"""Broadcast model name string from rank 0 to all ranks.
Encodes the model name as int32 codepoints so it can travel via
all_sum.
"""
if self.rank == 0:
encoded = [ord(c) for c in model_name]
# First broadcast the length
length = self.broadcast_token_count(len(encoded))
if length > 0:
name_array = mx.array(encoded, dtype=mx.int32)
result = mx.distributed.all_sum(name_array, group=self.group)
mx.eval(result)
return model_name
return ""
else:
length = self.broadcast_token_count(0)
if length > 0:
name_array = mx.zeros((length,), dtype=mx.int32)
result = mx.distributed.all_sum(name_array, group=self.group)
mx.eval(result)
return "".join(chr(int(c.item())) for c in result)
return ""