fix(placement): gate RDMA on nodeRdmaCtl.enabled at both endpoints (#2014)

## Summary

- Fixes a bug where `POST /place_instance` (and the dashboard UI) would
accept an MlxJaccl/RDMA instance spanning nodes whose
`nodeRdmaCtl.enabled` was `false`, because topology + placement
consulted Thunderbolt-derived RDMA edges without checking the per-node
`rdma_ctl` status.
- Three-layer fix: topology only emits `RDMAConnection` edges when both
endpoints have `nodeRdmaCtl.enabled = true`; flipping a node to disabled
immediately purges every RDMA edge touching it; `place_instance`
additionally rejects RDMA cycles containing any disabled or unobserved
node as a defense-in-depth check on the API/master path.

## Details

- `src/exo/shared/apply.py`
- `MacThunderboltConnections` case now filters out RDMA connections
whose source or sink lacks observed-and-enabled `rdma_ctl` status
(missing entry → treated as disabled).
- `RdmaCtlStatus` case now calls
`topology.remove_all_rdma_connections_touching(node_id)` when the node
reports disabled, so consumers don't have to wait for the next TB poll.
- `src/exo/shared/topology.py`
- New `Topology.remove_all_rdma_connections_touching(node_id)` removes
every RDMA edge incident to the node (incoming and outgoing) while
leaving socket edges intact.
- `src/exo/master/placement.py`
- `place_instance` accepts `node_rdma_ctl: Mapping[NodeId,
NodeRdmaCtlStatus] | None`. The `is_rdma_cycle` filter now also requires
`nodeRdmaCtl.enabled` for every node in the cycle. MlxJaccl placement
raises the existing "no RDMA-connected cycles available" error if no
qualifying cycle remains.
- `src/exo/api/main.py`, `src/exo/master/main.py`
  - Both placement entrypoints now pass `state.node_rdma_ctl` through.

## Tests

- `src/exo/shared/tests/test_apply/test_apply_rdma_gating.py` (new): six
unit tests covering enabled/disabled/missing combinations on apply, the
immediate-purge transition, and that purging RDMA edges leaves socket
edges untouched.
- `src/exo/master/tests/test_placement.py`: existing
`test_tensor_rdma_backend_connectivity_matrix` updated to pass
`node_rdma_ctl`. Two new tests assert MlxJaccl placement is rejected
when any cycle node is `enabled=false` or has no `rdma_ctl` entry.

## Test plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — clean
- [x] `nix fmt`
- [x] `uv run pytest` — 429 passed, 1 skipped
- [ ] On a real mixed cluster (s15/s16 disabled, s17/s18 enabled),
confirm:
- [ ] `POST /place_instance` for an RDMA instance including s15 or s16
returns an error
  - [ ] An RDMA instance can still be placed across {s17, s18}
- [ ] `GET /state` shows no `sourceRdmaIface`/`sinkRdmaIface` on s15↔s16
connections
- [ ] Dashboard previews don't surface RDMA-spanning options that
include s15/s16

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Alex Cheema
2026-05-07 00:00:15 -07:00
committed by GitHub
parent 89d20c1888
commit a0c00f9dfd
7 changed files with 432 additions and 4 deletions

View File

@@ -481,6 +481,7 @@ class API:
topology=self.state.topology,
current_instances=self.state.instances,
download_status=self.state.downloads,
node_rdma_ctl=self.state.node_rdma_ctl,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@@ -544,6 +545,7 @@ class API:
current_instances=self.state.instances,
required_nodes=required_nodes,
download_status=self.state.downloads,
node_rdma_ctl=self.state.node_rdma_ctl,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:

View File

@@ -365,6 +365,7 @@ class Master:
self.state.node_memory,
self.state.node_network,
download_status=self.state.downloads,
node_rdma_ctl=self.state.node_rdma_ctl,
)
transition_events = get_transition_events(
self.state.instances, placement, self.state.tasks

View File

@@ -28,7 +28,7 @@ from exo.shared.types.events import (
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo, NodeRdmaCtlStatus
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import (
DownloadCompleted,
@@ -105,6 +105,7 @@ def place_instance(
node_network: Mapping[NodeId, NodeNetworkInfo],
required_nodes: set[NodeId] | None = None,
download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None,
node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] | None = None,
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
@@ -166,8 +167,18 @@ def place_instance(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
rdma_ctl_status = node_rdma_ctl or {}
def _all_rdma_ctl_enabled(cycle: Cycle) -> bool:
return all(
((status := rdma_ctl_status.get(node_id)) is not None and status.enabled)
for node_id in cycle
)
smallest_rdma_cycles = [
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
cycle
for cycle in smallest_cycles
if topology.is_rdma_cycle(cycle) and _all_rdma_ctl_enabled(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl:

View File

@@ -21,7 +21,11 @@ from exo.shared.types.events import (
)
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
NodeNetworkInfo,
NodeRdmaCtlStatus,
)
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
from exo.shared.types.text_generation import (
InputMessage,
@@ -439,8 +443,21 @@ def test_tensor_rdma_backend_connectivity_matrix(
min_nodes=1,
)
node_rdma_ctl = {
node_a: NodeRdmaCtlStatus(enabled=True),
node_b: NodeRdmaCtlStatus(enabled=True),
node_c: NodeRdmaCtlStatus(enabled=True),
}
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(
cic,
topology,
{},
node_memory,
node_network,
node_rdma_ctl=node_rdma_ctl,
)
# assert
assert len(placements) == 1
@@ -482,6 +499,131 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert len(ip_part.split(".")) == 4
def _build_three_node_rdma_topology() -> tuple[
Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo]
]:
topology = Topology()
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
ethernet_interface = NetworkInterfaceInfo(name="en0", ip_address="10.0.0.1")
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
)
node_network = {
node_a: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_b: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_c: NodeNetworkInfo(interfaces=[ethernet_interface]),
}
for n in (node_a, node_b, node_c):
topology.add_node(n)
rdma_pairs = [
(node_a, node_b, 3),
(node_b, node_a, 3),
(node_b, node_c, 4),
(node_c, node_b, 4),
(node_a, node_c, 5),
(node_c, node_a, 5),
]
for src, sink, iface in rdma_pairs:
topology.add_connection(
Connection(source=src, sink=sink, edge=create_rdma_connection(iface))
)
socket_pairs = [
(node_a, node_b),
(node_b, node_c),
(node_c, node_a),
(node_a, node_c),
(node_b, node_a),
(node_c, node_b),
]
for src, sink in socket_pairs:
topology.add_connection(Connection(source=src, sink=sink, edge=ethernet_conn))
return topology, node_a, node_b, node_c, node_network
def test_place_mlx_jaccl_rejects_when_a_node_has_rdma_ctl_disabled(
model_card: ModelCard,
):
# arrange
model_card = model_card.model_copy(
update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)}
)
topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology()
node_memory = {
node_a: create_node_memory(500),
node_b: create_node_memory(500),
node_c: create_node_memory(500),
}
node_rdma_ctl = {
node_a: NodeRdmaCtlStatus(enabled=True),
node_b: NodeRdmaCtlStatus(enabled=True),
node_c: NodeRdmaCtlStatus(enabled=False),
}
cic = PlaceInstance(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_card=model_card,
min_nodes=3,
)
# act / assert
with pytest.raises(
ValueError, match="Requested RDMA \\(MlxJaccl\\) but no RDMA-connected cycles"
):
place_instance(
cic,
topology,
{},
node_memory,
node_network,
node_rdma_ctl=node_rdma_ctl,
)
def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing(model_card: ModelCard):
"""A node with no observed rdma_ctl status must not participate in RDMA placement."""
# arrange
model_card = model_card.model_copy(
update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)}
)
topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology()
node_memory = {
node_a: create_node_memory(500),
node_b: create_node_memory(500),
node_c: create_node_memory(500),
}
# node_c has no rdma_ctl entry at all
node_rdma_ctl = {
node_a: NodeRdmaCtlStatus(enabled=True),
node_b: NodeRdmaCtlStatus(enabled=True),
}
cic = PlaceInstance(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_card=model_card,
min_nodes=3,
)
# act / assert
with pytest.raises(ValueError):
place_instance(
cic,
topology,
{},
node_memory,
node_network,
node_rdma_ctl=node_rdma_ctl,
)
def _make_task(
instance_id: InstanceId,
status: TaskStatus = TaskStatus.Running,

View File

@@ -65,6 +65,18 @@ from exo.utils.info_gatherer.info_gatherer import (
)
def _is_rdma_ctl_enabled(
node_id: NodeId, node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus]
) -> bool:
"""A node is RDMA-capable only if rdma_ctl status has been observed as enabled.
Missing entries default to ``False`` — if we have not yet observed (or the node
cannot run) ``rdma_ctl``, it must not participate in an RDMA-backed instance.
"""
status = node_rdma_ctl.get(node_id)
return status is not None and status.enabled
def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
@@ -397,6 +409,9 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
for nid in state.node_thunderbolt
for tb_ident in state.node_thunderbolt[nid].interfaces
}
source_is_rdma_enabled = _is_rdma_ctl_enabled(
event.node_id, state.node_rdma_ctl
)
as_rdma_conns = [
Connection(
source=event.node_id,
@@ -409,6 +424,10 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
if source_is_rdma_enabled
and _is_rdma_ctl_enabled(
conn_map[tb_conn.sink_uuid][0], state.node_rdma_ctl
)
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
case ThunderboltBridgeInfo():
@@ -432,6 +451,12 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
**state.node_rdma_ctl,
event.node_id: NodeRdmaCtlStatus(enabled=info.enabled),
}
# If RDMA just got disabled on this node, drop any RDMA edges touching it
# so placement / topology consumers cannot pick a disabled node for an
# RDMA-backed instance. (Edges will repopulate on the next
# MacThunderboltConnections poll once both endpoints are enabled again.)
if not info.enabled:
topology.remove_all_rdma_connections_touching(event.node_id)
return state.model_copy(update=update)

View File

@@ -0,0 +1,231 @@
from datetime import datetime, timezone
from exo.shared.apply import apply_node_gathered_info
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeGatheredInfo
from exo.shared.types.profiling import (
NodeRdmaCtlStatus,
NodeThunderboltInfo,
)
from exo.shared.types.state import State
from exo.shared.types.thunderbolt import ThunderboltConnection, ThunderboltIdentifier
from exo.shared.types.topology import RDMAConnection
from exo.utils.info_gatherer.info_gatherer import (
MacThunderboltConnections,
RdmaCtlStatus,
)
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
def _make_state_with_thunderbolt_idents(
*node_ids_and_uuids: tuple[NodeId, str, str],
rdma_ctl: dict[NodeId, NodeRdmaCtlStatus] | None = None,
) -> State:
"""Build a State with Thunderbolt identifiers per node so the apply MacThunderboltConnections
case can resolve uuid -> (node, iface)."""
node_thunderbolt = {
nid: NodeThunderboltInfo(
interfaces=[ThunderboltIdentifier(rdma_interface=iface, domain_uuid=uuid)]
)
for nid, uuid, iface in node_ids_and_uuids
}
return State(
node_thunderbolt=node_thunderbolt,
node_rdma_ctl=rdma_ctl or {},
)
def _has_rdma_edge(topology: Topology, source: NodeId, sink: NodeId) -> bool:
return any(
isinstance(edge, RDMAConnection)
for edge in topology.get_all_connections_between(source, sink)
)
def test_mac_thunderbolt_connections_emits_rdma_when_both_endpoints_enabled():
node_a = NodeId()
node_b = NodeId()
state = _make_state_with_thunderbolt_idents(
(node_a, "uuid-a", "rdma_en1"),
(node_b, "uuid-b", "rdma_en1"),
rdma_ctl={
node_a: NodeRdmaCtlStatus(enabled=True),
node_b: NodeRdmaCtlStatus(enabled=True),
},
)
event = NodeGatheredInfo(
node_id=node_a,
when=_now(),
info=MacThunderboltConnections(
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
),
)
new_state = apply_node_gathered_info(event, state)
assert _has_rdma_edge(new_state.topology, node_a, node_b)
def test_mac_thunderbolt_connections_skips_rdma_when_source_rdma_ctl_disabled():
node_a = NodeId()
node_b = NodeId()
state = _make_state_with_thunderbolt_idents(
(node_a, "uuid-a", "rdma_en1"),
(node_b, "uuid-b", "rdma_en1"),
rdma_ctl={
node_a: NodeRdmaCtlStatus(enabled=False),
node_b: NodeRdmaCtlStatus(enabled=True),
},
)
event = NodeGatheredInfo(
node_id=node_a,
when=_now(),
info=MacThunderboltConnections(
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
),
)
new_state = apply_node_gathered_info(event, state)
assert not _has_rdma_edge(new_state.topology, node_a, node_b)
def test_mac_thunderbolt_connections_skips_rdma_when_sink_rdma_ctl_disabled():
node_a = NodeId()
node_b = NodeId()
state = _make_state_with_thunderbolt_idents(
(node_a, "uuid-a", "rdma_en1"),
(node_b, "uuid-b", "rdma_en1"),
rdma_ctl={
node_a: NodeRdmaCtlStatus(enabled=True),
node_b: NodeRdmaCtlStatus(enabled=False),
},
)
event = NodeGatheredInfo(
node_id=node_a,
when=_now(),
info=MacThunderboltConnections(
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
),
)
new_state = apply_node_gathered_info(event, state)
assert not _has_rdma_edge(new_state.topology, node_a, node_b)
def test_mac_thunderbolt_connections_skips_rdma_when_rdma_ctl_status_missing():
"""Missing rdma_ctl status defaults to not-enabled — node is RDMA-incapable."""
node_a = NodeId()
node_b = NodeId()
state = _make_state_with_thunderbolt_idents(
(node_a, "uuid-a", "rdma_en1"),
(node_b, "uuid-b", "rdma_en1"),
rdma_ctl={
node_a: NodeRdmaCtlStatus(enabled=True),
# node_b intentionally absent
},
)
event = NodeGatheredInfo(
node_id=node_a,
when=_now(),
info=MacThunderboltConnections(
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
),
)
new_state = apply_node_gathered_info(event, state)
assert not _has_rdma_edge(new_state.topology, node_a, node_b)
def test_rdma_ctl_status_disabled_purges_existing_rdma_edges():
"""When a node reports rdma_ctl disabled, all RDMA edges touching it must be removed."""
node_a = NodeId()
node_b = NodeId()
# Start with both nodes RDMA-enabled and existing RDMA edges in the topology.
state = _make_state_with_thunderbolt_idents(
(node_a, "uuid-a", "rdma_en1"),
(node_b, "uuid-b", "rdma_en1"),
rdma_ctl={
node_a: NodeRdmaCtlStatus(enabled=True),
node_b: NodeRdmaCtlStatus(enabled=True),
},
)
state = apply_node_gathered_info(
NodeGatheredInfo(
node_id=node_a,
when=_now(),
info=MacThunderboltConnections(
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
),
),
state,
)
state = apply_node_gathered_info(
NodeGatheredInfo(
node_id=node_b,
when=_now(),
info=MacThunderboltConnections(
conns=[ThunderboltConnection(source_uuid="uuid-b", sink_uuid="uuid-a")]
),
),
state,
)
assert _has_rdma_edge(state.topology, node_a, node_b)
assert _has_rdma_edge(state.topology, node_b, node_a)
# Now node_a flips to rdma_ctl disabled — both directions of RDMA edge must drop.
state = apply_node_gathered_info(
NodeGatheredInfo(
node_id=node_a, when=_now(), info=RdmaCtlStatus(enabled=False)
),
state,
)
assert not _has_rdma_edge(state.topology, node_a, node_b)
assert not _has_rdma_edge(state.topology, node_b, node_a)
assert state.node_rdma_ctl[node_a].enabled is False
def test_topology_remove_all_rdma_connections_touching_keeps_socket_edges():
"""Purging RDMA edges for a disabled node must not affect non-RDMA edges."""
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.topology import Connection, SocketConnection
topology = Topology()
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(
Connection(
source=node_a,
sink=node_b,
edge=RDMAConnection(
source_rdma_iface="rdma_en1", sink_rdma_iface="rdma_en1"
),
)
)
socket_edge = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
)
topology.add_connection(Connection(source=node_a, sink=node_b, edge=socket_edge))
topology.remove_all_rdma_connections_touching(node_a)
assert not _has_rdma_edge(topology, node_a, node_b)
# Socket edge survives.
assert any(
isinstance(edge, SocketConnection)
for edge in topology.get_all_connections_between(node_a, node_b)
)

View File

@@ -169,6 +169,22 @@ class Topology:
for conn in new_connections:
self.add_connection(conn)
def remove_all_rdma_connections_touching(self, node_id: NodeId) -> None:
"""Remove every RDMA edge incident to ``node_id`` (incoming or outgoing)."""
if node_id not in self._vertex_indices:
return
rx_idx = self._vertex_indices[node_id]
rdma_edge_idxs = [
edge_idx
for edge_idx in (
*self._graph.out_edge_indices(rx_idx),
*self._graph.in_edge_indices(rx_idx),
)
if isinstance(self._graph.get_edge_data_by_index(edge_idx), RDMAConnection)
]
for edge_idx in rdma_edge_idxs:
self._graph.remove_edge_from_index(edge_idx)
def remove_connection(self, conn: Connection) -> None:
if (
conn.source not in self._vertex_indices