mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
rework topology
This commit is contained in:
@@ -207,6 +207,7 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -262,6 +263,7 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -427,8 +429,8 @@ class API:
|
||||
total_available = Memory()
|
||||
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available
|
||||
if node in self.state.node_profiles:
|
||||
total_available += self.state.node_profiles[node].memory.ram_available
|
||||
|
||||
return total_available
|
||||
|
||||
|
||||
@@ -158,6 +158,7 @@ class Master:
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_profiles,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
|
||||
@@ -5,23 +5,23 @@ import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import ConnectionProfile
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.topology import Connection, TBConnection
|
||||
|
||||
|
||||
class TopologySnapshot(BaseModel):
|
||||
nodes: list[NodeId]
|
||||
connections: list[Connection]
|
||||
connections: list[tuple[NodeId, NodeId, Connection | TBConnection]]
|
||||
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
|
||||
|
||||
|
||||
class Topology:
|
||||
def __init__(self) -> None:
|
||||
self._graph: rx.PyDiGraph[NodeId, Connection] = rx.PyDiGraph()
|
||||
self._graph: rx.PyDiGraph[NodeId, Connection | TBConnection] = rx.PyDiGraph()
|
||||
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
|
||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[Connection | TBConnection, int] = dict()
|
||||
|
||||
def to_snapshot(self) -> TopologySnapshot:
|
||||
return TopologySnapshot(
|
||||
@@ -37,8 +37,8 @@ class Topology:
|
||||
with contextlib.suppress(ValueError):
|
||||
topology.add_node(node)
|
||||
|
||||
for connection in snapshot.connections:
|
||||
topology.add_connection(connection)
|
||||
for source, sink, connection in snapshot.connections:
|
||||
topology.add_connection(source, sink, connection)
|
||||
|
||||
return topology
|
||||
|
||||
@@ -61,7 +61,7 @@ class Topology:
|
||||
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
|
||||
]
|
||||
|
||||
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
|
||||
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection | TBConnection]]:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return []
|
||||
return [
|
||||
@@ -74,23 +74,25 @@ class Topology:
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
|
||||
def contains_connection(self, connection: Connection) -> bool:
|
||||
def contains_connection(self, connection: Connection | TBConnection) -> bool:
|
||||
return connection in self._edge_id_to_rx_id_map
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
connection: Connection,
|
||||
source: NodeId,
|
||||
sink: NodeId,
|
||||
connection: Connection | TBConnection,
|
||||
) -> None:
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(connection.local_node_id)
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(connection.send_back_node_id)
|
||||
if source not in self._node_id_to_rx_id_map:
|
||||
self.add_node(source)
|
||||
if sink not in self._node_id_to_rx_id_map:
|
||||
self.add_node(sink)
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
src_id = self._node_id_to_rx_id_map[source]
|
||||
sink_id = self._node_id_to_rx_id_map[sink]
|
||||
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
@@ -98,30 +100,17 @@ class Topology:
|
||||
def list_nodes(self) -> Iterable[NodeId]:
|
||||
return (self._graph[i] for i in self._graph.node_indices())
|
||||
|
||||
def list_connections(self) -> Iterable[Connection]:
|
||||
return (connection for _, _, connection in self._graph.weighted_edge_list())
|
||||
|
||||
def update_connection_profile(self, connection: Connection) -> None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.update_edge_by_index(rx_idx, connection)
|
||||
|
||||
def get_connection_profile(
|
||||
self, connection: Connection
|
||||
) -> ConnectionProfile | None:
|
||||
try:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
|
||||
except KeyError:
|
||||
return None
|
||||
def list_connections(self) -> Iterable[tuple[NodeId, NodeId, Connection | TBConnection]]:
|
||||
return ((self._rx_id_to_node_id_map[src_id], self._rx_id_to_node_id_map[sink_id], connection) for src_id, sink_id, connection in self._graph.weighted_edge_list())
|
||||
|
||||
def remove_node(self, node_id: NodeId) -> None:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
for connection in self.list_connections():
|
||||
for src, sink, connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_id
|
||||
or connection.send_back_node_id == node_id
|
||||
src == node_id
|
||||
or sink == node_id
|
||||
):
|
||||
self.remove_connection(connection)
|
||||
|
||||
@@ -131,7 +120,7 @@ class Topology:
|
||||
del self._node_id_to_rx_id_map[node_id]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
|
||||
def remove_connection(self, connection: Connection) -> None:
|
||||
def remove_connection(self, connection: Connection | TBConnection) -> None:
|
||||
if connection not in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
@@ -158,7 +147,8 @@ class Topology:
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
if isinstance(conn, Connection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycles: list[list[NodeId]] = []
|
||||
@@ -174,12 +164,12 @@ class Topology:
|
||||
topology = Topology()
|
||||
for rx_idx in rx_idxs:
|
||||
topology.add_node(self._graph[rx_idx])
|
||||
for connection in self.list_connections():
|
||||
for source, sink, connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id in node_idxs
|
||||
and connection.send_back_node_id in node_idxs
|
||||
source in node_idxs
|
||||
and sink in node_idxs
|
||||
):
|
||||
topology.add_connection(connection)
|
||||
topology.add_connection(source, sink, connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
|
||||
|
||||
@@ -1,32 +1,18 @@
|
||||
from exo.shared.types.common import NodeId
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import ConnectionProfile
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class Connection(CamelCaseModel):
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
)
|
||||
class TBConnection(CamelCaseModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
||||
logger.warning("duh")
|
||||
return True
|
||||
|
||||
class Connection(CamelCaseModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
Reference in New Issue
Block a user