fix topology tests

This commit is contained in:
Evan
2025-12-19 19:21:48 +00:00
parent a30c13c4d9
commit 83037d289e
3 changed files with 71 additions and 156 deletions

View File

@@ -1,29 +1,22 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
from exo.shared.types.topology import SocketConnection
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
def create_node(memory: int, node_id: NodeId | None = None) -> tuple[NodeId, NodePerformanceProfile]:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
return (node_id,
NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
@@ -34,34 +27,12 @@ def create_node():
),
)
return _create_node
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def create_connection(sink_port: int, ip: int) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"
),
)
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
return _create_connection

View File

@@ -3,11 +3,12 @@ import pytest
from exo.shared.topology import Topology
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
from exo.shared.types.topology import SocketConnection
from exo.shared.types.common import NodeId
@pytest.fixture
@@ -16,20 +17,15 @@ def topology() -> Topology:
@pytest.fixture
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
def connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile.from_bytes(
memory_profile = MemoryUsage.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -43,162 +39,107 @@ def node_profile() -> NodePerformanceProfile:
)
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
def test_add_node(topology: Topology):
# arrange
node_id = NodeId()
# act
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
topology.add_node(node_id)
# assert
data = topology.get_node_profile(node_id)
assert data == node_profile
assert topology.node_is_leaf(node_id)
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, connection: SocketConnection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
node_b
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
topology.add_connection(node_a, node_b, connection)
# act
data = topology.get_connection_profile(connection)
data = list(conn for _, _, conn in topology.list_connections())
# assert
assert data == connection.connection_profile
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
assert data == [connection]
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_remove_connection_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, connection: SocketConnection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
node_a
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
node_b
)
topology.add_connection(connection)
topology.add_connection(node_a, node_b, connection)
# act
topology.remove_connection(connection)
# assert
assert topology.get_connection_profile(connection) is None
assert list(topology.get_all_connections_between(node_a, node_b)) == []
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, connection: SocketConnection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
node_a
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
node_b
)
topology.add_connection(connection)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
# act
topology.remove_node(connection.local_node_id)
topology.remove_node(node_b)
# assert
assert topology.get_node_profile(connection.local_node_id) is None
assert list(topology.out_edges(node_a)) == []
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, connection: SocketConnection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
node_a
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
node_b
)
topology.add_connection(connection)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
assert all(isinstance(node, NodeId) for node in nodes)
assert {node for node in nodes} == {
node_a, node_b
}

View File

@@ -56,7 +56,7 @@ class Topology:
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) <= 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
@@ -67,15 +67,15 @@ class Topology:
def out_edges(
self, node_id: NodeId
) -> list[tuple[NodeId, SocketConnection | RDMAConnection]]:
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return [
return (
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
]
)
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._node_id_to_rx_id_map
@@ -110,7 +110,10 @@ class Topology:
) -> Iterable[SocketConnection | RDMAConnection]:
src_id = self._node_id_to_rx_id_map[source]
sink_id = self._node_id_to_rx_id_map[sink]
return self._graph.get_all_edge_data(src_id, sink_id)
try:
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
def list_nodes(self) -> Iterable[NodeId]:
return (self._graph[i] for i in self._graph.node_indices())