mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
ibv -> jaccl
This commit is contained in:
@@ -9,8 +9,8 @@ from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_coordinators,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -59,9 +59,7 @@ def place_instance(
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles() + [[node] for node in all_nodes]
|
||||
logger.info(cycles)
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles)
|
||||
)
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, node_profiles, command.model_meta.storage_size
|
||||
)
|
||||
@@ -108,7 +106,7 @@ def place_instance(
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.warning(
|
||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
||||
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
command.instance_meta = InstanceMeta.MlxRing
|
||||
@@ -116,20 +114,19 @@ def place_instance(
|
||||
# TODO: Single node instances
|
||||
match command.instance_meta:
|
||||
case InstanceMeta.MlxJaccl:
|
||||
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
|
||||
[node.node_id for node in selected_cycle],
|
||||
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
||||
cycle_digraph,
|
||||
)
|
||||
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
|
||||
[node.node_id for node in selected_cycle],
|
||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||
coordinator=selected_cycle[0].node_id,
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
cycle_digraph=cycle_digraph,
|
||||
)
|
||||
target_instances[instance_id] = MlxJacclInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
ibv_coordinators=mlx_ibv_coordinators,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
|
||||
@@ -195,8 +195,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_ibv_devices_matrix(
|
||||
selected_cycle: list[NodeId],
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
cycle_digraph: Topology,
|
||||
) -> list[list[str | None]]:
|
||||
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
|
||||
@@ -205,6 +204,7 @@ def get_mlx_ibv_devices_matrix(
|
||||
to device j, or None if no connection exists or no interface name is found.
|
||||
Diagonal elements are always None.
|
||||
"""
|
||||
selected_cycle = list(cycle_digraph.list_nodes())
|
||||
num_nodes = len(selected_cycle)
|
||||
matrix: list[list[str | None]] = [
|
||||
[None for _ in range(num_nodes)] for _ in range(num_nodes)
|
||||
@@ -234,29 +234,29 @@ def _find_connection_ip(
|
||||
yield connection.sink_multiaddr.ip_address
|
||||
|
||||
|
||||
def get_mlx_ibv_coordinators(
|
||||
selected_cycle: list[NodeId],
|
||||
def get_mlx_jaccl_coordinators(
|
||||
coordinator: NodeId,
|
||||
coordinator_port: int,
|
||||
cycle_digraph: Topology,
|
||||
) -> dict[NodeId, str]:
|
||||
"""Get the coordinator addresses for MLX IBV (rank 0 device).
|
||||
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
|
||||
|
||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node}")
|
||||
selected_cycle = list(cycle_digraph.list_nodes())
|
||||
logger.info(f"Selecting coordinator: {coordinator}")
|
||||
|
||||
def get_ip_for_node(n: NodeId) -> str:
|
||||
if n == rank_0_node:
|
||||
if n == coordinator:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n} and {rank_0_node}"
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
||||
raise ValueError("Current jaccl backend requires all-to-all rdma connections")
|
||||
|
||||
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}
|
||||
|
||||
@@ -244,7 +244,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
node_id_c = NodeId()
|
||||
node_id_d = NodeId()
|
||||
|
||||
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(500),
|
||||
node_id_b: create_node_profile(600),
|
||||
@@ -267,7 +266,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
|
||||
logger.info(list(topology.list_connections()))
|
||||
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
)
|
||||
@@ -280,7 +278,9 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
instance = list(placements.values())[0]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set((node_id_c, node_id_d))
|
||||
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
|
||||
(node_id_c, node_id_d)
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
@@ -335,10 +335,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.ibv_coordinators is not None
|
||||
assert instance.jaccl_devices is not None
|
||||
assert instance.jaccl_coordinators is not None
|
||||
|
||||
matrix = instance.ibv_devices
|
||||
matrix = instance.jaccl_devices
|
||||
assert len(matrix) == 3
|
||||
|
||||
for i in range(3):
|
||||
@@ -358,10 +358,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.ibv_coordinators) == 3
|
||||
assert len(instance.jaccl_coordinators) == 3
|
||||
for node_id in assigned_nodes:
|
||||
assert node_id in instance.ibv_coordinators
|
||||
coordinator = instance.ibv_coordinators[node_id]
|
||||
assert node_id in instance.jaccl_coordinators
|
||||
coordinator = instance.jaccl_coordinators[node_id]
|
||||
assert ":" in coordinator
|
||||
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
|
||||
if node_id == assigned_nodes[0]:
|
||||
|
||||
@@ -4,7 +4,7 @@ from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_coordinators,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -265,7 +265,7 @@ def test_get_hosts_from_subgraph():
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_ibv_coordinators():
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
@@ -295,11 +295,9 @@ def test_get_mlx_ibv_coordinators():
|
||||
topology.add_connection(node_c_id, node_a_id, conn_c_a)
|
||||
topology.add_connection(node_a_id, node_c_id, conn_a_c)
|
||||
|
||||
cycle = [node_a_id, node_b_id, node_c_id]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_ibv_coordinators(
|
||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||
coordinators = get_mlx_jaccl_coordinators(
|
||||
node_a_id, coordinator_port=5000, cycle_digraph=topology
|
||||
)
|
||||
|
||||
# assert
|
||||
|
||||
@@ -247,7 +247,13 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
for tb_ident in state.node_profiles[nid].tb_interfaces
|
||||
}
|
||||
as_rdma_conns = [
|
||||
(conn_map[tb_conn.sink_uuid][0], RDMAConnection(source_rdma_iface=conn_map[tb_conn.source_uuid][1], sink_rdma_iface=conn_map[tb_conn.sink_uuid][1]))
|
||||
(
|
||||
conn_map[tb_conn.sink_uuid][0],
|
||||
RDMAConnection(
|
||||
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
|
||||
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
|
||||
),
|
||||
)
|
||||
for tb_conn in info
|
||||
if tb_conn.source_uuid in conn_map
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
|
||||
@@ -129,14 +129,20 @@ class Topology:
|
||||
def replace_all_out_tb_connections(
|
||||
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
|
||||
) -> None:
|
||||
for conn_idx in self._graph.out_edge_indices(self._node_id_to_rx_id_map[source]):
|
||||
for conn_idx in self._graph.out_edge_indices(
|
||||
self._node_id_to_rx_id_map[source]
|
||||
):
|
||||
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
for sink, conn in new_connections:
|
||||
self.add_connection(source, sink, conn)
|
||||
|
||||
def remove_connection(self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection) -> None:
|
||||
for conn_idx in self._graph.edge_indices_from_endpoints(self._node_id_to_rx_id_map[source], self._node_id_to_rx_id_map[sink]):
|
||||
def remove_connection(
|
||||
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
|
||||
) -> None:
|
||||
for conn_idx in self._graph.edge_indices_from_endpoints(
|
||||
self._node_id_to_rx_id_map[source], self._node_id_to_rx_id_map[sink]
|
||||
):
|
||||
if self._graph.get_edge_data_by_index(conn_idx) == edge:
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ class MlxRingInstance(BaseInstance):
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
ibv_devices: list[list[str | None]]
|
||||
ibv_coordinators: dict[NodeId, str]
|
||||
jaccl_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
|
||||
@@ -101,13 +101,7 @@ def mlx_distributed_init(
|
||||
bound_instance: BoundInstance,
|
||||
) -> mx.distributed.Group:
|
||||
"""
|
||||
Initialize the MLX distributed (runs in thread pool).
|
||||
|
||||
Either hosts or mlx_ibv_devices must be provided:
|
||||
- hosts: traditional host-based connectivity using MLX_HOSTFILE
|
||||
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
|
||||
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
|
||||
- strict: if True, raise an error if the distributed backend is not available
|
||||
Initialize the MLX distributed
|
||||
"""
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
@@ -129,22 +123,22 @@ def mlx_distributed_init(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
|
||||
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
|
||||
with open(devices_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
_ = f.write(jaccl_devices_json)
|
||||
|
||||
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {jaccl_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
|
||||
os.environ["MLX_IBV_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
@@ -22,7 +22,7 @@ def entrypoint(
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
and len(bound_instance.instance.jaccl_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user