mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
fix all master tests except rdma placement
This commit is contained in:
@@ -57,15 +57,15 @@ def place_instance(
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
cycles = topology.get_cycles() + [[node] for node in all_nodes]
|
||||
logger.info(cycles)
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles)
|
||||
)
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, node_profiles, command.model_meta.storage_size
|
||||
)
|
||||
if not cycles_with_sufficient_memory:
|
||||
if len(cycles_with_sufficient_memory) == 0:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
@@ -1,38 +1,36 @@
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
|
||||
|
||||
def create_node(memory: int, node_id: NodeId | None = None) -> tuple[NodeId, NodePerformanceProfile]:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return (node_id,
|
||||
NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
def create_node_profile(memory: int) -> NodePerformanceProfile:
|
||||
return NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
|
||||
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
def create_connection(sink_port: int, ip: int) -> SocketConnection:
|
||||
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"
|
||||
),
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||
)
|
||||
|
||||
|
||||
def create_rdma_connection(iface: int) -> RDMAConnection:
|
||||
return RDMAConnection(
|
||||
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
|
||||
)
|
||||
|
||||
@@ -19,15 +19,13 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodePerformanceMeasured,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
MemoryUsage,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
@@ -83,21 +81,14 @@ async def test_master():
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodePerformanceMeasured(
|
||||
NodeGatheredInfo(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
info=MemoryUsage(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -161,7 +152,7 @@ async def test_master():
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[0].event, NodeGatheredInfo)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
@@ -7,14 +5,18 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_connection,
|
||||
create_node_profile,
|
||||
create_rdma_connection,
|
||||
)
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -26,11 +28,6 @@ from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def instance() -> Instance:
|
||||
return MlxRingInstance(
|
||||
@@ -74,30 +71,33 @@ def test_get_instance_placements_create_instance(
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
model_meta.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
cic = place_instance_command(model_meta)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(available_memory[0]),
|
||||
node_id_b: create_node_profile(available_memory[1]),
|
||||
node_id_c: create_node_profile(available_memory[2]),
|
||||
}
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_c, create_connection(2))
|
||||
topology.add_connection(node_id_c, node_id_a, create_connection(3))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {})
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
@@ -123,12 +123,11 @@ def test_get_instance_placements_create_instance(
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -137,7 +136,7 @@ def test_get_instance_placements_one_node_exact_fit(
|
||||
n_layers=10,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -148,12 +147,11 @@ def test_get_instance_placements_one_node_exact_fit(
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1001 * 1024, node_id))
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1001 * 1024)}
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -162,7 +160,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
n_layers=10,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -173,12 +171,11 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
cic = place_instance_command(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -189,7 +186,7 @@ def test_get_instance_placements_one_node_not_fit(
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
||||
place_instance(cic, topology, {})
|
||||
place_instance(cic, topology, {}, profiles)
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(instance: Instance):
|
||||
@@ -235,190 +232,92 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
|
||||
|
||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
|
||||
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
|
||||
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
|
||||
# cycle having LESS total memory than D-E-F. The algorithm should still choose
|
||||
# the cycle that contains a leaf node.
|
||||
# arrange
|
||||
topology = Topology()
|
||||
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size = Memory.from_bytes(1000)
|
||||
|
||||
# Create node ids
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
node_id_d = NodeId()
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
|
||||
node_id_x = NodeId()
|
||||
node_id_y = NodeId()
|
||||
node_id_z = NodeId()
|
||||
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
topology.add_node(create_node(800, node_id_c))
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(500),
|
||||
node_id_b: create_node_profile(600),
|
||||
node_id_c: create_node_profile(600),
|
||||
node_id_d: create_node_profile(500),
|
||||
}
|
||||
|
||||
# D-E-F cycle total memory = 1800 (> A-B-C total)
|
||||
topology.add_node(create_node(600, node_id_d))
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_node(node_id_d)
|
||||
|
||||
# Extra nodes with tiny memory so they can't form singleton placements
|
||||
topology.add_node(create_node(10, node_id_x))
|
||||
topology.add_node(create_node(10, node_id_y))
|
||||
topology.add_node(create_node(10, node_id_z))
|
||||
# Daisy chain topology
|
||||
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_a, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_c, create_connection(1))
|
||||
topology.add_connection(node_id_c, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_c, node_id_d, create_connection(1))
|
||||
topology.add_connection(node_id_d, node_id_c, create_connection(1))
|
||||
|
||||
# Build directed cycles
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
logger.info(list(topology.list_connections()))
|
||||
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
|
||||
# Add extra outgoing edges from D/E/F so none of them are leaves
|
||||
topology.add_connection(create_connection(node_id_d, node_id_x))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_y))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_z))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
|
||||
# D-E-F has more total memory.
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
instance = list(placements.values())[0]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
|
||||
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set((node_id_c, node_id_d))
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
topology = Topology()
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
node_c = NodeId()
|
||||
|
||||
node_a = create_node(500, node_id_a)
|
||||
node_b = create_node(500, node_id_b)
|
||||
node_c = create_node(500, node_id_c)
|
||||
profiles = {
|
||||
node_a: create_node_profile(500),
|
||||
node_b: create_node_profile(500),
|
||||
node_c: create_node_profile(500),
|
||||
}
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(
|
||||
name="en0",
|
||||
ip_address="192.168.1.100",
|
||||
)
|
||||
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
conn_a_b = create_connection(node_id_a, node_id_b)
|
||||
conn_b_c = create_connection(node_id_b, node_id_c)
|
||||
conn_c_a = create_connection(node_id_c, node_id_a)
|
||||
|
||||
conn_b_a = create_connection(node_id_b, node_id_a)
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
profiles[node_a].network_interfaces = [ethernet_interface]
|
||||
profiles[node_b].network_interfaces = [ethernet_interface]
|
||||
profiles[node_c].network_interfaces = [ethernet_interface]
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
topology.add_connection(node_a, node_b, create_rdma_connection(3))
|
||||
topology.add_connection(node_b, node_c, create_rdma_connection(4))
|
||||
topology.add_connection(node_c, node_a, create_rdma_connection(5))
|
||||
topology.add_connection(node_b, node_a, create_rdma_connection(3))
|
||||
topology.add_connection(node_c, node_b, create_rdma_connection(4))
|
||||
topology.add_connection(node_a, node_c, create_rdma_connection(5))
|
||||
|
||||
cic = PlaceInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
@@ -428,7 +327,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
placements = place_instance(cic, topology, {})
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -448,15 +347,15 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
|
||||
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
|
||||
|
||||
idx_a = node_to_idx[node_id_a]
|
||||
idx_b = node_to_idx[node_id_b]
|
||||
idx_c = node_to_idx[node_id_c]
|
||||
idx_a = node_to_idx[node_a]
|
||||
idx_b = node_to_idx[node_b]
|
||||
idx_c = node_to_idx[node_c]
|
||||
|
||||
logger.info(matrix)
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.ibv_coordinators) == 3
|
||||
|
||||
@@ -1,56 +1,48 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import create_connection, create_node_profile
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
topology = Topology()
|
||||
return topology
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
def test_filter_cycles_by_memory():
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
connection1 = create_connection(1)
|
||||
connection2 = create_connection(2)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(node1_id, node2_id, connection1)
|
||||
topology.add_connection(node2_id, node1_id, connection2)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
assert len(cycles) == 1
|
||||
assert len(cycles[0]) == 2
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_bytes(1)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -58,64 +50,65 @@ def test_filter_cycles_by_memory(
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
def test_filter_cycles_by_insufficient_memory():
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
connection1 = create_connection(1)
|
||||
connection2 = create_connection(2)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(node1_id, node2_id, connection1)
|
||||
topology.add_connection(node2_id, node1_id, connection2)
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
topology.get_cycles(), Memory.from_kb(2001)
|
||||
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 0
|
||||
|
||||
|
||||
def test_filter_multiple_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
def test_filter_multiple_cycles_by_memory():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_kb(1500)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -127,31 +120,38 @@ def test_filter_multiple_cycles_by_memory(
|
||||
}
|
||||
|
||||
|
||||
def test_get_smallest_cycles(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
def test_get_smallest_cycles():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
|
||||
cycles = [
|
||||
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||
for cycle in topology.get_cycles()
|
||||
]
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(topology.get_cycles())
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
@@ -168,9 +168,6 @@ def test_get_smallest_cycles(
|
||||
],
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
@@ -179,19 +176,25 @@ def test_get_shard_assignments(
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node(available_memory[0] * 1024, node_a_id)
|
||||
node_b = create_node(available_memory[1] * 1024, node_b_id)
|
||||
node_c = create_node(available_memory[2] * 1024, node_c_id)
|
||||
node_a = create_node_profile(available_memory[0] * 1024)
|
||||
node_b = create_node_profile(available_memory[1] * 1024)
|
||||
node_c = create_node_profile(available_memory[2] * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_c_id, create_connection(2))
|
||||
topology.add_connection(node_c_id, node_a_id, create_connection(3))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(4))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -199,7 +202,11 @@ def test_get_shard_assignments(
|
||||
n_layers=total_layers,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
cycles = [
|
||||
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||
for cycle in topology.get_cycles()
|
||||
]
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# act
|
||||
@@ -228,28 +235,21 @@ def test_get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
def test_get_hosts_from_subgraph():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node(500, node_a_id)
|
||||
node_b = create_node(500, node_b_id)
|
||||
node_c = create_node(1000, node_c_id)
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
@@ -257,104 +257,45 @@ def test_get_hosts_from_subgraph(
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip=("169.254.0.2"), port=5001),
|
||||
Host(ip=("169.254.0.3"), port=5002),
|
||||
Host(ip=("169.254.0.4"), port=5003),
|
||||
Host(ip=("169.254.0.2"), port=1234),
|
||||
Host(ip=("169.254.0.3"), port=1234),
|
||||
Host(ip=("169.254.0.4"), port=1234),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_ibv_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
def test_get_mlx_ibv_coordinators():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
|
||||
# Update node profiles with network interfaces before adding to topology
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
conn_a_b = create_connection(1)
|
||||
conn_b_a = create_connection(2)
|
||||
conn_b_c = create_connection(3)
|
||||
conn_c_b = create_connection(4)
|
||||
conn_c_a = create_connection(5)
|
||||
conn_a_c = create_connection(6)
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
topology.add_connection(node_a_id, node_b_id, conn_a_b)
|
||||
topology.add_connection(node_b_id, node_a_id, conn_b_a)
|
||||
topology.add_connection(node_b_id, node_c_id, conn_b_c)
|
||||
topology.add_connection(node_c_id, node_b_id, conn_c_b)
|
||||
topology.add_connection(node_c_id, node_a_id, conn_c_a)
|
||||
topology.add_connection(node_a_id, node_c_id, conn_a_c)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cycle = [node_a, node_b, node_c]
|
||||
cycle = [node_a_id, node_b_id, node_c_id]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_ibv_coordinators(
|
||||
@@ -387,11 +328,11 @@ def test_get_mlx_ibv_coordinators(
|
||||
|
||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||
assert coordinators[node_b_id] == (
|
||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_b should use the IP from conn_b_a"
|
||||
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
|
||||
"node_b should use the IP from conn_b_a"
|
||||
)
|
||||
|
||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
|
||||
"node_c should use the IP from conn_c_a"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
@@ -8,7 +9,6 @@ from exo.shared.types.profiling import (
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -50,18 +50,13 @@ def test_add_node(topology: Topology):
|
||||
assert topology.node_is_leaf(node_id)
|
||||
|
||||
|
||||
|
||||
def test_add_connection(
|
||||
topology: Topology, connection: SocketConnection
|
||||
):
|
||||
def test_add_connection(topology: Topology, connection: SocketConnection):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(
|
||||
node_b
|
||||
)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
|
||||
# act
|
||||
@@ -73,6 +68,7 @@ def test_add_connection(
|
||||
assert topology.node_is_leaf(node_a)
|
||||
assert topology.node_is_leaf(node_b)
|
||||
|
||||
|
||||
def test_remove_connection_still_connected(
|
||||
topology: Topology, connection: SocketConnection
|
||||
):
|
||||
@@ -80,34 +76,24 @@ def test_remove_connection_still_connected(
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(
|
||||
node_a
|
||||
)
|
||||
topology.add_node(
|
||||
node_b
|
||||
)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
|
||||
# act
|
||||
topology.remove_connection(connection)
|
||||
topology.remove_connection(node_a, node_b, connection)
|
||||
|
||||
# assert
|
||||
assert list(topology.get_all_connections_between(node_a, node_b)) == []
|
||||
|
||||
|
||||
def test_remove_node_still_connected(
|
||||
topology: Topology, connection: SocketConnection
|
||||
):
|
||||
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(
|
||||
node_a
|
||||
)
|
||||
topology.add_node(
|
||||
node_b
|
||||
)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
|
||||
@@ -118,19 +104,13 @@ def test_remove_node_still_connected(
|
||||
assert list(topology.out_edges(node_a)) == []
|
||||
|
||||
|
||||
def test_list_nodes(
|
||||
topology: Topology, connection: SocketConnection
|
||||
):
|
||||
def test_list_nodes(topology: Topology, connection: SocketConnection):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(
|
||||
node_a
|
||||
)
|
||||
topology.add_node(
|
||||
node_b
|
||||
)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
|
||||
@@ -140,6 +120,4 @@ def test_list_nodes(
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeId) for node in nodes)
|
||||
assert {node for node in nodes} == {
|
||||
node_a, node_b
|
||||
}
|
||||
assert {node for node in nodes} == {node_a, node_b}
|
||||
|
||||
@@ -212,6 +212,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
|
||||
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_node(event.node_id)
|
||||
info = event.info
|
||||
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
|
||||
# TODO: should be broken up into individual events instead of this monster
|
||||
@@ -273,7 +274,11 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
|
||||
new_profiles = {**state.node_profiles, event.node_id: profile}
|
||||
return state.model_copy(
|
||||
update={"node_profiles": new_profiles, "last_seen": last_seen, "topology": topology}
|
||||
update={
|
||||
"node_profiles": new_profiles,
|
||||
"last_seen": last_seen,
|
||||
"topology": topology,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -285,8 +290,6 @@ def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> Sta
|
||||
|
||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
topology.remove_connection(event.edge)
|
||||
topology.remove_connection(event.sink, event.source, event.edge)
|
||||
# TODO: Clean up removing the reverse connection
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@@ -22,10 +22,6 @@ class Topology:
|
||||
rx.PyDiGraph()
|
||||
)
|
||||
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
|
||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[SocketConnection | RDMAConnection, int] = (
|
||||
dict()
|
||||
)
|
||||
|
||||
def to_snapshot(self) -> TopologySnapshot:
|
||||
return TopologySnapshot(
|
||||
@@ -51,7 +47,7 @@ class Topology:
|
||||
return
|
||||
rx_id = self._graph.add_node(node)
|
||||
self._node_id_to_rx_id_map[node] = rx_id
|
||||
self._rx_id_to_node_id_map[rx_id] = node
|
||||
self._graph[rx_id] = node
|
||||
|
||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||
return (
|
||||
@@ -61,7 +57,7 @@ class Topology:
|
||||
|
||||
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
||||
return [
|
||||
self._rx_id_to_node_id_map[rx_id]
|
||||
self._graph[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
|
||||
]
|
||||
|
||||
@@ -71,7 +67,7 @@ class Topology:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return []
|
||||
return (
|
||||
(self._rx_id_to_node_id_map[nid], conn)
|
||||
(self._graph[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(
|
||||
self._node_id_to_rx_id_map[node_id]
|
||||
)
|
||||
@@ -80,11 +76,6 @@ class Topology:
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
|
||||
def contains_connection(
|
||||
self, connection: SocketConnection | RDMAConnection
|
||||
) -> bool:
|
||||
return connection in self._edge_id_to_rx_id_map
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
source: NodeId,
|
||||
@@ -96,14 +87,10 @@ class Topology:
|
||||
if sink not in self._node_id_to_rx_id_map:
|
||||
self.add_node(sink)
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
src_id = self._node_id_to_rx_id_map[source]
|
||||
sink_id = self._node_id_to_rx_id_map[sink]
|
||||
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
self._graph.add_edge(src_id, sink_id, connection)
|
||||
|
||||
def get_all_connections_between(
|
||||
self, source: NodeId, sink: NodeId
|
||||
@@ -123,8 +110,8 @@ class Topology:
|
||||
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
|
||||
return (
|
||||
(
|
||||
self._rx_id_to_node_id_map[src_id],
|
||||
self._rx_id_to_node_id_map[sink_id],
|
||||
self._graph[src_id],
|
||||
self._graph[sink_id],
|
||||
connection,
|
||||
)
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list()
|
||||
@@ -134,31 +121,24 @@ class Topology:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
for src, sink, connection in self.list_connections():
|
||||
if src == node_id or sink == node_id:
|
||||
self.remove_connection(connection)
|
||||
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph.remove_node(rx_idx)
|
||||
|
||||
del self._node_id_to_rx_id_map[node_id]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
|
||||
def replace_all_out_tb_connections(
|
||||
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
|
||||
) -> None:
|
||||
for conn in self.out_edges(source):
|
||||
if isinstance(conn, RDMAConnection):
|
||||
self.remove_connection(conn)
|
||||
for conn_idx in self._graph.out_edge_indices(self._node_id_to_rx_id_map[source]):
|
||||
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
for sink, conn in new_connections:
|
||||
self.add_connection(source, sink, conn)
|
||||
|
||||
def remove_connection(self, connection: SocketConnection | RDMAConnection) -> None:
|
||||
if connection not in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.remove_edge_from_index(rx_idx)
|
||||
del self._edge_id_to_rx_id_map[connection]
|
||||
def remove_connection(self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection) -> None:
|
||||
for conn_idx in self._graph.edge_indices_from_endpoints(self._node_id_to_rx_id_map[source], self._node_id_to_rx_id_map[sink]):
|
||||
if self._graph.get_edge_data_by_index(conn_idx) == edge:
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
|
||||
def get_cycles(self) -> list[list[NodeId]]:
|
||||
cycle_idxs = rx.simple_cycles(self._graph)
|
||||
|
||||
@@ -10,9 +10,6 @@ class RDMAConnection(CamelCaseModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.source_rdma_iface, self.sink_rdma_iface))
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
logger.warning("duh")
|
||||
return True
|
||||
@@ -28,8 +25,5 @@ class LinkType(str, Enum):
|
||||
class SocketConnection(CamelCaseModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
@@ -11,8 +11,8 @@ from collections.abc import Sequence
|
||||
import anyio
|
||||
from anyio import create_task_group, open_process
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_CONFIG_FILE
|
||||
|
||||
Reference in New Issue
Block a user