mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
mlx: update to 0.30.1 and align coordinator naming with MLX conventions
The Jaccl distributed backend requires MLX 0.30.1+, which includes the RDMA over Thunderbolt support. The previous minimum version (0.29.3) would fail at runtime with "The only valid values for backend are 'any', 'mpi' and 'ring' but 'jaccl' was provided." Bump MLX dependency to >=0.30.1 and rename ibv_coordinators to jaccl_coordinators to match MLX's naming conventions. This includes the environment variable change from MLX_IBV_COORDINATOR to MLX_JACCL_COORDINATOR.
This commit is contained in:
@@ -29,7 +29,7 @@ dependencies = [
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.29.3",
|
||||
"mlx>=0.30.1",
|
||||
"mlx-lm>=0.28.3",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
|
||||
@@ -8,8 +8,8 @@ from loguru import logger
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_coordinators,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -118,7 +118,7 @@ def place_instance(
|
||||
selected_cycle,
|
||||
cycle_digraph,
|
||||
)
|
||||
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
|
||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||
selected_cycle,
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
cycle_digraph=cycle_digraph,
|
||||
@@ -127,7 +127,7 @@ def place_instance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
ibv_coordinators=mlx_ibv_coordinators,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
|
||||
@@ -269,12 +269,12 @@ def _find_interface_name_for_ip(
|
||||
return None
|
||||
|
||||
|
||||
def get_mlx_ibv_coordinators(
|
||||
def get_mlx_jaccl_coordinators(
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
cycle_digraph: Topology,
|
||||
) -> dict[NodeId, str]:
|
||||
"""Get the coordinator addresses for MLX IBV (rank 0 device).
|
||||
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
|
||||
|
||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
|
||||
@@ -437,7 +437,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.ibv_coordinators is not None
|
||||
assert instance.jaccl_coordinators is not None
|
||||
|
||||
matrix = instance.ibv_devices
|
||||
assert len(matrix) == 3
|
||||
@@ -459,10 +459,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.ibv_coordinators) == 3
|
||||
assert len(instance.jaccl_coordinators) == 3
|
||||
for node_id in assigned_nodes:
|
||||
assert node_id in instance.ibv_coordinators
|
||||
coordinator = instance.ibv_coordinators[node_id]
|
||||
assert node_id in instance.jaccl_coordinators
|
||||
coordinator = instance.jaccl_coordinators[node_id]
|
||||
assert ":" in coordinator
|
||||
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
|
||||
if node_id == assigned_nodes[0]:
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_coordinators,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -265,7 +265,7 @@ def test_get_hosts_from_subgraph(
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_ibv_coordinators(
|
||||
def test_get_mlx_jaccl_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
@@ -357,7 +357,7 @@ def test_get_mlx_ibv_coordinators(
|
||||
cycle = [node_a, node_b, node_c]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_ibv_coordinators(
|
||||
coordinators = get_mlx_jaccl_coordinators(
|
||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||
)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
ibv_devices: list[list[str | None]]
|
||||
ibv_coordinators: dict[NodeId, str]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
|
||||
@@ -129,7 +129,7 @@ def mlx_distributed_init(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
@@ -138,13 +138,13 @@ def mlx_distributed_init(
|
||||
with open(devices_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
26
uv.lock
generated
26
uv.lock
generated
@@ -374,7 +374,7 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.33.4" },
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mlx", specifier = ">=0.29.3" },
|
||||
{ name = "mlx", specifier = ">=0.30.1" },
|
||||
{ name = "mlx-lm", specifier = ">=0.28.3" },
|
||||
{ name = "networkx", specifier = ">=3.5" },
|
||||
{ name = "protobuf", specifier = ">=6.32.0" },
|
||||
@@ -783,16 +783,22 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.29.3"
|
||||
version = "0.30.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/a2/078152b45aa8a23949a1b09601d0044f8bb4ab85e909e4475a440c21aaea/mlx-0.29.3-cp313-cp313-macosx_13_0_arm64.whl", hash = "sha256:d59eccf6a1e1e131becc5a3910504507862da3a4e9b7bd9e73a625515d767844", size = 549585, upload-time = "2025-10-17T19:17:01.872Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/bb/869eaac4efaae033c13db5fddd6a8907b5d667d135a35a2e482b1af402ee/mlx-0.29.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:6642aa0a6dc2242c024fb8274d00631a7e7ffbdcef26148afd299b877c1e6a4a", size = 549586, upload-time = "2025-10-17T19:16:57.844Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/76/196c248c2b2a471f795356564ad1d7dc40284160c8b66370ffadfd991fa1/mlx-0.29.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:ec0aef311fab10cb5f2c274afa6edf6c482636096a5f7886aba43676454aa462", size = 549586, upload-time = "2025-10-17T19:16:39.912Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/fd/c6f56cd87d48763ed63655ace627c06db9819eae7d43d132f40d4965947a/mlx-0.30.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743520758bc8261b2ed8f3b3dc96e4e9236769dd8f61fb17877c5e44037e2058", size = 593366, upload-time = "2025-12-18T01:55:46.786Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/53/96d8c48b21f91c4216b6d2ef6dfc10862e5fb0b811a2aaf02c96c78601de/mlx-0.30.1-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:fc9745bc1860ca60128e3a6d36157da06d936e2b4007a4dcba990b40202f598f", size = 593368, upload-time = "2025-12-18T01:55:48.363Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/ce/476c3b7d3a4153bd0e1c5af1f1b6c09a804b652bbed34072404b322c22e0/mlx-0.30.1-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:a1480399c67bb327a66c5527b73915132e3fcaae3bce9634e5c81ccad9f43229", size = 567561, upload-time = "2025-12-18T00:15:56.153Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/41/7ad1e639fd7dd1cf01a62c1c5b051024a859888c27504996e9d8380e6754/mlx-0.30.1-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:8e19850a4236a8e174f851f5789b8b62a8eb74f5a8fa49ad8ba286c5ddb5f9bf", size = 643122, upload-time = "2025-12-18T01:55:49.607Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/dc/72d3737c5b0662eb5e785d353dbc5e34d793d27b09b99e39993ee051bd19/mlx-0.30.1-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:1c8ed5bcd9f1910fca209e95859ac737e60b3e1954181b820fa269158f81049a", size = 687254, upload-time = "2025-12-18T01:55:51.239Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/cc/523448996247bb05d9d68e23bccf3dafdda660befb9330f6bd5fa13361e8/mlx-0.30.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:d34cc2c25b0ee41c1349f14650db760e282685339858e305453f62405c12bc1b", size = 596006, upload-time = "2025-12-18T01:55:52.463Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/0e/f9f2f9659c34c87be8f4167f6a1d6ed7e826f4889d20eecd4c0d8122f0e9/mlx-0.30.1-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:4e47d301e9095b87f0bda8827bfd6ffe744223aba5cee8f28e25894d647f5823", size = 596008, upload-time = "2025-12-18T01:55:54.02Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/a7/49e41fb141de95b6a376091a963c737839c9cda04e423c67f57460a50458/mlx-0.30.1-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:cfba13e2a52255d663a1ad62f0f83eb3991e42147edf9a8d38cdd224e48ca49b", size = 570406, upload-time = "2025-12-18T00:15:57.177Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/99/a43cb112167cf865c069f5e108ae42f5314663930ff3dd86c2d23d984191/mlx-0.30.1-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:bebfec377208eb29cc88aa86c897c7446aa0984838669e138f273f9225d627ff", size = 646461, upload-time = "2025-12-18T01:55:55.285Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -814,12 +820,12 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.29.3"
|
||||
version = "0.30.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/41/95/a00054a006df82bb1b5b8f666ae44a676b259146fadbff90fe654309fefc/mlx_metal-0.29.3-py3-none-macosx_13_0_arm64.whl", hash = "sha256:27b5a4d905202a71e84d9fd559ea0236813f6f960ef494e5cafe9c45df4c9d7c", size = 36817352, upload-time = "2025-10-17T19:19:25.801Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/d8/5ee91eac16dfcf0334103120b47d4abd8c890ccc0d73d3eee4770ce8810f/mlx_metal-0.29.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:f426d4b67f96b4d6f0ed50d5992933595aadb370dc3e9ed2410bafbc16229882", size = 36555573, upload-time = "2025-10-17T19:18:42.098Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/9a/39b7ecdf21cf2a39ced8d7933eed65c6cb38295cadfd0907dd1abd4d1ded/mlx_metal-0.29.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:106616f7f825851043c53d3dc186965c003985da9cbb6e5c034f35108fc1fc27", size = 36549163, upload-time = "2025-10-17T19:18:37.701Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/09/3f/0be35ddad7e13d8ecd33a9185895f9739bb00b96ef0cce36cf0405d4aec0/mlx_metal-0.30.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:e7e92c6bdbd7ac8083f528a4c6640552ae106a57bb3d99856ac10a32e93a4b5e", size = 36864966, upload-time = "2025-12-18T01:55:31.473Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/1f/c0bddd0d5bf3871411aabe32121e09e1b7cdbece8917a49d5a442310e3e5/mlx_metal-0.30.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:bb50f57418af7fc3c42a2da2c4bde0e7ab7ac0b997de1f6f642a6680ac65d626", size = 36859011, upload-time = "2025-12-18T01:55:34.541Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/67/b3/73cc2f584ac612a476096d35a61eed75ee7ed8b4e320b0c36cf60a14d4eb/mlx_metal-0.30.1-py3-none-macosx_26_0_arm64.whl", hash = "sha256:e0b151a0053ac00b4226710bfb6dbf54b87283fb01e10fb3877f9ea969f680aa", size = 44981160, upload-time = "2025-12-18T00:15:47.518Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user