mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
fix topology tests
This commit is contained in:
@@ -1,29 +1,22 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
MemoryUsage,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
def create_node(memory: int, node_id: NodeId | None = None) -> tuple[NodeId, NodePerformanceProfile]:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
return (node_id,
|
||||
NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
@@ -34,34 +27,12 @@ def create_node():
|
||||
),
|
||||
)
|
||||
|
||||
return _create_node
|
||||
|
||||
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
@pytest.fixture
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
ip_counter = 1
|
||||
def create_connection(sink_port: int, ip: int) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"
|
||||
),
|
||||
)
|
||||
|
||||
def _create_connection(
|
||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
||||
) -> Connection:
|
||||
nonlocal port_counter
|
||||
nonlocal ip_counter
|
||||
# assign unique ips
|
||||
ip_counter += 1
|
||||
if send_back_port is None:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
return _create_connection
|
||||
|
||||
@@ -3,11 +3,12 @@ import pytest
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
MemoryUsage,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,20 +17,15 @@ def topology() -> Topology:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection() -> Connection:
|
||||
return Connection(
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
def connection() -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryPerformanceProfile.from_bytes(
|
||||
memory_profile = MemoryUsage.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
)
|
||||
system_profile = SystemPerformanceProfile()
|
||||
@@ -43,162 +39,107 @@ def node_profile() -> NodePerformanceProfile:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_profile() -> ConnectionProfile:
|
||||
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
def test_add_node(topology: Topology):
|
||||
# arrange
|
||||
node_id = NodeId()
|
||||
|
||||
# act
|
||||
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
|
||||
topology.add_node(node_id)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(node_id)
|
||||
assert data == node_profile
|
||||
assert topology.node_is_leaf(node_id)
|
||||
|
||||
|
||||
|
||||
def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
topology: Topology, connection: SocketConnection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
node_b
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
|
||||
# act
|
||||
data = topology.get_connection_profile(connection)
|
||||
data = list(conn for _, _, conn in topology.list_connections())
|
||||
|
||||
# assert
|
||||
assert data == connection.connection_profile
|
||||
|
||||
|
||||
def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_node_profile(
|
||||
connection.local_node_id, node_profile=new_node_profile
|
||||
)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(
|
||||
throughput=2000, latency=2000, jitter=2000
|
||||
)
|
||||
connection = Connection(
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
data = topology.get_connection_profile(connection)
|
||||
assert data == new_connection_profile
|
||||
assert data == [connection]
|
||||
|
||||
assert topology.node_is_leaf(node_a)
|
||||
assert topology.node_is_leaf(node_b)
|
||||
|
||||
def test_remove_connection_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
topology: Topology, connection: SocketConnection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
node_a
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
node_b
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
|
||||
# act
|
||||
topology.remove_connection(connection)
|
||||
|
||||
# assert
|
||||
assert topology.get_connection_profile(connection) is None
|
||||
assert list(topology.get_all_connections_between(node_a, node_b)) == []
|
||||
|
||||
|
||||
def test_remove_node_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
topology: Topology, connection: SocketConnection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
node_a
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
node_b
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
|
||||
# act
|
||||
topology.remove_node(connection.local_node_id)
|
||||
topology.remove_node(node_b)
|
||||
|
||||
# assert
|
||||
assert topology.get_node_profile(connection.local_node_id) is None
|
||||
assert list(topology.out_edges(node_a)) == []
|
||||
|
||||
|
||||
def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
topology: Topology, connection: SocketConnection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
node_a
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
node_b
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
|
||||
# act
|
||||
nodes = list(topology.list_nodes())
|
||||
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
assert all(isinstance(node, NodeId) for node in nodes)
|
||||
assert {node for node in nodes} == {
|
||||
node_a, node_b
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ class Topology:
|
||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||
return (
|
||||
node_id in self._node_id_to_rx_id_map
|
||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) <= 1
|
||||
)
|
||||
|
||||
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
||||
@@ -67,15 +67,15 @@ class Topology:
|
||||
|
||||
def out_edges(
|
||||
self, node_id: NodeId
|
||||
) -> list[tuple[NodeId, SocketConnection | RDMAConnection]]:
|
||||
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return []
|
||||
return [
|
||||
return (
|
||||
(self._rx_id_to_node_id_map[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(
|
||||
self._node_id_to_rx_id_map[node_id]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
@@ -110,7 +110,10 @@ class Topology:
|
||||
) -> Iterable[SocketConnection | RDMAConnection]:
|
||||
src_id = self._node_id_to_rx_id_map[source]
|
||||
sink_id = self._node_id_to_rx_id_map[sink]
|
||||
return self._graph.get_all_edge_data(src_id, sink_id)
|
||||
try:
|
||||
return self._graph.get_all_edge_data(src_id, sink_id)
|
||||
except rx.NoEdgeBetweenNodes:
|
||||
return []
|
||||
|
||||
def list_nodes(self) -> Iterable[NodeId]:
|
||||
return (self._graph[i] for i in self._graph.node_indices())
|
||||
|
||||
Reference in New Issue
Block a user