From a0c00f9dfd1bdbcfa93bd57dcbb9ca1998ee9ffa Mon Sep 17 00:00:00 2001 From: Alex Cheema <41707476+AlexCheema@users.noreply.github.com> Date: Thu, 7 May 2026 00:00:15 -0700 Subject: [PATCH] fix(placement): gate RDMA on nodeRdmaCtl.enabled at both endpoints (#2014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Fixes a bug where `POST /place_instance` (and the dashboard UI) would accept an MlxJaccl/RDMA instance spanning nodes whose `nodeRdmaCtl.enabled` was `false`, because topology + placement consulted Thunderbolt-derived RDMA edges without checking the per-node `rdma_ctl` status. - Three-layer fix: topology only emits `RDMAConnection` edges when both endpoints have `nodeRdmaCtl.enabled = true`; flipping a node to disabled immediately purges every RDMA edge touching it; `place_instance` additionally rejects RDMA cycles containing any disabled or unobserved node as a defense-in-depth check on the API/master path. ## Details - `src/exo/shared/apply.py` - `MacThunderboltConnections` case now filters out RDMA connections whose source or sink lacks observed-and-enabled `rdma_ctl` status (missing entry → treated as disabled). - `RdmaCtlStatus` case now calls `topology.remove_all_rdma_connections_touching(node_id)` when the node reports disabled, so consumers don't have to wait for the next TB poll. - `src/exo/shared/topology.py` - New `Topology.remove_all_rdma_connections_touching(node_id)` removes every RDMA edge incident to the node (incoming and outgoing) while leaving socket edges intact. - `src/exo/master/placement.py` - `place_instance` accepts `node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] | None`. The `is_rdma_cycle` filter now also requires `nodeRdmaCtl.enabled` for every node in the cycle. MlxJaccl placement raises the existing "no RDMA-connected cycles available" error if no qualifying cycle remains. - `src/exo/api/main.py`, `src/exo/master/main.py` - Both placement entrypoints now pass `state.node_rdma_ctl` through. ## Tests - `src/exo/shared/tests/test_apply/test_apply_rdma_gating.py` (new): six unit tests covering enabled/disabled/missing combinations on apply, the immediate-purge transition, and that purging RDMA edges leaves socket edges untouched. - `src/exo/master/tests/test_placement.py`: existing `test_tensor_rdma_backend_connectivity_matrix` updated to pass `node_rdma_ctl`. Two new tests assert MlxJaccl placement is rejected when any cycle node is `enabled=false` or has no `rdma_ctl` entry. ## Test plan - [x] `uv run basedpyright` — 0 errors - [x] `uv run ruff check` — clean - [x] `nix fmt` - [x] `uv run pytest` — 429 passed, 1 skipped - [ ] On a real mixed cluster (s15/s16 disabled, s17/s18 enabled), confirm: - [ ] `POST /place_instance` for an RDMA instance including s15 or s16 returns an error - [ ] An RDMA instance can still be placed across {s17, s18} - [ ] `GET /state` shows no `sourceRdmaIface`/`sinkRdmaIface` on s15↔s16 connections - [ ] Dashboard previews don't surface RDMA-spanning options that include s15/s16 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.7 (1M context) --- src/exo/api/main.py | 2 + src/exo/master/main.py | 1 + src/exo/master/placement.py | 15 +- src/exo/master/tests/test_placement.py | 146 ++++++++++- src/exo/shared/apply.py | 25 ++ .../test_apply/test_apply_rdma_gating.py | 231 ++++++++++++++++++ src/exo/shared/topology.py | 16 ++ 7 files changed, 432 insertions(+), 4 deletions(-) create mode 100644 src/exo/shared/tests/test_apply/test_apply_rdma_gating.py diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 8fe0cfbec..b4c789326 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -481,6 +481,7 @@ class API: topology=self.state.topology, current_instances=self.state.instances, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @@ -544,6 +545,7 @@ class API: current_instances=self.state.instances, required_nodes=required_nodes, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) except ValueError as exc: if (model_card.model_id, sharding, instance_meta, 0) not in seen: diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 044dfbf45..85abbe390 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -365,6 +365,7 @@ class Master: self.state.node_memory, self.state.node_network, download_status=self.state.downloads, + node_rdma_ctl=self.state.node_rdma_ctl, ) transition_events = get_transition_events( self.state.instances, placement, self.state.tasks diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index 715939514..b55571c7f 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -28,7 +28,7 @@ from exo.shared.types.events import ( TaskStatusUpdated, ) from exo.shared.types.memory import Memory -from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo +from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo, NodeRdmaCtlStatus from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.downloads import ( DownloadCompleted, @@ -105,6 +105,7 @@ def place_instance( node_network: Mapping[NodeId, NodeNetworkInfo], required_nodes: set[NodeId] | None = None, download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None, + node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] | None = None, ) -> dict[InstanceId, Instance]: cycles = topology.get_cycles() candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles)) @@ -166,8 +167,18 @@ def place_instance( smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory) + rdma_ctl_status = node_rdma_ctl or {} + + def _all_rdma_ctl_enabled(cycle: Cycle) -> bool: + return all( + ((status := rdma_ctl_status.get(node_id)) is not None and status.enabled) + for node_id in cycle + ) + smallest_rdma_cycles = [ - cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle) + cycle + for cycle in smallest_cycles + if topology.is_rdma_cycle(cycle) and _all_rdma_ctl_enabled(cycle) ] if command.instance_meta == InstanceMeta.MlxJaccl: diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 3e6b9d928..d3acd24f1 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -21,7 +21,11 @@ from exo.shared.types.events import ( ) from exo.shared.types.memory import Memory from exo.shared.types.multiaddr import Multiaddr -from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo +from exo.shared.types.profiling import ( + NetworkInterfaceInfo, + NodeNetworkInfo, + NodeRdmaCtlStatus, +) from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration from exo.shared.types.text_generation import ( InputMessage, @@ -439,8 +443,21 @@ def test_tensor_rdma_backend_connectivity_matrix( min_nodes=1, ) + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + node_c: NodeRdmaCtlStatus(enabled=True), + } + # act - placements = place_instance(cic, topology, {}, node_memory, node_network) + placements = place_instance( + cic, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) # assert assert len(placements) == 1 @@ -482,6 +499,131 @@ def test_tensor_rdma_backend_connectivity_matrix( assert len(ip_part.split(".")) == 4 +def _build_three_node_rdma_topology() -> tuple[ + Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo] +]: + topology = Topology() + node_a = NodeId() + node_b = NodeId() + node_c = NodeId() + + ethernet_interface = NetworkInterfaceInfo(name="en0", ip_address="10.0.0.1") + ethernet_conn = SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000") + ) + node_network = { + node_a: NodeNetworkInfo(interfaces=[ethernet_interface]), + node_b: NodeNetworkInfo(interfaces=[ethernet_interface]), + node_c: NodeNetworkInfo(interfaces=[ethernet_interface]), + } + + for n in (node_a, node_b, node_c): + topology.add_node(n) + + rdma_pairs = [ + (node_a, node_b, 3), + (node_b, node_a, 3), + (node_b, node_c, 4), + (node_c, node_b, 4), + (node_a, node_c, 5), + (node_c, node_a, 5), + ] + for src, sink, iface in rdma_pairs: + topology.add_connection( + Connection(source=src, sink=sink, edge=create_rdma_connection(iface)) + ) + + socket_pairs = [ + (node_a, node_b), + (node_b, node_c), + (node_c, node_a), + (node_a, node_c), + (node_b, node_a), + (node_c, node_b), + ] + for src, sink in socket_pairs: + topology.add_connection(Connection(source=src, sink=sink, edge=ethernet_conn)) + + return topology, node_a, node_b, node_c, node_network + + +def test_place_mlx_jaccl_rejects_when_a_node_has_rdma_ctl_disabled( + model_card: ModelCard, +): + # arrange + model_card = model_card.model_copy( + update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)} + ) + topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology() + node_memory = { + node_a: create_node_memory(500), + node_b: create_node_memory(500), + node_c: create_node_memory(500), + } + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + node_c: NodeRdmaCtlStatus(enabled=False), + } + cic = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=3, + ) + + # act / assert + with pytest.raises( + ValueError, match="Requested RDMA \\(MlxJaccl\\) but no RDMA-connected cycles" + ): + place_instance( + cic, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + +def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing(model_card: ModelCard): + """A node with no observed rdma_ctl status must not participate in RDMA placement.""" + # arrange + model_card = model_card.model_copy( + update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)} + ) + topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology() + node_memory = { + node_a: create_node_memory(500), + node_b: create_node_memory(500), + node_c: create_node_memory(500), + } + # node_c has no rdma_ctl entry at all + node_rdma_ctl = { + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + } + cic = PlaceInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxJaccl, + command_id=CommandId(), + model_card=model_card, + min_nodes=3, + ) + + # act / assert + with pytest.raises(ValueError): + place_instance( + cic, + topology, + {}, + node_memory, + node_network, + node_rdma_ctl=node_rdma_ctl, + ) + + def _make_task( instance_id: InstanceId, status: TaskStatus = TaskStatus.Running, diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index 959f7765b..450c35e7f 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -65,6 +65,18 @@ from exo.utils.info_gatherer.info_gatherer import ( ) +def _is_rdma_ctl_enabled( + node_id: NodeId, node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] +) -> bool: + """A node is RDMA-capable only if rdma_ctl status has been observed as enabled. + + Missing entries default to ``False`` — if we have not yet observed (or the node + cannot run) ``rdma_ctl``, it must not participate in an RDMA-backed instance. + """ + status = node_rdma_ctl.get(node_id) + return status is not None and status.enabled + + def event_apply(event: Event, state: State) -> State: """Apply an event to state.""" match event: @@ -397,6 +409,9 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: for nid in state.node_thunderbolt for tb_ident in state.node_thunderbolt[nid].interfaces } + source_is_rdma_enabled = _is_rdma_ctl_enabled( + event.node_id, state.node_rdma_ctl + ) as_rdma_conns = [ Connection( source=event.node_id, @@ -409,6 +424,10 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: for tb_conn in info.conns if tb_conn.source_uuid in conn_map if tb_conn.sink_uuid in conn_map + if source_is_rdma_enabled + and _is_rdma_ctl_enabled( + conn_map[tb_conn.sink_uuid][0], state.node_rdma_ctl + ) ] topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns) case ThunderboltBridgeInfo(): @@ -432,6 +451,12 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: **state.node_rdma_ctl, event.node_id: NodeRdmaCtlStatus(enabled=info.enabled), } + # If RDMA just got disabled on this node, drop any RDMA edges touching it + # so placement / topology consumers cannot pick a disabled node for an + # RDMA-backed instance. (Edges will repopulate on the next + # MacThunderboltConnections poll once both endpoints are enabled again.) + if not info.enabled: + topology.remove_all_rdma_connections_touching(event.node_id) return state.model_copy(update=update) diff --git a/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py b/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py new file mode 100644 index 000000000..492e3fc5e --- /dev/null +++ b/src/exo/shared/tests/test_apply/test_apply_rdma_gating.py @@ -0,0 +1,231 @@ +from datetime import datetime, timezone + +from exo.shared.apply import apply_node_gathered_info +from exo.shared.topology import Topology +from exo.shared.types.common import NodeId +from exo.shared.types.events import NodeGatheredInfo +from exo.shared.types.profiling import ( + NodeRdmaCtlStatus, + NodeThunderboltInfo, +) +from exo.shared.types.state import State +from exo.shared.types.thunderbolt import ThunderboltConnection, ThunderboltIdentifier +from exo.shared.types.topology import RDMAConnection +from exo.utils.info_gatherer.info_gatherer import ( + MacThunderboltConnections, + RdmaCtlStatus, +) + + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _make_state_with_thunderbolt_idents( + *node_ids_and_uuids: tuple[NodeId, str, str], + rdma_ctl: dict[NodeId, NodeRdmaCtlStatus] | None = None, +) -> State: + """Build a State with Thunderbolt identifiers per node so the apply MacThunderboltConnections + case can resolve uuid -> (node, iface).""" + node_thunderbolt = { + nid: NodeThunderboltInfo( + interfaces=[ThunderboltIdentifier(rdma_interface=iface, domain_uuid=uuid)] + ) + for nid, uuid, iface in node_ids_and_uuids + } + return State( + node_thunderbolt=node_thunderbolt, + node_rdma_ctl=rdma_ctl or {}, + ) + + +def _has_rdma_edge(topology: Topology, source: NodeId, sink: NodeId) -> bool: + return any( + isinstance(edge, RDMAConnection) + for edge in topology.get_all_connections_between(source, sink) + ) + + +def test_mac_thunderbolt_connections_emits_rdma_when_both_endpoints_enabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_source_rdma_ctl_disabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=False), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_sink_rdma_ctl_disabled(): + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=False), + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_mac_thunderbolt_connections_skips_rdma_when_rdma_ctl_status_missing(): + """Missing rdma_ctl status defaults to not-enabled — node is RDMA-incapable.""" + node_a = NodeId() + node_b = NodeId() + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + # node_b intentionally absent + }, + ) + + event = NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ) + + new_state = apply_node_gathered_info(event, state) + + assert not _has_rdma_edge(new_state.topology, node_a, node_b) + + +def test_rdma_ctl_status_disabled_purges_existing_rdma_edges(): + """When a node reports rdma_ctl disabled, all RDMA edges touching it must be removed.""" + node_a = NodeId() + node_b = NodeId() + + # Start with both nodes RDMA-enabled and existing RDMA edges in the topology. + state = _make_state_with_thunderbolt_idents( + (node_a, "uuid-a", "rdma_en1"), + (node_b, "uuid-b", "rdma_en1"), + rdma_ctl={ + node_a: NodeRdmaCtlStatus(enabled=True), + node_b: NodeRdmaCtlStatus(enabled=True), + }, + ) + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_a, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")] + ), + ), + state, + ) + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_b, + when=_now(), + info=MacThunderboltConnections( + conns=[ThunderboltConnection(source_uuid="uuid-b", sink_uuid="uuid-a")] + ), + ), + state, + ) + assert _has_rdma_edge(state.topology, node_a, node_b) + assert _has_rdma_edge(state.topology, node_b, node_a) + + # Now node_a flips to rdma_ctl disabled — both directions of RDMA edge must drop. + state = apply_node_gathered_info( + NodeGatheredInfo( + node_id=node_a, when=_now(), info=RdmaCtlStatus(enabled=False) + ), + state, + ) + + assert not _has_rdma_edge(state.topology, node_a, node_b) + assert not _has_rdma_edge(state.topology, node_b, node_a) + assert state.node_rdma_ctl[node_a].enabled is False + + +def test_topology_remove_all_rdma_connections_touching_keeps_socket_edges(): + """Purging RDMA edges for a disabled node must not affect non-RDMA edges.""" + from exo.shared.types.multiaddr import Multiaddr + from exo.shared.types.topology import Connection, SocketConnection + + topology = Topology() + node_a = NodeId() + node_b = NodeId() + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection( + source=node_a, + sink=node_b, + edge=RDMAConnection( + source_rdma_iface="rdma_en1", sink_rdma_iface="rdma_en1" + ), + ) + ) + socket_edge = SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000") + ) + topology.add_connection(Connection(source=node_a, sink=node_b, edge=socket_edge)) + + topology.remove_all_rdma_connections_touching(node_a) + + assert not _has_rdma_edge(topology, node_a, node_b) + # Socket edge survives. + assert any( + isinstance(edge, SocketConnection) + for edge in topology.get_all_connections_between(node_a, node_b) + ) diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py index 9d649a6f4..121d5af2d 100644 --- a/src/exo/shared/topology.py +++ b/src/exo/shared/topology.py @@ -169,6 +169,22 @@ class Topology: for conn in new_connections: self.add_connection(conn) + def remove_all_rdma_connections_touching(self, node_id: NodeId) -> None: + """Remove every RDMA edge incident to ``node_id`` (incoming or outgoing).""" + if node_id not in self._vertex_indices: + return + rx_idx = self._vertex_indices[node_id] + rdma_edge_idxs = [ + edge_idx + for edge_idx in ( + *self._graph.out_edge_indices(rx_idx), + *self._graph.in_edge_indices(rx_idx), + ) + if isinstance(self._graph.get_edge_data_by_index(edge_idx), RDMAConnection) + ] + for edge_idx in rdma_edge_idxs: + self._graph.remove_edge_from_index(edge_idx) + def remove_connection(self, conn: Connection) -> None: if ( conn.source not in self._vertex_indices