mlx: update to 0.30.1 and align coordinator naming with MLX conventions

The Jaccl distributed backend requires MLX 0.30.1+, which includes the
RDMA over Thunderbolt support. The previous minimum version (0.29.3)
would fail at runtime with "The only valid values for backend are
'any', 'mpi' and 'ring' but 'jaccl' was provided."

Bump MLX dependency to >=0.30.1 and rename ibv_coordinators to
jaccl_coordinators to match MLX's naming conventions. This includes
the environment variable change from MLX_IBV_COORDINATOR to
MLX_JACCL_COORDINATOR.
This commit is contained in:
Jake Hillion
2025-12-23 17:40:11 +00:00
parent 9afc1043ef
commit f186cdb8d5
8 changed files with 34 additions and 28 deletions

View File

@@ -29,7 +29,7 @@ dependencies = [
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"bidict>=0.23.1",
"mlx>=0.29.3",
"mlx>=0.30.1",
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",

View File

@@ -8,8 +8,8 @@ from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
@@ -118,7 +118,7 @@ def place_instance(
selected_cycle,
cycle_digraph,
)
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
@@ -127,7 +127,7 @@ def place_instance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
ibv_coordinators=mlx_ibv_coordinators,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)

View File

@@ -269,12 +269,12 @@ def _find_interface_name_for_ip(
return None
def get_mlx_ibv_coordinators(
def get_mlx_jaccl_coordinators(
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX IBV (rank 0 device).
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.

View File

@@ -437,7 +437,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.ibv_coordinators is not None
assert instance.jaccl_coordinators is not None
matrix = instance.ibv_devices
assert len(matrix) == 3
@@ -459,10 +459,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.ibv_coordinators) == 3
assert len(instance.jaccl_coordinators) == 3
for node_id in assigned_nodes:
assert node_id in instance.ibv_coordinators
coordinator = instance.ibv_coordinators[node_id]
assert node_id in instance.jaccl_coordinators
coordinator = instance.jaccl_coordinators[node_id]
assert ":" in coordinator
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
if node_id == assigned_nodes[0]:

View File

@@ -5,7 +5,7 @@ import pytest
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
@@ -265,7 +265,7 @@ def test_get_hosts_from_subgraph(
assert expected_host in hosts
def test_get_mlx_ibv_coordinators(
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
@@ -357,7 +357,7 @@ def test_get_mlx_ibv_coordinators(
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_ibv_coordinators(
coordinators = get_mlx_jaccl_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
)

View File

@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
ibv_coordinators: dict[NodeId, str]
jaccl_coordinators: dict[NodeId, str]
# TODO: Single node instance

View File

@@ -129,7 +129,7 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
@@ -138,13 +138,13 @@ def mlx_distributed_init(
with open(devices_file, "w") as f:
_ = f.write(ibv_devices_json)
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")

26
uv.lock generated
View File

@@ -374,7 +374,7 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", specifier = ">=0.29.3" },
{ name = "mlx", specifier = ">=0.30.1" },
{ name = "mlx-lm", specifier = ">=0.28.3" },
{ name = "networkx", specifier = ">=3.5" },
{ name = "protobuf", specifier = ">=6.32.0" },
@@ -783,16 +783,22 @@ wheels = [
[[package]]
name = "mlx"
version = "0.29.3"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/fe/a2/078152b45aa8a23949a1b09601d0044f8bb4ab85e909e4475a440c21aaea/mlx-0.29.3-cp313-cp313-macosx_13_0_arm64.whl", hash = "sha256:d59eccf6a1e1e131becc5a3910504507862da3a4e9b7bd9e73a625515d767844", size = 549585, upload-time = "2025-10-17T19:17:01.872Z" },
{ url = "https://files.pythonhosted.org/packages/ae/bb/869eaac4efaae033c13db5fddd6a8907b5d667d135a35a2e482b1af402ee/mlx-0.29.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:6642aa0a6dc2242c024fb8274d00631a7e7ffbdcef26148afd299b877c1e6a4a", size = 549586, upload-time = "2025-10-17T19:16:57.844Z" },
{ url = "https://files.pythonhosted.org/packages/ad/76/196c248c2b2a471f795356564ad1d7dc40284160c8b66370ffadfd991fa1/mlx-0.29.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:ec0aef311fab10cb5f2c274afa6edf6c482636096a5f7886aba43676454aa462", size = 549586, upload-time = "2025-10-17T19:16:39.912Z" },
{ url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" },
{ url = "https://files.pythonhosted.org/packages/f9/fd/c6f56cd87d48763ed63655ace627c06db9819eae7d43d132f40d4965947a/mlx-0.30.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743520758bc8261b2ed8f3b3dc96e4e9236769dd8f61fb17877c5e44037e2058", size = 593366, upload-time = "2025-12-18T01:55:46.786Z" },
{ url = "https://files.pythonhosted.org/packages/dc/53/96d8c48b21f91c4216b6d2ef6dfc10862e5fb0b811a2aaf02c96c78601de/mlx-0.30.1-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:fc9745bc1860ca60128e3a6d36157da06d936e2b4007a4dcba990b40202f598f", size = 593368, upload-time = "2025-12-18T01:55:48.363Z" },
{ url = "https://files.pythonhosted.org/packages/70/ce/476c3b7d3a4153bd0e1c5af1f1b6c09a804b652bbed34072404b322c22e0/mlx-0.30.1-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:a1480399c67bb327a66c5527b73915132e3fcaae3bce9634e5c81ccad9f43229", size = 567561, upload-time = "2025-12-18T00:15:56.153Z" },
{ url = "https://files.pythonhosted.org/packages/33/41/7ad1e639fd7dd1cf01a62c1c5b051024a859888c27504996e9d8380e6754/mlx-0.30.1-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:8e19850a4236a8e174f851f5789b8b62a8eb74f5a8fa49ad8ba286c5ddb5f9bf", size = 643122, upload-time = "2025-12-18T01:55:49.607Z" },
{ url = "https://files.pythonhosted.org/packages/d0/dc/72d3737c5b0662eb5e785d353dbc5e34d793d27b09b99e39993ee051bd19/mlx-0.30.1-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:1c8ed5bcd9f1910fca209e95859ac737e60b3e1954181b820fa269158f81049a", size = 687254, upload-time = "2025-12-18T01:55:51.239Z" },
{ url = "https://files.pythonhosted.org/packages/9b/cc/523448996247bb05d9d68e23bccf3dafdda660befb9330f6bd5fa13361e8/mlx-0.30.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:d34cc2c25b0ee41c1349f14650db760e282685339858e305453f62405c12bc1b", size = 596006, upload-time = "2025-12-18T01:55:52.463Z" },
{ url = "https://files.pythonhosted.org/packages/23/0e/f9f2f9659c34c87be8f4167f6a1d6ed7e826f4889d20eecd4c0d8122f0e9/mlx-0.30.1-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:4e47d301e9095b87f0bda8827bfd6ffe744223aba5cee8f28e25894d647f5823", size = 596008, upload-time = "2025-12-18T01:55:54.02Z" },
{ url = "https://files.pythonhosted.org/packages/56/a7/49e41fb141de95b6a376091a963c737839c9cda04e423c67f57460a50458/mlx-0.30.1-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:cfba13e2a52255d663a1ad62f0f83eb3991e42147edf9a8d38cdd224e48ca49b", size = 570406, upload-time = "2025-12-18T00:15:57.177Z" },
{ url = "https://files.pythonhosted.org/packages/73/99/a43cb112167cf865c069f5e108ae42f5314663930ff3dd86c2d23d984191/mlx-0.30.1-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:bebfec377208eb29cc88aa86c897c7446aa0984838669e138f273f9225d627ff", size = 646461, upload-time = "2025-12-18T01:55:55.285Z" },
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
]
[[package]]
@@ -814,12 +820,12 @@ wheels = [
[[package]]
name = "mlx-metal"
version = "0.29.3"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/41/95/a00054a006df82bb1b5b8f666ae44a676b259146fadbff90fe654309fefc/mlx_metal-0.29.3-py3-none-macosx_13_0_arm64.whl", hash = "sha256:27b5a4d905202a71e84d9fd559ea0236813f6f960ef494e5cafe9c45df4c9d7c", size = 36817352, upload-time = "2025-10-17T19:19:25.801Z" },
{ url = "https://files.pythonhosted.org/packages/c0/d8/5ee91eac16dfcf0334103120b47d4abd8c890ccc0d73d3eee4770ce8810f/mlx_metal-0.29.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:f426d4b67f96b4d6f0ed50d5992933595aadb370dc3e9ed2410bafbc16229882", size = 36555573, upload-time = "2025-10-17T19:18:42.098Z" },
{ url = "https://files.pythonhosted.org/packages/cd/9a/39b7ecdf21cf2a39ced8d7933eed65c6cb38295cadfd0907dd1abd4d1ded/mlx_metal-0.29.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:106616f7f825851043c53d3dc186965c003985da9cbb6e5c034f35108fc1fc27", size = 36549163, upload-time = "2025-10-17T19:18:37.701Z" },
{ url = "https://files.pythonhosted.org/packages/09/3f/0be35ddad7e13d8ecd33a9185895f9739bb00b96ef0cce36cf0405d4aec0/mlx_metal-0.30.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:e7e92c6bdbd7ac8083f528a4c6640552ae106a57bb3d99856ac10a32e93a4b5e", size = 36864966, upload-time = "2025-12-18T01:55:31.473Z" },
{ url = "https://files.pythonhosted.org/packages/1e/1f/c0bddd0d5bf3871411aabe32121e09e1b7cdbece8917a49d5a442310e3e5/mlx_metal-0.30.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:bb50f57418af7fc3c42a2da2c4bde0e7ab7ac0b997de1f6f642a6680ac65d626", size = 36859011, upload-time = "2025-12-18T01:55:34.541Z" },
{ url = "https://files.pythonhosted.org/packages/67/b3/73cc2f584ac612a476096d35a61eed75ee7ed8b4e320b0c36cf60a14d4eb/mlx_metal-0.30.1-py3-none-macosx_26_0_arm64.whl", hash = "sha256:e0b151a0053ac00b4226710bfb6dbf54b87283fb01e10fb3877f9ea969f680aa", size = 44981160, upload-time = "2025-12-18T00:15:47.518Z" },
]
[[package]]