diff --git a/src/exo/master/tests/conftest.py b/src/exo/master/tests/conftest.py index 8441cef8..b12fea76 100644 --- a/src/exo/master/tests/conftest.py +++ b/src/exo/master/tests/conftest.py @@ -1,29 +1,22 @@ -from typing import Callable - -import pytest - from exo.shared.types.common import NodeId from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.profiling import ( - MemoryPerformanceProfile, + MemoryUsage, NodePerformanceProfile, SystemPerformanceProfile, ) -from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo +from exo.shared.types.topology import SocketConnection -@pytest.fixture -def create_node(): - def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo: +def create_node(memory: int, node_id: NodeId | None = None) -> tuple[NodeId, NodePerformanceProfile]: if node_id is None: node_id = NodeId() - return NodeInfo( - node_id=node_id, - node_profile=NodePerformanceProfile( + return (node_id, + NodePerformanceProfile( model_id="test", chip_id="test", friendly_name="test", - memory=MemoryPerformanceProfile.from_bytes( + memory=MemoryUsage.from_bytes( ram_total=1000, ram_available=memory, swap_total=1000, @@ -34,34 +27,12 @@ def create_node(): ), ) - return _create_node - # TODO: this is a hack to get the port for the send_back_multiaddr -@pytest.fixture -def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]: - port_counter = 1235 - ip_counter = 1 +def create_connection(sink_port: int, ip: int) -> SocketConnection: + return SocketConnection( + sink_multiaddr=Multiaddr( + address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}" + ), + ) - def _create_connection( - source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None - ) -> Connection: - nonlocal port_counter - nonlocal ip_counter - # assign unique ips - ip_counter += 1 - if send_back_port is None: - send_back_port = port_counter - port_counter += 1 - return Connection( - local_node_id=source_node_id, - send_back_node_id=sink_node_id, - send_back_multiaddr=Multiaddr( - address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}" - ), - connection_profile=ConnectionProfile( - throughput=1000, latency=1000, jitter=1000 - ), - ) - - return _create_connection diff --git a/src/exo/master/tests/test_topology.py b/src/exo/master/tests/test_topology.py index d6afb339..365be18e 100644 --- a/src/exo/master/tests/test_topology.py +++ b/src/exo/master/tests/test_topology.py @@ -3,11 +3,12 @@ import pytest from exo.shared.topology import Topology from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.profiling import ( - MemoryPerformanceProfile, + MemoryUsage, NodePerformanceProfile, SystemPerformanceProfile, ) -from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo +from exo.shared.types.topology import SocketConnection +from exo.shared.types.common import NodeId @pytest.fixture @@ -16,20 +17,15 @@ def topology() -> Topology: @pytest.fixture -def connection() -> Connection: - return Connection( - local_node_id=NodeId(), - send_back_node_id=NodeId(), - send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"), - connection_profile=ConnectionProfile( - throughput=1000, latency=1000, jitter=1000 - ), +def connection() -> SocketConnection: + return SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"), ) @pytest.fixture def node_profile() -> NodePerformanceProfile: - memory_profile = MemoryPerformanceProfile.from_bytes( + memory_profile = MemoryUsage.from_bytes( ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000 ) system_profile = SystemPerformanceProfile() @@ -43,162 +39,107 @@ def node_profile() -> NodePerformanceProfile: ) -@pytest.fixture -def connection_profile() -> ConnectionProfile: - return ConnectionProfile(throughput=1000, latency=1000, jitter=1000) - - -def test_add_node(topology: Topology, node_profile: NodePerformanceProfile): +def test_add_node(topology: Topology): # arrange node_id = NodeId() # act - topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile)) + topology.add_node(node_id) # assert - data = topology.get_node_profile(node_id) - assert data == node_profile + assert topology.node_is_leaf(node_id) + def test_add_connection( - topology: Topology, node_profile: NodePerformanceProfile, connection: Connection + topology: Topology, connection: SocketConnection ): # arrange + node_a = NodeId() + node_b = NodeId() + + topology.add_node(node_a) topology.add_node( - NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + node_b ) - topology.add_node( - NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) - ) - topology.add_connection(connection) + topology.add_connection(node_a, node_b, connection) # act - data = topology.get_connection_profile(connection) + data = list(conn for _, _, conn in topology.list_connections()) # assert - assert data == connection.connection_profile - - -def test_update_node_profile( - topology: Topology, node_profile: NodePerformanceProfile, connection: Connection -): - # arrange - topology.add_node( - NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) - ) - topology.add_node( - NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) - ) - topology.add_connection(connection) - - new_node_profile = NodePerformanceProfile( - model_id="test", - chip_id="test", - friendly_name="test", - memory=MemoryPerformanceProfile.from_bytes( - ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000 - ), - network_interfaces=[], - system=SystemPerformanceProfile(), - ) - - # act - topology.update_node_profile( - connection.local_node_id, node_profile=new_node_profile - ) - - # assert - data = topology.get_node_profile(connection.local_node_id) - assert data == new_node_profile - - -def test_update_connection_profile( - topology: Topology, node_profile: NodePerformanceProfile, connection: Connection -): - # arrange - topology.add_node( - NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) - ) - topology.add_node( - NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) - ) - topology.add_connection(connection) - - new_connection_profile = ConnectionProfile( - throughput=2000, latency=2000, jitter=2000 - ) - connection = Connection( - local_node_id=connection.local_node_id, - send_back_node_id=connection.send_back_node_id, - send_back_multiaddr=connection.send_back_multiaddr, - connection_profile=new_connection_profile, - ) - - # act - topology.update_connection_profile(connection) - - # assert - data = topology.get_connection_profile(connection) - assert data == new_connection_profile + assert data == [connection] + assert topology.node_is_leaf(node_a) + assert topology.node_is_leaf(node_b) def test_remove_connection_still_connected( - topology: Topology, node_profile: NodePerformanceProfile, connection: Connection + topology: Topology, connection: SocketConnection ): # arrange + node_a = NodeId() + node_b = NodeId() + topology.add_node( - NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + node_a ) topology.add_node( - NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) + node_b ) - topology.add_connection(connection) + topology.add_connection(node_a, node_b, connection) # act topology.remove_connection(connection) # assert - assert topology.get_connection_profile(connection) is None + assert list(topology.get_all_connections_between(node_a, node_b)) == [] def test_remove_node_still_connected( - topology: Topology, node_profile: NodePerformanceProfile, connection: Connection + topology: Topology, connection: SocketConnection ): # arrange + node_a = NodeId() + node_b = NodeId() + topology.add_node( - NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + node_a ) topology.add_node( - NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) + node_b ) - topology.add_connection(connection) + topology.add_connection(node_a, node_b, connection) + assert list(topology.out_edges(node_a)) == [(node_b, connection)] # act - topology.remove_node(connection.local_node_id) + topology.remove_node(node_b) # assert - assert topology.get_node_profile(connection.local_node_id) is None + assert list(topology.out_edges(node_a)) == [] def test_list_nodes( - topology: Topology, node_profile: NodePerformanceProfile, connection: Connection + topology: Topology, connection: SocketConnection ): # arrange + node_a = NodeId() + node_b = NodeId() + topology.add_node( - NodeInfo(node_id=connection.local_node_id, node_profile=node_profile) + node_a ) topology.add_node( - NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile) + node_b ) - topology.add_connection(connection) + topology.add_connection(node_a, node_b, connection) + assert list(topology.out_edges(node_a)) == [(node_b, connection)] # act nodes = list(topology.list_nodes()) # assert assert len(nodes) == 2 - assert all(isinstance(node, NodeInfo) for node in nodes) - assert {node.node_id for node in nodes} == { - connection.local_node_id, - connection.send_back_node_id, + assert all(isinstance(node, NodeId) for node in nodes) + assert {node for node in nodes} == { + node_a, node_b } diff --git a/src/exo/shared/topology.py b/src/exo/shared/topology.py index 6deef4d6..c3471ca1 100644 --- a/src/exo/shared/topology.py +++ b/src/exo/shared/topology.py @@ -56,7 +56,7 @@ class Topology: def node_is_leaf(self, node_id: NodeId) -> bool: return ( node_id in self._node_id_to_rx_id_map - and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1 + and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) <= 1 ) def neighbours(self, node_id: NodeId) -> list[NodeId]: @@ -67,15 +67,15 @@ class Topology: def out_edges( self, node_id: NodeId - ) -> list[tuple[NodeId, SocketConnection | RDMAConnection]]: + ) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]: if node_id not in self._node_id_to_rx_id_map: return [] - return [ + return ( (self._rx_id_to_node_id_map[nid], conn) for _, nid, conn in self._graph.out_edges( self._node_id_to_rx_id_map[node_id] ) - ] + ) def contains_node(self, node_id: NodeId) -> bool: return node_id in self._node_id_to_rx_id_map @@ -110,7 +110,10 @@ class Topology: ) -> Iterable[SocketConnection | RDMAConnection]: src_id = self._node_id_to_rx_id_map[source] sink_id = self._node_id_to_rx_id_map[sink] - return self._graph.get_all_edge_data(src_id, sink_id) + try: + return self._graph.get_all_edge_data(src_id, sink_id) + except rx.NoEdgeBetweenNodes: + return [] def list_nodes(self) -> Iterable[NodeId]: return (self._graph[i] for i in self._graph.node_indices())