mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-19 04:05:23 -04:00
fix(placement): gate RDMA on nodeRdmaCtl.enabled at both endpoints (#2014)
## Summary
- Fixes a bug where `POST /place_instance` (and the dashboard UI) would
accept an MlxJaccl/RDMA instance spanning nodes whose
`nodeRdmaCtl.enabled` was `false`, because topology + placement
consulted Thunderbolt-derived RDMA edges without checking the per-node
`rdma_ctl` status.
- Three-layer fix: topology only emits `RDMAConnection` edges when both
endpoints have `nodeRdmaCtl.enabled = true`; flipping a node to disabled
immediately purges every RDMA edge touching it; `place_instance`
additionally rejects RDMA cycles containing any disabled or unobserved
node as a defense-in-depth check on the API/master path.
## Details
- `src/exo/shared/apply.py`
- `MacThunderboltConnections` case now filters out RDMA connections
whose source or sink lacks observed-and-enabled `rdma_ctl` status
(missing entry → treated as disabled).
- `RdmaCtlStatus` case now calls
`topology.remove_all_rdma_connections_touching(node_id)` when the node
reports disabled, so consumers don't have to wait for the next TB poll.
- `src/exo/shared/topology.py`
- New `Topology.remove_all_rdma_connections_touching(node_id)` removes
every RDMA edge incident to the node (incoming and outgoing) while
leaving socket edges intact.
- `src/exo/master/placement.py`
- `place_instance` accepts `node_rdma_ctl: Mapping[NodeId,
NodeRdmaCtlStatus] | None`. The `is_rdma_cycle` filter now also requires
`nodeRdmaCtl.enabled` for every node in the cycle. MlxJaccl placement
raises the existing "no RDMA-connected cycles available" error if no
qualifying cycle remains.
- `src/exo/api/main.py`, `src/exo/master/main.py`
- Both placement entrypoints now pass `state.node_rdma_ctl` through.
## Tests
- `src/exo/shared/tests/test_apply/test_apply_rdma_gating.py` (new): six
unit tests covering enabled/disabled/missing combinations on apply, the
immediate-purge transition, and that purging RDMA edges leaves socket
edges untouched.
- `src/exo/master/tests/test_placement.py`: existing
`test_tensor_rdma_backend_connectivity_matrix` updated to pass
`node_rdma_ctl`. Two new tests assert MlxJaccl placement is rejected
when any cycle node is `enabled=false` or has no `rdma_ctl` entry.
## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — clean
- [x] `nix fmt`
- [x] `uv run pytest` — 429 passed, 1 skipped
- [ ] On a real mixed cluster (s15/s16 disabled, s17/s18 enabled),
confirm:
- [ ] `POST /place_instance` for an RDMA instance including s15 or s16
returns an error
- [ ] An RDMA instance can still be placed across {s17, s18}
- [ ] `GET /state` shows no `sourceRdmaIface`/`sinkRdmaIface` on s15↔s16
connections
- [ ] Dashboard previews don't surface RDMA-spanning options that
include s15/s16
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -481,6 +481,7 @@ class API:
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
download_status=self.state.downloads,
|
||||
node_rdma_ctl=self.state.node_rdma_ctl,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
@@ -544,6 +545,7 @@ class API:
|
||||
current_instances=self.state.instances,
|
||||
required_nodes=required_nodes,
|
||||
download_status=self.state.downloads,
|
||||
node_rdma_ctl=self.state.node_rdma_ctl,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
|
||||
@@ -365,6 +365,7 @@ class Master:
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
download_status=self.state.downloads,
|
||||
node_rdma_ctl=self.state.node_rdma_ctl,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
|
||||
@@ -28,7 +28,7 @@ from exo.shared.types.events import (
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo, NodeRdmaCtlStatus
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
@@ -105,6 +105,7 @@ def place_instance(
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
required_nodes: set[NodeId] | None = None,
|
||||
download_status: Mapping[NodeId, Sequence[DownloadProgress]] | None = None,
|
||||
node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] | None = None,
|
||||
) -> dict[InstanceId, Instance]:
|
||||
cycles = topology.get_cycles()
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
@@ -166,8 +167,18 @@ def place_instance(
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
rdma_ctl_status = node_rdma_ctl or {}
|
||||
|
||||
def _all_rdma_ctl_enabled(cycle: Cycle) -> bool:
|
||||
return all(
|
||||
((status := rdma_ctl_status.get(node_id)) is not None and status.enabled)
|
||||
for node_id in cycle
|
||||
)
|
||||
|
||||
smallest_rdma_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if topology.is_rdma_cycle(cycle) and _all_rdma_ctl_enabled(cycle)
|
||||
]
|
||||
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl:
|
||||
|
||||
@@ -21,7 +21,11 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
|
||||
from exo.shared.types.profiling import (
|
||||
NetworkInterfaceInfo,
|
||||
NodeNetworkInfo,
|
||||
NodeRdmaCtlStatus,
|
||||
)
|
||||
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
|
||||
from exo.shared.types.text_generation import (
|
||||
InputMessage,
|
||||
@@ -439,8 +443,21 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
node_rdma_ctl = {
|
||||
node_a: NodeRdmaCtlStatus(enabled=True),
|
||||
node_b: NodeRdmaCtlStatus(enabled=True),
|
||||
node_c: NodeRdmaCtlStatus(enabled=True),
|
||||
}
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, node_memory, node_network)
|
||||
placements = place_instance(
|
||||
cic,
|
||||
topology,
|
||||
{},
|
||||
node_memory,
|
||||
node_network,
|
||||
node_rdma_ctl=node_rdma_ctl,
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
@@ -482,6 +499,131 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
|
||||
def _build_three_node_rdma_topology() -> tuple[
|
||||
Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo]
|
||||
]:
|
||||
topology = Topology()
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
node_c = NodeId()
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(name="en0", ip_address="10.0.0.1")
|
||||
ethernet_conn = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
|
||||
)
|
||||
node_network = {
|
||||
node_a: NodeNetworkInfo(interfaces=[ethernet_interface]),
|
||||
node_b: NodeNetworkInfo(interfaces=[ethernet_interface]),
|
||||
node_c: NodeNetworkInfo(interfaces=[ethernet_interface]),
|
||||
}
|
||||
|
||||
for n in (node_a, node_b, node_c):
|
||||
topology.add_node(n)
|
||||
|
||||
rdma_pairs = [
|
||||
(node_a, node_b, 3),
|
||||
(node_b, node_a, 3),
|
||||
(node_b, node_c, 4),
|
||||
(node_c, node_b, 4),
|
||||
(node_a, node_c, 5),
|
||||
(node_c, node_a, 5),
|
||||
]
|
||||
for src, sink, iface in rdma_pairs:
|
||||
topology.add_connection(
|
||||
Connection(source=src, sink=sink, edge=create_rdma_connection(iface))
|
||||
)
|
||||
|
||||
socket_pairs = [
|
||||
(node_a, node_b),
|
||||
(node_b, node_c),
|
||||
(node_c, node_a),
|
||||
(node_a, node_c),
|
||||
(node_b, node_a),
|
||||
(node_c, node_b),
|
||||
]
|
||||
for src, sink in socket_pairs:
|
||||
topology.add_connection(Connection(source=src, sink=sink, edge=ethernet_conn))
|
||||
|
||||
return topology, node_a, node_b, node_c, node_network
|
||||
|
||||
|
||||
def test_place_mlx_jaccl_rejects_when_a_node_has_rdma_ctl_disabled(
|
||||
model_card: ModelCard,
|
||||
):
|
||||
# arrange
|
||||
model_card = model_card.model_copy(
|
||||
update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)}
|
||||
)
|
||||
topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology()
|
||||
node_memory = {
|
||||
node_a: create_node_memory(500),
|
||||
node_b: create_node_memory(500),
|
||||
node_c: create_node_memory(500),
|
||||
}
|
||||
node_rdma_ctl = {
|
||||
node_a: NodeRdmaCtlStatus(enabled=True),
|
||||
node_b: NodeRdmaCtlStatus(enabled=True),
|
||||
node_c: NodeRdmaCtlStatus(enabled=False),
|
||||
}
|
||||
cic = PlaceInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
instance_meta=InstanceMeta.MlxJaccl,
|
||||
command_id=CommandId(),
|
||||
model_card=model_card,
|
||||
min_nodes=3,
|
||||
)
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(
|
||||
ValueError, match="Requested RDMA \\(MlxJaccl\\) but no RDMA-connected cycles"
|
||||
):
|
||||
place_instance(
|
||||
cic,
|
||||
topology,
|
||||
{},
|
||||
node_memory,
|
||||
node_network,
|
||||
node_rdma_ctl=node_rdma_ctl,
|
||||
)
|
||||
|
||||
|
||||
def test_place_mlx_jaccl_rejects_when_node_rdma_ctl_missing(model_card: ModelCard):
|
||||
"""A node with no observed rdma_ctl status must not participate in RDMA placement."""
|
||||
# arrange
|
||||
model_card = model_card.model_copy(
|
||||
update={"n_layers": 12, "storage_size": Memory.from_bytes(1500)}
|
||||
)
|
||||
topology, node_a, node_b, node_c, node_network = _build_three_node_rdma_topology()
|
||||
node_memory = {
|
||||
node_a: create_node_memory(500),
|
||||
node_b: create_node_memory(500),
|
||||
node_c: create_node_memory(500),
|
||||
}
|
||||
# node_c has no rdma_ctl entry at all
|
||||
node_rdma_ctl = {
|
||||
node_a: NodeRdmaCtlStatus(enabled=True),
|
||||
node_b: NodeRdmaCtlStatus(enabled=True),
|
||||
}
|
||||
cic = PlaceInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
instance_meta=InstanceMeta.MlxJaccl,
|
||||
command_id=CommandId(),
|
||||
model_card=model_card,
|
||||
min_nodes=3,
|
||||
)
|
||||
|
||||
# act / assert
|
||||
with pytest.raises(ValueError):
|
||||
place_instance(
|
||||
cic,
|
||||
topology,
|
||||
{},
|
||||
node_memory,
|
||||
node_network,
|
||||
node_rdma_ctl=node_rdma_ctl,
|
||||
)
|
||||
|
||||
|
||||
def _make_task(
|
||||
instance_id: InstanceId,
|
||||
status: TaskStatus = TaskStatus.Running,
|
||||
|
||||
@@ -65,6 +65,18 @@ from exo.utils.info_gatherer.info_gatherer import (
|
||||
)
|
||||
|
||||
|
||||
def _is_rdma_ctl_enabled(
|
||||
node_id: NodeId, node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus]
|
||||
) -> bool:
|
||||
"""A node is RDMA-capable only if rdma_ctl status has been observed as enabled.
|
||||
|
||||
Missing entries default to ``False`` — if we have not yet observed (or the node
|
||||
cannot run) ``rdma_ctl``, it must not participate in an RDMA-backed instance.
|
||||
"""
|
||||
status = node_rdma_ctl.get(node_id)
|
||||
return status is not None and status.enabled
|
||||
|
||||
|
||||
def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
@@ -397,6 +409,9 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
for nid in state.node_thunderbolt
|
||||
for tb_ident in state.node_thunderbolt[nid].interfaces
|
||||
}
|
||||
source_is_rdma_enabled = _is_rdma_ctl_enabled(
|
||||
event.node_id, state.node_rdma_ctl
|
||||
)
|
||||
as_rdma_conns = [
|
||||
Connection(
|
||||
source=event.node_id,
|
||||
@@ -409,6 +424,10 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
for tb_conn in info.conns
|
||||
if tb_conn.source_uuid in conn_map
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
if source_is_rdma_enabled
|
||||
and _is_rdma_ctl_enabled(
|
||||
conn_map[tb_conn.sink_uuid][0], state.node_rdma_ctl
|
||||
)
|
||||
]
|
||||
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
|
||||
case ThunderboltBridgeInfo():
|
||||
@@ -432,6 +451,12 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
**state.node_rdma_ctl,
|
||||
event.node_id: NodeRdmaCtlStatus(enabled=info.enabled),
|
||||
}
|
||||
# If RDMA just got disabled on this node, drop any RDMA edges touching it
|
||||
# so placement / topology consumers cannot pick a disabled node for an
|
||||
# RDMA-backed instance. (Edges will repopulate on the next
|
||||
# MacThunderboltConnections poll once both endpoints are enabled again.)
|
||||
if not info.enabled:
|
||||
topology.remove_all_rdma_connections_touching(event.node_id)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
231
src/exo/shared/tests/test_apply/test_apply_rdma_gating.py
Normal file
231
src/exo/shared/tests/test_apply/test_apply_rdma_gating.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from exo.shared.apply import apply_node_gathered_info
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import NodeGatheredInfo
|
||||
from exo.shared.types.profiling import (
|
||||
NodeRdmaCtlStatus,
|
||||
NodeThunderboltInfo,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.thunderbolt import ThunderboltConnection, ThunderboltIdentifier
|
||||
from exo.shared.types.topology import RDMAConnection
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
MacThunderboltConnections,
|
||||
RdmaCtlStatus,
|
||||
)
|
||||
|
||||
|
||||
def _now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _make_state_with_thunderbolt_idents(
|
||||
*node_ids_and_uuids: tuple[NodeId, str, str],
|
||||
rdma_ctl: dict[NodeId, NodeRdmaCtlStatus] | None = None,
|
||||
) -> State:
|
||||
"""Build a State with Thunderbolt identifiers per node so the apply MacThunderboltConnections
|
||||
case can resolve uuid -> (node, iface)."""
|
||||
node_thunderbolt = {
|
||||
nid: NodeThunderboltInfo(
|
||||
interfaces=[ThunderboltIdentifier(rdma_interface=iface, domain_uuid=uuid)]
|
||||
)
|
||||
for nid, uuid, iface in node_ids_and_uuids
|
||||
}
|
||||
return State(
|
||||
node_thunderbolt=node_thunderbolt,
|
||||
node_rdma_ctl=rdma_ctl or {},
|
||||
)
|
||||
|
||||
|
||||
def _has_rdma_edge(topology: Topology, source: NodeId, sink: NodeId) -> bool:
|
||||
return any(
|
||||
isinstance(edge, RDMAConnection)
|
||||
for edge in topology.get_all_connections_between(source, sink)
|
||||
)
|
||||
|
||||
|
||||
def test_mac_thunderbolt_connections_emits_rdma_when_both_endpoints_enabled():
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
state = _make_state_with_thunderbolt_idents(
|
||||
(node_a, "uuid-a", "rdma_en1"),
|
||||
(node_b, "uuid-b", "rdma_en1"),
|
||||
rdma_ctl={
|
||||
node_a: NodeRdmaCtlStatus(enabled=True),
|
||||
node_b: NodeRdmaCtlStatus(enabled=True),
|
||||
},
|
||||
)
|
||||
|
||||
event = NodeGatheredInfo(
|
||||
node_id=node_a,
|
||||
when=_now(),
|
||||
info=MacThunderboltConnections(
|
||||
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
|
||||
),
|
||||
)
|
||||
|
||||
new_state = apply_node_gathered_info(event, state)
|
||||
|
||||
assert _has_rdma_edge(new_state.topology, node_a, node_b)
|
||||
|
||||
|
||||
def test_mac_thunderbolt_connections_skips_rdma_when_source_rdma_ctl_disabled():
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
state = _make_state_with_thunderbolt_idents(
|
||||
(node_a, "uuid-a", "rdma_en1"),
|
||||
(node_b, "uuid-b", "rdma_en1"),
|
||||
rdma_ctl={
|
||||
node_a: NodeRdmaCtlStatus(enabled=False),
|
||||
node_b: NodeRdmaCtlStatus(enabled=True),
|
||||
},
|
||||
)
|
||||
|
||||
event = NodeGatheredInfo(
|
||||
node_id=node_a,
|
||||
when=_now(),
|
||||
info=MacThunderboltConnections(
|
||||
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
|
||||
),
|
||||
)
|
||||
|
||||
new_state = apply_node_gathered_info(event, state)
|
||||
|
||||
assert not _has_rdma_edge(new_state.topology, node_a, node_b)
|
||||
|
||||
|
||||
def test_mac_thunderbolt_connections_skips_rdma_when_sink_rdma_ctl_disabled():
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
state = _make_state_with_thunderbolt_idents(
|
||||
(node_a, "uuid-a", "rdma_en1"),
|
||||
(node_b, "uuid-b", "rdma_en1"),
|
||||
rdma_ctl={
|
||||
node_a: NodeRdmaCtlStatus(enabled=True),
|
||||
node_b: NodeRdmaCtlStatus(enabled=False),
|
||||
},
|
||||
)
|
||||
|
||||
event = NodeGatheredInfo(
|
||||
node_id=node_a,
|
||||
when=_now(),
|
||||
info=MacThunderboltConnections(
|
||||
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
|
||||
),
|
||||
)
|
||||
|
||||
new_state = apply_node_gathered_info(event, state)
|
||||
|
||||
assert not _has_rdma_edge(new_state.topology, node_a, node_b)
|
||||
|
||||
|
||||
def test_mac_thunderbolt_connections_skips_rdma_when_rdma_ctl_status_missing():
|
||||
"""Missing rdma_ctl status defaults to not-enabled — node is RDMA-incapable."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
state = _make_state_with_thunderbolt_idents(
|
||||
(node_a, "uuid-a", "rdma_en1"),
|
||||
(node_b, "uuid-b", "rdma_en1"),
|
||||
rdma_ctl={
|
||||
node_a: NodeRdmaCtlStatus(enabled=True),
|
||||
# node_b intentionally absent
|
||||
},
|
||||
)
|
||||
|
||||
event = NodeGatheredInfo(
|
||||
node_id=node_a,
|
||||
when=_now(),
|
||||
info=MacThunderboltConnections(
|
||||
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
|
||||
),
|
||||
)
|
||||
|
||||
new_state = apply_node_gathered_info(event, state)
|
||||
|
||||
assert not _has_rdma_edge(new_state.topology, node_a, node_b)
|
||||
|
||||
|
||||
def test_rdma_ctl_status_disabled_purges_existing_rdma_edges():
|
||||
"""When a node reports rdma_ctl disabled, all RDMA edges touching it must be removed."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
# Start with both nodes RDMA-enabled and existing RDMA edges in the topology.
|
||||
state = _make_state_with_thunderbolt_idents(
|
||||
(node_a, "uuid-a", "rdma_en1"),
|
||||
(node_b, "uuid-b", "rdma_en1"),
|
||||
rdma_ctl={
|
||||
node_a: NodeRdmaCtlStatus(enabled=True),
|
||||
node_b: NodeRdmaCtlStatus(enabled=True),
|
||||
},
|
||||
)
|
||||
state = apply_node_gathered_info(
|
||||
NodeGatheredInfo(
|
||||
node_id=node_a,
|
||||
when=_now(),
|
||||
info=MacThunderboltConnections(
|
||||
conns=[ThunderboltConnection(source_uuid="uuid-a", sink_uuid="uuid-b")]
|
||||
),
|
||||
),
|
||||
state,
|
||||
)
|
||||
state = apply_node_gathered_info(
|
||||
NodeGatheredInfo(
|
||||
node_id=node_b,
|
||||
when=_now(),
|
||||
info=MacThunderboltConnections(
|
||||
conns=[ThunderboltConnection(source_uuid="uuid-b", sink_uuid="uuid-a")]
|
||||
),
|
||||
),
|
||||
state,
|
||||
)
|
||||
assert _has_rdma_edge(state.topology, node_a, node_b)
|
||||
assert _has_rdma_edge(state.topology, node_b, node_a)
|
||||
|
||||
# Now node_a flips to rdma_ctl disabled — both directions of RDMA edge must drop.
|
||||
state = apply_node_gathered_info(
|
||||
NodeGatheredInfo(
|
||||
node_id=node_a, when=_now(), info=RdmaCtlStatus(enabled=False)
|
||||
),
|
||||
state,
|
||||
)
|
||||
|
||||
assert not _has_rdma_edge(state.topology, node_a, node_b)
|
||||
assert not _has_rdma_edge(state.topology, node_b, node_a)
|
||||
assert state.node_rdma_ctl[node_a].enabled is False
|
||||
|
||||
|
||||
def test_topology_remove_all_rdma_connections_touching_keeps_socket_edges():
|
||||
"""Purging RDMA edges for a disabled node must not affect non-RDMA edges."""
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
|
||||
topology = Topology()
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=node_a,
|
||||
sink=node_b,
|
||||
edge=RDMAConnection(
|
||||
source_rdma_iface="rdma_en1", sink_rdma_iface="rdma_en1"
|
||||
),
|
||||
)
|
||||
)
|
||||
socket_edge = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
|
||||
)
|
||||
topology.add_connection(Connection(source=node_a, sink=node_b, edge=socket_edge))
|
||||
|
||||
topology.remove_all_rdma_connections_touching(node_a)
|
||||
|
||||
assert not _has_rdma_edge(topology, node_a, node_b)
|
||||
# Socket edge survives.
|
||||
assert any(
|
||||
isinstance(edge, SocketConnection)
|
||||
for edge in topology.get_all_connections_between(node_a, node_b)
|
||||
)
|
||||
@@ -169,6 +169,22 @@ class Topology:
|
||||
for conn in new_connections:
|
||||
self.add_connection(conn)
|
||||
|
||||
def remove_all_rdma_connections_touching(self, node_id: NodeId) -> None:
|
||||
"""Remove every RDMA edge incident to ``node_id`` (incoming or outgoing)."""
|
||||
if node_id not in self._vertex_indices:
|
||||
return
|
||||
rx_idx = self._vertex_indices[node_id]
|
||||
rdma_edge_idxs = [
|
||||
edge_idx
|
||||
for edge_idx in (
|
||||
*self._graph.out_edge_indices(rx_idx),
|
||||
*self._graph.in_edge_indices(rx_idx),
|
||||
)
|
||||
if isinstance(self._graph.get_edge_data_by_index(edge_idx), RDMAConnection)
|
||||
]
|
||||
for edge_idx in rdma_edge_idxs:
|
||||
self._graph.remove_edge_from_index(edge_idx)
|
||||
|
||||
def remove_connection(self, conn: Connection) -> None:
|
||||
if (
|
||||
conn.source not in self._vertex_indices
|
||||
|
||||
Reference in New Issue
Block a user