From 2112e273f59dafa87fa0f7ea2d7cf7c4dc4773d4 Mon Sep 17 00:00:00 2001 From: Evan Date: Fri, 19 Dec 2025 13:51:21 +0000 Subject: [PATCH] rework topology --- src/exo/master/api.py | 6 ++- src/exo/master/main.py | 1 + src/exo/shared/topology.py | 70 ++++++++++++++------------------ src/exo/shared/types/topology.py | 40 ++++++------------ 4 files changed, 48 insertions(+), 69 deletions(-) diff --git a/src/exo/master/api.py b/src/exo/master/api.py index 7778ca3d..ae31add2 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -207,6 +207,7 @@ class API: instance_meta=instance_meta, min_nodes=min_nodes, ), + node_profiles=self.state.node_profiles, topology=self.state.topology, current_instances=self.state.instances, ) @@ -262,6 +263,7 @@ class API: instance_meta=instance_meta, min_nodes=min_nodes, ), + node_profiles=self.state.node_profiles, topology=self.state.topology, current_instances=self.state.instances, ) @@ -427,8 +429,8 @@ class API: total_available = Memory() for node in self.state.topology.list_nodes(): - if node.node_profile is not None: - total_available += node.node_profile.memory.ram_available + if node in self.state.node_profiles: + total_available += self.state.node_profiles[node].memory.ram_available return total_available diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 49e734d4..961a51ff 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -158,6 +158,7 @@ class Master: command, self.state.topology, self.state.instances, + self.state.node_profiles, ) transition_events = get_transition_events( self.state.instances, placement diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py index 0c956c76..94689645 100644 --- a/src/exo/shared/topology.py +++ b/src/exo/shared/topology.py @@ -5,23 +5,23 @@ import rustworkx as rx from pydantic import BaseModel, ConfigDict from exo.shared.types.common import NodeId -from exo.shared.types.profiling import ConnectionProfile -from exo.shared.types.topology import Connection +from exo.shared.types.topology import Connection, TBConnection class TopologySnapshot(BaseModel): nodes: list[NodeId] - connections: list[Connection] + connections: list[tuple[NodeId, NodeId, Connection | TBConnection]] model_config = ConfigDict(frozen=True, extra="forbid", strict=True) + class Topology: def __init__(self) -> None: - self._graph: rx.PyDiGraph[NodeId, Connection] = rx.PyDiGraph() + self._graph: rx.PyDiGraph[NodeId, Connection | TBConnection] = rx.PyDiGraph() self._node_id_to_rx_id_map: dict[NodeId, int] = dict() self._rx_id_to_node_id_map: dict[int, NodeId] = dict() - self._edge_id_to_rx_id_map: dict[Connection, int] = dict() + self._edge_id_to_rx_id_map: dict[Connection | TBConnection, int] = dict() def to_snapshot(self) -> TopologySnapshot: return TopologySnapshot( @@ -37,8 +37,8 @@ class Topology: with contextlib.suppress(ValueError): topology.add_node(node) - for connection in snapshot.connections: - topology.add_connection(connection) + for source, sink, connection in snapshot.connections: + topology.add_connection(source, sink, connection) return topology @@ -61,7 +61,7 @@ class Topology: for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id]) ] - def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]: + def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection | TBConnection]]: if node_id not in self._node_id_to_rx_id_map: return [] return [ @@ -74,23 +74,25 @@ class Topology: def contains_node(self, node_id: NodeId) -> bool: return node_id in self._node_id_to_rx_id_map - def contains_connection(self, connection: Connection) -> bool: + def contains_connection(self, connection: Connection | TBConnection) -> bool: return connection in self._edge_id_to_rx_id_map def add_connection( self, - connection: Connection, + source: NodeId, + sink: NodeId, + connection: Connection | TBConnection, ) -> None: - if connection.local_node_id not in self._node_id_to_rx_id_map: - self.add_node(connection.local_node_id) - if connection.send_back_node_id not in self._node_id_to_rx_id_map: - self.add_node(connection.send_back_node_id) + if source not in self._node_id_to_rx_id_map: + self.add_node(source) + if sink not in self._node_id_to_rx_id_map: + self.add_node(sink) if connection in self._edge_id_to_rx_id_map: return - src_id = self._node_id_to_rx_id_map[connection.local_node_id] - sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id] + src_id = self._node_id_to_rx_id_map[source] + sink_id = self._node_id_to_rx_id_map[sink] rx_id = self._graph.add_edge(src_id, sink_id, connection) self._edge_id_to_rx_id_map[connection] = rx_id @@ -98,30 +100,17 @@ class Topology: def list_nodes(self) -> Iterable[NodeId]: return (self._graph[i] for i in self._graph.node_indices()) - def list_connections(self) -> Iterable[Connection]: - return (connection for _, _, connection in self._graph.weighted_edge_list()) - - def update_connection_profile(self, connection: Connection) -> None: - rx_idx = self._edge_id_to_rx_id_map[connection] - self._graph.update_edge_by_index(rx_idx, connection) - - def get_connection_profile( - self, connection: Connection - ) -> ConnectionProfile | None: - try: - rx_idx = self._edge_id_to_rx_id_map[connection] - return self._graph.get_edge_data_by_index(rx_idx).connection_profile - except KeyError: - return None + def list_connections(self) -> Iterable[tuple[NodeId, NodeId, Connection | TBConnection]]: + return ((self._rx_id_to_node_id_map[src_id], self._rx_id_to_node_id_map[sink_id], connection) for src_id, sink_id, connection in self._graph.weighted_edge_list()) def remove_node(self, node_id: NodeId) -> None: if node_id not in self._node_id_to_rx_id_map: return - for connection in self.list_connections(): + for src, sink, connection in self.list_connections(): if ( - connection.local_node_id == node_id - or connection.send_back_node_id == node_id + src == node_id + or sink == node_id ): self.remove_connection(connection) @@ -131,7 +120,7 @@ class Topology: del self._node_id_to_rx_id_map[node_id] del self._rx_id_to_node_id_map[rx_idx] - def remove_connection(self, connection: Connection) -> None: + def remove_connection(self, connection: Connection | TBConnection) -> None: if connection not in self._edge_id_to_rx_id_map: return rx_idx = self._edge_id_to_rx_id_map[connection] @@ -158,7 +147,8 @@ class Topology: tb_graph.add_nodes_from(self._graph.nodes()) for u, v, conn in tb_edges: - tb_graph.add_edge(u, v, conn) + if isinstance(conn, Connection): + tb_graph.add_edge(u, v, conn) cycle_idxs = rx.simple_cycles(tb_graph) cycles: list[list[NodeId]] = [] @@ -174,12 +164,12 @@ class Topology: topology = Topology() for rx_idx in rx_idxs: topology.add_node(self._graph[rx_idx]) - for connection in self.list_connections(): + for source, sink, connection in self.list_connections(): if ( - connection.local_node_id in node_idxs - and connection.send_back_node_id in node_idxs + source in node_idxs + and sink in node_idxs ): - topology.add_connection(connection) + topology.add_connection(source, sink, connection) return topology def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool: diff --git a/src/exo/shared/types/topology.py b/src/exo/shared/types/topology.py index e24d7d2e..ba6f5950 100644 --- a/src/exo/shared/types/topology.py +++ b/src/exo/shared/types/topology.py @@ -1,32 +1,18 @@ -from exo.shared.types.common import NodeId +from loguru import logger + from exo.shared.types.multiaddr import Multiaddr -from exo.shared.types.profiling import ConnectionProfile from exo.utils.pydantic_ext import CamelCaseModel - -class Connection(CamelCaseModel): - local_node_id: NodeId - send_back_node_id: NodeId - send_back_multiaddr: Multiaddr - connection_profile: ConnectionProfile | None = None - - def __hash__(self) -> int: - return hash( - ( - self.local_node_id, - self.send_back_node_id, - self.send_back_multiaddr.address, - ) - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Connection): - raise ValueError("Cannot compare Connection with non-Connection") - return ( - self.local_node_id == other.local_node_id - and self.send_back_node_id == other.send_back_node_id - and self.send_back_multiaddr == other.send_back_multiaddr - ) +class TBConnection(CamelCaseModel): + source_rdma_iface: str + sink_rdma_iface: str def is_thunderbolt(self) -> bool: - return str(self.send_back_multiaddr.ipv4_address).startswith("169.254") + logger.warning("duh") + return True + +class Connection(CamelCaseModel): + sink_multiaddr: Multiaddr + + def is_thunderbolt(self) -> bool: + return str(self.sink_multiaddr.ipv4_address).startswith("169.254")