From f186cdb8d58d7d7e641253dec43589af78be8a74 Mon Sep 17 00:00:00 2001 From: Jake Hillion Date: Tue, 23 Dec 2025 17:40:11 +0000 Subject: [PATCH] mlx: update to 0.30.1 and align coordinator naming with MLX conventions The Jaccl distributed backend requires MLX 0.30.1+, which includes the RDMA over Thunderbolt support. The previous minimum version (0.29.3) would fail at runtime with "The only valid values for backend are 'any', 'mpi' and 'ring' but 'jaccl' was provided." Bump MLX dependency to >=0.30.1 and rename ibv_coordinators to jaccl_coordinators to match MLX's naming conventions. This includes the environment variable change from MLX_IBV_COORDINATOR to MLX_JACCL_COORDINATOR. --- pyproject.toml | 2 +- src/exo/master/placement.py | 6 ++--- src/exo/master/placement_utils.py | 4 +-- src/exo/master/tests/test_placement.py | 8 +++--- src/exo/master/tests/test_placement_utils.py | 6 ++--- src/exo/shared/types/worker/instances.py | 2 +- src/exo/worker/engines/mlx/utils_mlx.py | 8 +++--- uv.lock | 26 ++++++++++++-------- 8 files changed, 34 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3c007f09..7393cb30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "exo_pyo3_bindings", # rust bindings "anyio==4.11.0", "bidict>=0.23.1", - "mlx>=0.29.3", + "mlx>=0.30.1", "mlx-lm>=0.28.3", "tiktoken>=0.12.0", # required for kimi k2 tokenizer "hypercorn>=0.18.0", diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index e580c254..68567b87 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -8,8 +8,8 @@ from loguru import logger from exo.master.placement_utils import ( filter_cycles_by_memory, get_hosts_from_subgraph, - get_mlx_ibv_coordinators, get_mlx_ibv_devices_matrix, + get_mlx_jaccl_coordinators, get_shard_assignments, get_smallest_cycles, ) @@ -118,7 +118,7 @@ def place_instance( selected_cycle, cycle_digraph, ) - mlx_ibv_coordinators = get_mlx_ibv_coordinators( + mlx_jaccl_coordinators = get_mlx_jaccl_coordinators( selected_cycle, coordinator_port=random_ephemeral_port(), cycle_digraph=cycle_digraph, @@ -127,7 +127,7 @@ def place_instance( instance_id=instance_id, shard_assignments=shard_assignments, ibv_devices=mlx_ibv_devices, - ibv_coordinators=mlx_ibv_coordinators, + jaccl_coordinators=mlx_jaccl_coordinators, ) case InstanceMeta.MlxRing: hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph) diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 24461b42..f4acfb52 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -269,12 +269,12 @@ def _find_interface_name_for_ip( return None -def get_mlx_ibv_coordinators( +def get_mlx_jaccl_coordinators( selected_cycle: list[NodeInfo], coordinator_port: int, cycle_digraph: Topology, ) -> dict[NodeId, str]: - """Get the coordinator addresses for MLX IBV (rank 0 device). + """Get the coordinator addresses for MLX Jaccl (rank 0 device). Select an IP address that each node can reach for the rank 0 node. Returns address in format "X.X.X.X:PORT" per node. diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index c688e8ff..9d410275 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -437,7 +437,7 @@ def test_tensor_rdma_backend_connectivity_matrix( assert isinstance(instance, MlxJacclInstance) assert instance.ibv_devices is not None - assert instance.ibv_coordinators is not None + assert instance.jaccl_coordinators is not None matrix = instance.ibv_devices assert len(matrix) == 3 @@ -459,10 +459,10 @@ def test_tensor_rdma_backend_connectivity_matrix( assert matrix[idx_c][idx_a] == "rdma_en3" # Verify coordinators are set for all nodes - assert len(instance.ibv_coordinators) == 3 + assert len(instance.jaccl_coordinators) == 3 for node_id in assigned_nodes: - assert node_id in instance.ibv_coordinators - coordinator = instance.ibv_coordinators[node_id] + assert node_id in instance.jaccl_coordinators + coordinator = instance.jaccl_coordinators[node_id] assert ":" in coordinator # Rank 0 node should use 0.0.0.0, others should use connection-specific IPs if node_id == assigned_nodes[0]: diff --git a/src/exo/master/tests/test_placement_utils.py b/src/exo/master/tests/test_placement_utils.py index ff6de72c..7a0d33be 100644 --- a/src/exo/master/tests/test_placement_utils.py +++ b/src/exo/master/tests/test_placement_utils.py @@ -5,7 +5,7 @@ import pytest from exo.master.placement_utils import ( filter_cycles_by_memory, get_hosts_from_subgraph, - get_mlx_ibv_coordinators, + get_mlx_jaccl_coordinators, get_shard_assignments, get_smallest_cycles, ) @@ -265,7 +265,7 @@ def test_get_hosts_from_subgraph( assert expected_host in hosts -def test_get_mlx_ibv_coordinators( +def test_get_mlx_jaccl_coordinators( topology: Topology, create_node: Callable[[int, NodeId | None], NodeInfo], create_connection: Callable[[NodeId, NodeId, int | None], Connection], @@ -357,7 +357,7 @@ def test_get_mlx_ibv_coordinators( cycle = [node_a, node_b, node_c] # act - coordinators = get_mlx_ibv_coordinators( + coordinators = get_mlx_jaccl_coordinators( cycle, coordinator_port=5000, cycle_digraph=topology ) diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index ea8e7887..6bb1a75e 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance): class MlxJacclInstance(BaseInstance): ibv_devices: list[list[str | None]] - ibv_coordinators: dict[NodeId, str] + jaccl_coordinators: dict[NodeId, str] # TODO: Single node instance diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 19d565ca..71114282 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -129,7 +129,7 @@ def mlx_distributed_init( group = mx.distributed.init(backend="ring", strict=True) case MlxJacclInstance( - ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators + ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators ): # Use RDMA connectivity matrix devices_file = f"./hosts_{rank}.json" @@ -138,13 +138,13 @@ def mlx_distributed_init( with open(devices_file, "w") as f: _ = f.write(ibv_devices_json) - ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id] + jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id] logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}") - logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}") + logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}") os.environ["MLX_IBV_DEVICES"] = devices_file os.environ["MLX_RANK"] = str(rank) - os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator + os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator group = mx.distributed.init(backend="jaccl", strict=True) logger.info(f"Rank {rank} mlx distributed initialization complete") diff --git a/uv.lock b/uv.lock index 50884363..6c2bf775 100644 --- a/uv.lock +++ b/uv.lock @@ -374,7 +374,7 @@ requires-dist = [ { name = "huggingface-hub", specifier = ">=0.33.4" }, { name = "hypercorn", specifier = ">=0.18.0" }, { name = "loguru", specifier = ">=0.7.3" }, - { name = "mlx", specifier = ">=0.29.3" }, + { name = "mlx", specifier = ">=0.30.1" }, { name = "mlx-lm", specifier = ">=0.28.3" }, { name = "networkx", specifier = ">=3.5" }, { name = "protobuf", specifier = ">=6.32.0" }, @@ -783,16 +783,22 @@ wheels = [ [[package]] name = "mlx" -version = "0.29.3" +version = "0.30.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mlx-metal", marker = "sys_platform == 'darwin'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/a2/078152b45aa8a23949a1b09601d0044f8bb4ab85e909e4475a440c21aaea/mlx-0.29.3-cp313-cp313-macosx_13_0_arm64.whl", hash = "sha256:d59eccf6a1e1e131becc5a3910504507862da3a4e9b7bd9e73a625515d767844", size = 549585, upload-time = "2025-10-17T19:17:01.872Z" }, - { url = "https://files.pythonhosted.org/packages/ae/bb/869eaac4efaae033c13db5fddd6a8907b5d667d135a35a2e482b1af402ee/mlx-0.29.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:6642aa0a6dc2242c024fb8274d00631a7e7ffbdcef26148afd299b877c1e6a4a", size = 549586, upload-time = "2025-10-17T19:16:57.844Z" }, - { url = "https://files.pythonhosted.org/packages/ad/76/196c248c2b2a471f795356564ad1d7dc40284160c8b66370ffadfd991fa1/mlx-0.29.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:ec0aef311fab10cb5f2c274afa6edf6c482636096a5f7886aba43676454aa462", size = 549586, upload-time = "2025-10-17T19:16:39.912Z" }, - { url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" }, + { url = "https://files.pythonhosted.org/packages/f9/fd/c6f56cd87d48763ed63655ace627c06db9819eae7d43d132f40d4965947a/mlx-0.30.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743520758bc8261b2ed8f3b3dc96e4e9236769dd8f61fb17877c5e44037e2058", size = 593366, upload-time = "2025-12-18T01:55:46.786Z" }, + { url = "https://files.pythonhosted.org/packages/dc/53/96d8c48b21f91c4216b6d2ef6dfc10862e5fb0b811a2aaf02c96c78601de/mlx-0.30.1-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:fc9745bc1860ca60128e3a6d36157da06d936e2b4007a4dcba990b40202f598f", size = 593368, upload-time = "2025-12-18T01:55:48.363Z" }, + { url = "https://files.pythonhosted.org/packages/70/ce/476c3b7d3a4153bd0e1c5af1f1b6c09a804b652bbed34072404b322c22e0/mlx-0.30.1-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:a1480399c67bb327a66c5527b73915132e3fcaae3bce9634e5c81ccad9f43229", size = 567561, upload-time = "2025-12-18T00:15:56.153Z" }, + { url = "https://files.pythonhosted.org/packages/33/41/7ad1e639fd7dd1cf01a62c1c5b051024a859888c27504996e9d8380e6754/mlx-0.30.1-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:8e19850a4236a8e174f851f5789b8b62a8eb74f5a8fa49ad8ba286c5ddb5f9bf", size = 643122, upload-time = "2025-12-18T01:55:49.607Z" }, + { url = "https://files.pythonhosted.org/packages/d0/dc/72d3737c5b0662eb5e785d353dbc5e34d793d27b09b99e39993ee051bd19/mlx-0.30.1-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:1c8ed5bcd9f1910fca209e95859ac737e60b3e1954181b820fa269158f81049a", size = 687254, upload-time = "2025-12-18T01:55:51.239Z" }, + { url = "https://files.pythonhosted.org/packages/9b/cc/523448996247bb05d9d68e23bccf3dafdda660befb9330f6bd5fa13361e8/mlx-0.30.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:d34cc2c25b0ee41c1349f14650db760e282685339858e305453f62405c12bc1b", size = 596006, upload-time = "2025-12-18T01:55:52.463Z" }, + { url = "https://files.pythonhosted.org/packages/23/0e/f9f2f9659c34c87be8f4167f6a1d6ed7e826f4889d20eecd4c0d8122f0e9/mlx-0.30.1-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:4e47d301e9095b87f0bda8827bfd6ffe744223aba5cee8f28e25894d647f5823", size = 596008, upload-time = "2025-12-18T01:55:54.02Z" }, + { url = "https://files.pythonhosted.org/packages/56/a7/49e41fb141de95b6a376091a963c737839c9cda04e423c67f57460a50458/mlx-0.30.1-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:cfba13e2a52255d663a1ad62f0f83eb3991e42147edf9a8d38cdd224e48ca49b", size = 570406, upload-time = "2025-12-18T00:15:57.177Z" }, + { url = "https://files.pythonhosted.org/packages/73/99/a43cb112167cf865c069f5e108ae42f5314663930ff3dd86c2d23d984191/mlx-0.30.1-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:bebfec377208eb29cc88aa86c897c7446aa0984838669e138f273f9225d627ff", size = 646461, upload-time = "2025-12-18T01:55:55.285Z" }, + { url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" }, ] [[package]] @@ -814,12 +820,12 @@ wheels = [ [[package]] name = "mlx-metal" -version = "0.29.3" +version = "0.30.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/95/a00054a006df82bb1b5b8f666ae44a676b259146fadbff90fe654309fefc/mlx_metal-0.29.3-py3-none-macosx_13_0_arm64.whl", hash = "sha256:27b5a4d905202a71e84d9fd559ea0236813f6f960ef494e5cafe9c45df4c9d7c", size = 36817352, upload-time = "2025-10-17T19:19:25.801Z" }, - { url = "https://files.pythonhosted.org/packages/c0/d8/5ee91eac16dfcf0334103120b47d4abd8c890ccc0d73d3eee4770ce8810f/mlx_metal-0.29.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:f426d4b67f96b4d6f0ed50d5992933595aadb370dc3e9ed2410bafbc16229882", size = 36555573, upload-time = "2025-10-17T19:18:42.098Z" }, - { url = "https://files.pythonhosted.org/packages/cd/9a/39b7ecdf21cf2a39ced8d7933eed65c6cb38295cadfd0907dd1abd4d1ded/mlx_metal-0.29.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:106616f7f825851043c53d3dc186965c003985da9cbb6e5c034f35108fc1fc27", size = 36549163, upload-time = "2025-10-17T19:18:37.701Z" }, + { url = "https://files.pythonhosted.org/packages/09/3f/0be35ddad7e13d8ecd33a9185895f9739bb00b96ef0cce36cf0405d4aec0/mlx_metal-0.30.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:e7e92c6bdbd7ac8083f528a4c6640552ae106a57bb3d99856ac10a32e93a4b5e", size = 36864966, upload-time = "2025-12-18T01:55:31.473Z" }, + { url = "https://files.pythonhosted.org/packages/1e/1f/c0bddd0d5bf3871411aabe32121e09e1b7cdbece8917a49d5a442310e3e5/mlx_metal-0.30.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:bb50f57418af7fc3c42a2da2c4bde0e7ab7ac0b997de1f6f642a6680ac65d626", size = 36859011, upload-time = "2025-12-18T01:55:34.541Z" }, + { url = "https://files.pythonhosted.org/packages/67/b3/73cc2f584ac612a476096d35a61eed75ee7ed8b4e320b0c36cf60a14d4eb/mlx_metal-0.30.1-py3-none-macosx_26_0_arm64.whl", hash = "sha256:e0b151a0053ac00b4226710bfb6dbf54b87283fb01e10fb3877f9ea969f680aa", size = 44981160, upload-time = "2025-12-18T00:15:47.518Z" }, ] [[package]]