diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index b3f51222..3805ed81 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -9,8 +9,8 @@ from exo.master.placement_utils import ( NodeWithProfile, filter_cycles_by_memory, get_hosts_from_subgraph, - get_mlx_ibv_coordinators, - get_mlx_ibv_devices_matrix, + get_mlx_jaccl_coordinators, + get_mlx_jaccl_devices_matrix, get_shard_assignments, get_smallest_cycles, ) @@ -59,9 +59,7 @@ def place_instance( logger.info("finding cycles:") cycles = topology.get_cycles() + [[node] for node in all_nodes] logger.info(cycles) - candidate_cycles = list( - filter(lambda it: len(it) >= command.min_nodes, cycles) - ) + candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles)) cycles_with_sufficient_memory = filter_cycles_by_memory( candidate_cycles, node_profiles, command.model_meta.storage_size ) @@ -108,7 +106,7 @@ def place_instance( if len(selected_cycle) == 1: logger.warning( - "You have likely selected ibv for a single node instance; falling back to MlxRing" + "You have likely selected jaccl for a single node instance; falling back to MlxRing" ) command.instance_meta = InstanceMeta.MlxRing @@ -116,20 +114,19 @@ def place_instance( # TODO: Single node instances match command.instance_meta: case InstanceMeta.MlxJaccl: - mlx_ibv_devices = get_mlx_ibv_devices_matrix( - [node.node_id for node in selected_cycle], + mlx_jaccl_devices = get_mlx_jaccl_devices_matrix( cycle_digraph, ) - mlx_ibv_coordinators = get_mlx_ibv_coordinators( - [node.node_id for node in selected_cycle], + mlx_jaccl_coordinators = get_mlx_jaccl_coordinators( + coordinator=selected_cycle[0].node_id, coordinator_port=random_ephemeral_port(), cycle_digraph=cycle_digraph, ) target_instances[instance_id] = MlxJacclInstance( instance_id=instance_id, shard_assignments=shard_assignments, - ibv_devices=mlx_ibv_devices, - ibv_coordinators=mlx_ibv_coordinators, + jaccl_devices=mlx_jaccl_devices, + jaccl_coordinators=mlx_jaccl_coordinators, ) case InstanceMeta.MlxRing: hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph) diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index b426ab7e..63fa1dc2 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -195,8 +195,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]: return hosts -def get_mlx_ibv_devices_matrix( - selected_cycle: list[NodeId], +def get_mlx_jaccl_devices_matrix( cycle_digraph: Topology, ) -> list[list[str | None]]: """Build connectivity matrix mapping device i to device j via RDMA interface names. @@ -205,6 +204,7 @@ def get_mlx_ibv_devices_matrix( to device j, or None if no connection exists or no interface name is found. Diagonal elements are always None. """ + selected_cycle = list(cycle_digraph.list_nodes()) num_nodes = len(selected_cycle) matrix: list[list[str | None]] = [ [None for _ in range(num_nodes)] for _ in range(num_nodes) @@ -234,29 +234,29 @@ def _find_connection_ip( yield connection.sink_multiaddr.ip_address -def get_mlx_ibv_coordinators( - selected_cycle: list[NodeId], +def get_mlx_jaccl_coordinators( + coordinator: NodeId, coordinator_port: int, cycle_digraph: Topology, ) -> dict[NodeId, str]: - """Get the coordinator addresses for MLX IBV (rank 0 device). + """Get the coordinator addresses for MLX JACCL (rank 0 device). Select an IP address that each node can reach for the rank 0 node. Returns address in format "X.X.X.X:PORT" per node. """ - rank_0_node = selected_cycle[0] - logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node}") + selected_cycle = list(cycle_digraph.list_nodes()) + logger.info(f"Selecting coordinator: {coordinator}") def get_ip_for_node(n: NodeId) -> str: - if n == rank_0_node: + if n == coordinator: return "0.0.0.0" - for ip in _find_connection_ip(n, rank_0_node, cycle_digraph): + for ip in _find_connection_ip(n, coordinator, cycle_digraph): return ip logger.warning( - f"Failed to find directly connected ip between {n} and {rank_0_node}" + f"Failed to find directly connected ip between {n} and {coordinator}" ) - raise ValueError("Current ibv backend requires all-to-all rdma connections") + raise ValueError("Current jaccl backend requires all-to-all rdma connections") return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle} diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index efc556e9..e2211edb 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -244,7 +244,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory( node_id_c = NodeId() node_id_d = NodeId() - profiles = { node_id_a: create_node_profile(500), node_id_b: create_node_profile(600), @@ -267,7 +266,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory( logger.info(list(topology.list_connections())) - cic = place_instance_command( model_meta=model_meta, ) @@ -280,7 +278,9 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory( instance = list(placements.values())[0] assigned_nodes = set(instance.shard_assignments.node_to_runner.keys()) - assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set((node_id_c, node_id_d)) + assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set( + (node_id_c, node_id_d) + ) def test_tensor_rdma_backend_connectivity_matrix( @@ -335,10 +335,10 @@ def test_tensor_rdma_backend_connectivity_matrix( assert isinstance(instance, MlxJacclInstance) - assert instance.ibv_devices is not None - assert instance.ibv_coordinators is not None + assert instance.jaccl_devices is not None + assert instance.jaccl_coordinators is not None - matrix = instance.ibv_devices + matrix = instance.jaccl_devices assert len(matrix) == 3 for i in range(3): @@ -358,10 +358,10 @@ def test_tensor_rdma_backend_connectivity_matrix( assert matrix[idx_c][idx_a] == "rdma_en5" # Verify coordinators are set for all nodes - assert len(instance.ibv_coordinators) == 3 + assert len(instance.jaccl_coordinators) == 3 for node_id in assigned_nodes: - assert node_id in instance.ibv_coordinators - coordinator = instance.ibv_coordinators[node_id] + assert node_id in instance.jaccl_coordinators + coordinator = instance.jaccl_coordinators[node_id] assert ":" in coordinator # Rank 0 node should use 0.0.0.0, others should use connection-specific IPs if node_id == assigned_nodes[0]: diff --git a/src/exo/master/tests/test_placement_utils.py b/src/exo/master/tests/test_placement_utils.py index 7bc8b7a7..7609d357 100644 --- a/src/exo/master/tests/test_placement_utils.py +++ b/src/exo/master/tests/test_placement_utils.py @@ -4,7 +4,7 @@ from exo.master.placement_utils import ( NodeWithProfile, filter_cycles_by_memory, get_hosts_from_subgraph, - get_mlx_ibv_coordinators, + get_mlx_jaccl_coordinators, get_shard_assignments, get_smallest_cycles, ) @@ -265,7 +265,7 @@ def test_get_hosts_from_subgraph(): assert expected_host in hosts -def test_get_mlx_ibv_coordinators(): +def test_get_mlx_jaccl_coordinators(): # arrange node_a_id = NodeId() node_b_id = NodeId() @@ -295,11 +295,9 @@ def test_get_mlx_ibv_coordinators(): topology.add_connection(node_c_id, node_a_id, conn_c_a) topology.add_connection(node_a_id, node_c_id, conn_a_c) - cycle = [node_a_id, node_b_id, node_c_id] - # act - coordinators = get_mlx_ibv_coordinators( - cycle, coordinator_port=5000, cycle_digraph=topology + coordinators = get_mlx_jaccl_coordinators( + node_a_id, coordinator_port=5000, cycle_digraph=topology ) # assert diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index 4e01f8ad..a9c87dde 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -247,7 +247,13 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: for tb_ident in state.node_profiles[nid].tb_interfaces } as_rdma_conns = [ - (conn_map[tb_conn.sink_uuid][0], RDMAConnection(source_rdma_iface=conn_map[tb_conn.source_uuid][1], sink_rdma_iface=conn_map[tb_conn.sink_uuid][1])) + ( + conn_map[tb_conn.sink_uuid][0], + RDMAConnection( + source_rdma_iface=conn_map[tb_conn.source_uuid][1], + sink_rdma_iface=conn_map[tb_conn.sink_uuid][1], + ), + ) for tb_conn in info if tb_conn.source_uuid in conn_map if tb_conn.sink_uuid in conn_map diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py index 0f4ae44e..9e45f7e4 100644 --- a/src/exo/shared/topology.py +++ b/src/exo/shared/topology.py @@ -129,14 +129,20 @@ class Topology: def replace_all_out_tb_connections( self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]] ) -> None: - for conn_idx in self._graph.out_edge_indices(self._node_id_to_rx_id_map[source]): + for conn_idx in self._graph.out_edge_indices( + self._node_id_to_rx_id_map[source] + ): if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection): self._graph.remove_edge_from_index(conn_idx) for sink, conn in new_connections: self.add_connection(source, sink, conn) - def remove_connection(self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection) -> None: - for conn_idx in self._graph.edge_indices_from_endpoints(self._node_id_to_rx_id_map[source], self._node_id_to_rx_id_map[sink]): + def remove_connection( + self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection + ) -> None: + for conn_idx in self._graph.edge_indices_from_endpoints( + self._node_id_to_rx_id_map[source], self._node_id_to_rx_id_map[sink] + ): if self._graph.get_edge_data_by_index(conn_idx) == edge: self._graph.remove_edge_from_index(conn_idx) diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index ea8e7887..4252146f 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -29,8 +29,8 @@ class MlxRingInstance(BaseInstance): class MlxJacclInstance(BaseInstance): - ibv_devices: list[list[str | None]] - ibv_coordinators: dict[NodeId, str] + jaccl_devices: list[list[str | None]] + jaccl_coordinators: dict[NodeId, str] # TODO: Single node instance diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 19d565ca..a63f0ef7 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -101,13 +101,7 @@ def mlx_distributed_init( bound_instance: BoundInstance, ) -> mx.distributed.Group: """ - Initialize the MLX distributed (runs in thread pool). - - Either hosts or mlx_ibv_devices must be provided: - - hosts: traditional host-based connectivity using MLX_HOSTFILE - - mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES - - mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup - - strict: if True, raise an error if the distributed backend is not available + Initialize the MLX distributed """ rank = bound_instance.bound_shard.device_rank logger.info(f"Starting initialization for rank {rank}") @@ -129,22 +123,22 @@ def mlx_distributed_init( group = mx.distributed.init(backend="ring", strict=True) case MlxJacclInstance( - ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators + jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators ): # Use RDMA connectivity matrix devices_file = f"./hosts_{rank}.json" - ibv_devices_json = json.dumps(ibv_devices) + jaccl_devices_json = json.dumps(jaccl_devices) with open(devices_file, "w") as f: - _ = f.write(ibv_devices_json) + _ = f.write(jaccl_devices_json) - ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id] + jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id] - logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}") - logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}") + logger.info(f"rank {rank} MLX_IBV_DEVICES: {jaccl_devices_json}") + logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {jaccl_coordinator}") os.environ["MLX_IBV_DEVICES"] = devices_file os.environ["MLX_RANK"] = str(rank) - os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator + os.environ["MLX_IBV_COORDINATOR"] = jaccl_coordinator group = mx.distributed.init(backend="jaccl", strict=True) logger.info(f"Rank {rank} mlx distributed initialization complete") diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index 24d30cb8..44c1cedc 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -22,7 +22,7 @@ def entrypoint( ) -> None: if ( isinstance(bound_instance.instance, MlxJacclInstance) - and len(bound_instance.instance.ibv_devices) >= 2 + and len(bound_instance.instance.jaccl_devices) >= 2 ): os.environ["MLX_METAL_FAST_SYNCH"] = "1"