rework topology

This commit is contained in:
Evan
2025-12-19 13:51:21 +00:00
parent e54ab7aa2c
commit 2112e273f5
4 changed files with 48 additions and 69 deletions

View File

@@ -207,6 +207,7 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -262,6 +263,7 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -427,8 +429,8 @@ class API:
total_available = Memory()
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
if node in self.state.node_profiles:
total_available += self.state.node_profiles[node].memory.ram_available
return total_available

View File

@@ -158,6 +158,7 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement

View File

@@ -5,23 +5,23 @@ import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import ConnectionProfile
from exo.shared.types.topology import Connection
from exo.shared.types.topology import Connection, TBConnection
class TopologySnapshot(BaseModel):
nodes: list[NodeId]
connections: list[Connection]
connections: list[tuple[NodeId, NodeId, Connection | TBConnection]]
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
class Topology:
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeId, Connection] = rx.PyDiGraph()
self._graph: rx.PyDiGraph[NodeId, Connection | TBConnection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
self._edge_id_to_rx_id_map: dict[Connection | TBConnection, int] = dict()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
@@ -37,8 +37,8 @@ class Topology:
with contextlib.suppress(ValueError):
topology.add_node(node)
for connection in snapshot.connections:
topology.add_connection(connection)
for source, sink, connection in snapshot.connections:
topology.add_connection(source, sink, connection)
return topology
@@ -61,7 +61,7 @@ class Topology:
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
]
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection | TBConnection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return [
@@ -74,23 +74,25 @@ class Topology:
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._node_id_to_rx_id_map
def contains_connection(self, connection: Connection) -> bool:
def contains_connection(self, connection: Connection | TBConnection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
connection: Connection,
source: NodeId,
sink: NodeId,
connection: Connection | TBConnection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(connection.local_node_id)
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(connection.send_back_node_id)
if source not in self._node_id_to_rx_id_map:
self.add_node(source)
if sink not in self._node_id_to_rx_id_map:
self.add_node(sink)
if connection in self._edge_id_to_rx_id_map:
return
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._node_id_to_rx_id_map[source]
sink_id = self._node_id_to_rx_id_map[sink]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
@@ -98,30 +100,17 @@ class Topology:
def list_nodes(self) -> Iterable[NodeId]:
return (self._graph[i] for i in self._graph.node_indices())
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def list_connections(self) -> Iterable[tuple[NodeId, NodeId, Connection | TBConnection]]:
return ((self._rx_id_to_node_id_map[src_id], self._rx_id_to_node_id_map[sink_id], connection) for src_id, sink_id, connection in self._graph.weighted_edge_list())
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._node_id_to_rx_id_map:
return
for connection in self.list_connections():
for src, sink, connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
src == node_id
or sink == node_id
):
self.remove_connection(connection)
@@ -131,7 +120,7 @@ class Topology:
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
def remove_connection(self, connection: Connection) -> None:
def remove_connection(self, connection: Connection | TBConnection) -> None:
if connection not in self._edge_id_to_rx_id_map:
return
rx_idx = self._edge_id_to_rx_id_map[connection]
@@ -158,7 +147,8 @@ class Topology:
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
tb_graph.add_edge(u, v, conn)
if isinstance(conn, Connection):
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[list[NodeId]] = []
@@ -174,12 +164,12 @@ class Topology:
topology = Topology()
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
for source, sink, connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
source in node_idxs
and sink in node_idxs
):
topology.add_connection(connection)
topology.add_connection(source, sink, connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:

View File

@@ -1,32 +1,18 @@
from exo.shared.types.common import NodeId
from loguru import logger
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ConnectionProfile
from exo.utils.pydantic_ext import CamelCaseModel
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
class TBConnection(CamelCaseModel):
source_rdma_iface: str
sink_rdma_iface: str
def is_thunderbolt(self) -> bool:
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
logger.warning("duh")
return True
class Connection(CamelCaseModel):
sink_multiaddr: Multiaddr
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")