mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-01 13:42:20 -04:00
* feat(mlx-distributed): add new MLX-distributed backend Add new MLX distributed backend with support for both TCP and RDMA for model sharding. This implementation ties in the discovery implementation already in place, and re-uses the same P2P mechanism for the TCP MLX-distributed inferencing. The Auto-parallel implementation is inspired by Exo's ones (who have been added to acknowledgement for the great work!) Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * expose a CLI to facilitate backend starting Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: make manual rank0 configurable via model configs Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add missing features from mlx backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Apply suggestion from @mudler Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
105 lines
3.7 KiB
Python
105 lines
3.7 KiB
Python
"""
|
|
Distributed coordination using MLX distributed primitives.
|
|
|
|
Rank 0 broadcasts commands and tokens to all ranks via all_sum/all_gather.
|
|
Worker ranks wait in a loop for commands from rank 0.
|
|
"""
|
|
import json
|
|
import struct
|
|
|
|
import mlx.core as mx
|
|
|
|
|
|
CMD_IDLE = 0
|
|
CMD_GENERATE = 1
|
|
CMD_LOAD_MODEL = 2
|
|
CMD_SHUTDOWN = -1
|
|
|
|
|
|
class DistributedCoordinator:
|
|
def __init__(self, group):
|
|
self.group = group
|
|
self.rank = group.rank()
|
|
self.world_size = group.size()
|
|
|
|
def broadcast_command(self, cmd, payload_size=0):
|
|
"""Rank 0 broadcasts a command to all ranks.
|
|
|
|
Uses all_sum with only rank 0 providing non-zero values so every
|
|
rank receives the same command array.
|
|
"""
|
|
if self.rank == 0:
|
|
cmd_array = mx.array([cmd, payload_size], dtype=mx.int32)
|
|
else:
|
|
cmd_array = mx.zeros((2,), dtype=mx.int32)
|
|
result = mx.distributed.all_sum(cmd_array, group=self.group)
|
|
mx.eval(result)
|
|
return int(result[0].item()), int(result[1].item())
|
|
|
|
def broadcast_tokens(self, tokens):
|
|
"""Broadcast input token ids from rank 0 to all ranks.
|
|
|
|
Rank 0 provides the real token array; other ranks provide zeros of the
|
|
same shape. ``all_sum`` ensures every rank ends up with identical data.
|
|
"""
|
|
if self.rank == 0:
|
|
token_array = mx.array(tokens, dtype=mx.int32)
|
|
else:
|
|
token_array = mx.zeros((len(tokens),), dtype=mx.int32)
|
|
result = mx.distributed.all_sum(token_array, group=self.group)
|
|
mx.eval(result)
|
|
return result
|
|
|
|
def broadcast_token_count(self, count):
|
|
"""Broadcast the number of tokens so workers can prepare a buffer."""
|
|
if self.rank == 0:
|
|
count_array = mx.array([count], dtype=mx.int32)
|
|
else:
|
|
count_array = mx.zeros((1,), dtype=mx.int32)
|
|
result = mx.distributed.all_sum(count_array, group=self.group)
|
|
mx.eval(result)
|
|
return int(result[0].item())
|
|
|
|
def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0):
|
|
"""Broadcast generation parameters from rank 0."""
|
|
if self.rank == 0:
|
|
params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32)
|
|
else:
|
|
params = mx.zeros((3,), dtype=mx.float32)
|
|
result = mx.distributed.all_sum(params, group=self.group)
|
|
mx.eval(result)
|
|
return {
|
|
"max_tokens": int(result[0].item()),
|
|
"temperature": float(result[1].item()),
|
|
"top_p": float(result[2].item()),
|
|
}
|
|
|
|
def wait_for_command(self):
|
|
"""Worker ranks block here until rank 0 broadcasts a command."""
|
|
return self.broadcast_command(CMD_IDLE, 0)
|
|
|
|
def broadcast_model_name(self, model_name=""):
|
|
"""Broadcast model name string from rank 0 to all ranks.
|
|
|
|
Encodes the model name as int32 codepoints so it can travel via
|
|
all_sum.
|
|
"""
|
|
if self.rank == 0:
|
|
encoded = [ord(c) for c in model_name]
|
|
# First broadcast the length
|
|
length = self.broadcast_token_count(len(encoded))
|
|
if length > 0:
|
|
name_array = mx.array(encoded, dtype=mx.int32)
|
|
result = mx.distributed.all_sum(name_array, group=self.group)
|
|
mx.eval(result)
|
|
return model_name
|
|
return ""
|
|
else:
|
|
length = self.broadcast_token_count(0)
|
|
if length > 0:
|
|
name_array = mx.zeros((length,), dtype=mx.int32)
|
|
result = mx.distributed.all_sum(name_array, group=self.group)
|
|
mx.eval(result)
|
|
return "".join(chr(int(c.item())) for c in result)
|
|
return ""
|