ibv -> jaccl

This commit is contained in:
Evan
2025-12-20 12:34:25 +00:00
parent 646d553bc5
commit 0a0ea37c64
9 changed files with 60 additions and 59 deletions

View File

@@ -9,8 +9,8 @@ from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_shard_assignments,
get_smallest_cycles,
)
@@ -59,9 +59,7 @@ def place_instance(
logger.info("finding cycles:")
cycles = topology.get_cycles() + [[node] for node in all_nodes]
logger.info(cycles)
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles)
)
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_profiles, command.model_meta.storage_size
)
@@ -108,7 +106,7 @@ def place_instance(
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected ibv for a single node instance; falling back to MlxRing"
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -116,20 +114,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
[node.node_id for node in selected_cycle],
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
cycle_digraph,
)
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
[node.node_id for node in selected_cycle],
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle[0].node_id,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
ibv_coordinators=mlx_ibv_coordinators,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)

View File

@@ -195,8 +195,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
return hosts
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeId],
def get_mlx_jaccl_devices_matrix(
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -205,6 +204,7 @@ def get_mlx_ibv_devices_matrix(
to device j, or None if no connection exists or no interface name is found.
Diagonal elements are always None.
"""
selected_cycle = list(cycle_digraph.list_nodes())
num_nodes = len(selected_cycle)
matrix: list[list[str | None]] = [
[None for _ in range(num_nodes)] for _ in range(num_nodes)
@@ -234,29 +234,29 @@ def _find_connection_ip(
yield connection.sink_multiaddr.ip_address
def get_mlx_ibv_coordinators(
selected_cycle: list[NodeId],
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX IBV (rank 0 device).
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node}")
selected_cycle = list(cycle_digraph.list_nodes())
logger.info(f"Selecting coordinator: {coordinator}")
def get_ip_for_node(n: NodeId) -> str:
if n == rank_0_node:
if n == coordinator:
return "0.0.0.0"
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {rank_0_node}"
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
raise ValueError("Current jaccl backend requires all-to-all rdma connections")
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}

View File

@@ -244,7 +244,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
node_id_c = NodeId()
node_id_d = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
@@ -267,7 +266,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
logger.info(list(topology.list_connections()))
cic = place_instance_command(
model_meta=model_meta,
)
@@ -280,7 +278,9 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
instance = list(placements.values())[0]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set((node_id_c, node_id_d))
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(node_id_c, node_id_d)
)
def test_tensor_rdma_backend_connectivity_matrix(
@@ -335,10 +335,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.ibv_coordinators is not None
assert instance.jaccl_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.ibv_devices
matrix = instance.jaccl_devices
assert len(matrix) == 3
for i in range(3):
@@ -358,10 +358,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert matrix[idx_c][idx_a] == "rdma_en5"
# Verify coordinators are set for all nodes
assert len(instance.ibv_coordinators) == 3
assert len(instance.jaccl_coordinators) == 3
for node_id in assigned_nodes:
assert node_id in instance.ibv_coordinators
coordinator = instance.ibv_coordinators[node_id]
assert node_id in instance.jaccl_coordinators
coordinator = instance.jaccl_coordinators[node_id]
assert ":" in coordinator
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
if node_id == assigned_nodes[0]:

View File

@@ -4,7 +4,7 @@ from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
@@ -265,7 +265,7 @@ def test_get_hosts_from_subgraph():
assert expected_host in hosts
def test_get_mlx_ibv_coordinators():
def test_get_mlx_jaccl_coordinators():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
@@ -295,11 +295,9 @@ def test_get_mlx_ibv_coordinators():
topology.add_connection(node_c_id, node_a_id, conn_c_a)
topology.add_connection(node_a_id, node_c_id, conn_a_c)
cycle = [node_a_id, node_b_id, node_c_id]
# act
coordinators = get_mlx_ibv_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
coordinators = get_mlx_jaccl_coordinators(
node_a_id, coordinator_port=5000, cycle_digraph=topology
)
# assert

View File

@@ -247,7 +247,13 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
(conn_map[tb_conn.sink_uuid][0], RDMAConnection(source_rdma_iface=conn_map[tb_conn.source_uuid][1], sink_rdma_iface=conn_map[tb_conn.sink_uuid][1]))
(
conn_map[tb_conn.sink_uuid][0],
RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map

View File

@@ -129,14 +129,20 @@ class Topology:
def replace_all_out_tb_connections(
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._node_id_to_rx_id_map[source]):
for conn_idx in self._graph.out_edge_indices(
self._node_id_to_rx_id_map[source]
):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for sink, conn in new_connections:
self.add_connection(source, sink, conn)
def remove_connection(self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection) -> None:
for conn_idx in self._graph.edge_indices_from_endpoints(self._node_id_to_rx_id_map[source], self._node_id_to_rx_id_map[sink]):
def remove_connection(
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
) -> None:
for conn_idx in self._graph.edge_indices_from_endpoints(
self._node_id_to_rx_id_map[source], self._node_id_to_rx_id_map[sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == edge:
self._graph.remove_edge_from_index(conn_idx)

View File

@@ -29,8 +29,8 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
ibv_coordinators: dict[NodeId, str]
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
# TODO: Single node instance

View File

@@ -101,13 +101,7 @@ def mlx_distributed_init(
bound_instance: BoundInstance,
) -> mx.distributed.Group:
"""
Initialize the MLX distributed (runs in thread pool).
Either hosts or mlx_ibv_devices must be provided:
- hosts: traditional host-based connectivity using MLX_HOSTFILE
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
Initialize the MLX distributed
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
@@ -129,22 +123,22 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
ibv_devices_json = json.dumps(ibv_devices)
jaccl_devices_json = json.dumps(jaccl_devices)
with open(devices_file, "w") as f:
_ = f.write(ibv_devices_json)
_ = f.write(jaccl_devices_json)
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
logger.info(f"rank {rank} MLX_IBV_DEVICES: {jaccl_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
os.environ["MLX_IBV_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")

View File

@@ -22,7 +22,7 @@ def entrypoint(
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
and len(bound_instance.instance.jaccl_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"