diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 88563713..8cb81adb 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -260,9 +260,11 @@ def _find_interface_name_for_ip( if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]: continue logger.info(f" | {interface.name}: {interface.ip_address}") - if interface.ip_address == ip_address: - logger.info("Found") - return f"rdma_{interface.name}" + if interface.ip_address != ip_address: + continue + + logger.info("Found") + return f"rdma_{interface.name}" return None diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index c5d3ae47..90c55c5b 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -165,9 +165,7 @@ async def test_master(): runner_id = list( events[1].event.instance.shard_assignments.runner_to_shard.keys() )[0] - assert events[1].event == InstanceCreated( - event_id=events[1].event.event_id, - instance=MlxRingInstance( + assert events[1].event.instance == MlxRingInstance( instance_id=events[1].event.instance.instance_id, shard_assignments=ShardAssignments( model_id=ModelId("llama-3.2-1b"), @@ -189,8 +187,7 @@ async def test_master(): node_to_runner={node_id: runner_id}, ), hosts=[], - ), - ) + ) assert isinstance(events[2].event, TaskCreated) assert events[2].event.task.task_status == TaskStatus.Pending assert isinstance(events[2].event.task, ChatCompletionTask) diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 699b2ff1..0eb7bd67 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -1,6 +1,7 @@ from typing import Callable import pytest +from loguru import logger from exo.master.placement import ( get_instance_placements_after_create, @@ -362,7 +363,11 @@ def test_tensor_rdma_backend_connectivity_matrix( network_interfaces=[ NetworkInterfaceInfo( name="en3", - ip_address=conn_a_b.send_back_multiaddr.ip_address, + ip_address=conn_c_a.send_back_multiaddr.ip_address, + ), + NetworkInterfaceInfo( + name="en4", + ip_address=conn_b_a.send_back_multiaddr.ip_address, ), ethernet_interface, ], @@ -374,9 +379,13 @@ def test_tensor_rdma_backend_connectivity_matrix( friendly_name="test", memory=node_b.node_profile.memory, network_interfaces=[ + NetworkInterfaceInfo( + name="en3", + ip_address=conn_c_b.send_back_multiaddr.ip_address, + ), NetworkInterfaceInfo( name="en4", - ip_address=conn_b_c.send_back_multiaddr.ip_address, + ip_address=conn_a_b.send_back_multiaddr.ip_address, ), ethernet_interface, ], @@ -389,8 +398,12 @@ def test_tensor_rdma_backend_connectivity_matrix( memory=node_c.node_profile.memory, network_interfaces=[ NetworkInterfaceInfo( - name="en5", - ip_address=conn_c_a.send_back_multiaddr.ip_address, + name="en3", + ip_address=conn_a_c.send_back_multiaddr.ip_address, + ), + NetworkInterfaceInfo( + name="en4", + ip_address=conn_b_c.send_back_multiaddr.ip_address, ), ethernet_interface, ], @@ -403,6 +416,9 @@ def test_tensor_rdma_backend_connectivity_matrix( topology.add_connection(conn_a_b) topology.add_connection(conn_b_c) topology.add_connection(conn_c_a) + topology.add_connection(conn_b_a) + topology.add_connection(conn_c_b) + topology.add_connection(conn_a_c) cic = CreateInstance( sharding=Sharding.Tensor, @@ -436,9 +452,11 @@ def test_tensor_rdma_backend_connectivity_matrix( idx_b = node_to_idx[node_id_b] idx_c = node_to_idx[node_id_c] - assert matrix[idx_a][idx_b] == "rdma_en3" - assert matrix[idx_b][idx_c] == "rdma_en4" - assert matrix[idx_c][idx_a] == "rdma_en5" + logger.info(matrix) + + assert matrix[idx_a][idx_b] == "rdma_en4" + assert matrix[idx_b][idx_c] == "rdma_en3" + assert matrix[idx_c][idx_a] == "rdma_en3" assert ":" in instance.ibv_coordinator assert not instance.ibv_coordinator.startswith("169.254") diff --git a/src/exo/master/tests/test_placement_utils.py b/src/exo/master/tests/test_placement_utils.py index d5f42ccf..eb1d4e10 100644 --- a/src/exo/master/tests/test_placement_utils.py +++ b/src/exo/master/tests/test_placement_utils.py @@ -255,9 +255,9 @@ def test_get_hosts_from_subgraph( # assert assert len(hosts) == 3 expected_hosts = [ - Host(ip=("169.254.0.1"), port=5001), - Host(ip=("169.254.0.1"), port=5002), - Host(ip=("169.254.0.1"), port=5003), + Host(ip=("169.254.0.2"), port=5001), + Host(ip=("169.254.0.3"), port=5002), + Host(ip=("169.254.0.4"), port=5003), ] for expected_host in expected_hosts: assert expected_host in hosts diff --git a/src/exo/routing/tests/test_event_buffer.py b/src/exo/routing/tests/test_event_buffer.py index a6f48a96..0e3e458c 100644 --- a/src/exo/routing/tests/test_event_buffer.py +++ b/src/exo/routing/tests/test_event_buffer.py @@ -95,7 +95,9 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]): buffer.ingest(*make_indexed_event(0)) buffer.ingest(*event2_first) - buffer.ingest(*event2_second) # This duplicate should be ignored + + with pytest.raises(AssertionError): + buffer.ingest(*event2_second) # This duplicate should be ignored drained = buffer.drain_indexed() assert len(drained) == 2 diff --git a/src/exo/shared/election.py b/src/exo/shared/election.py index 9d90642c..ccbbee52 100644 --- a/src/exo/shared/election.py +++ b/src/exo/shared/election.py @@ -16,6 +16,7 @@ from exo.shared.types.common import NodeId, SessionId from exo.utils.channels import Receiver, Sender from exo.utils.pydantic_ext import CamelCaseModel +DEFAULT_ELECTION_TIMEOUT = 3.0 class ElectionMessage(CamelCaseModel): clock: int @@ -151,7 +152,7 @@ class Election: self._candidates = candidates logger.debug(f"New candidates: {self._candidates}") logger.debug("Starting new campaign") - self._tg.start_soon(self._campaign, candidates) + self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT) logger.debug("Campaign started") continue # Dismiss old messages @@ -180,7 +181,7 @@ class Election: candidates: list[ElectionMessage] = [] self._candidates = candidates logger.debug("Starting new campaign") - self._tg.start_soon(self._campaign, candidates) + self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT) logger.debug("Campaign started") self._connection_messages.append(first) self._connection_messages.extend(rest) @@ -192,7 +193,7 @@ class Election: self.commands_seen += 1 async def _campaign( - self, candidates: list[ElectionMessage], *, campaign_timeout: float = 3.0 + self, candidates: list[ElectionMessage], campaign_timeout: float ) -> None: clock = self.clock diff --git a/src/exo/shared/tests/test_election.py b/src/exo/shared/tests/test_election.py index 894c55ce..525b35a2 100644 --- a/src/exo/shared/tests/test_election.py +++ b/src/exo/shared/tests/test_election.py @@ -2,10 +2,10 @@ import pytest from anyio import create_task_group, fail_after, move_on_after from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType -from exo.shared.election import Election, ElectionMessage, ElectionResult from exo.shared.types.commands import ForwarderCommand, TestCommand from exo.shared.types.common import NodeId, SessionId from exo.utils.channels import channel +from exo.shared.election import Election, ElectionMessage, ElectionResult # ======= # # Helpers # @@ -40,6 +40,10 @@ def em( # TESTS # # ======================================= # +@pytest.fixture(autouse=True) +def fast_election_timeout(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1) + @pytest.mark.anyio async def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> None: @@ -188,7 +192,7 @@ async def test_ignores_older_messages() -> None: await em_in_tx.send(em(clock=1, seniority=999, node_id="B")) got_second = False - with move_on_after(0.2): + with move_on_after(0.05): _ = await em_out_rx.receive() got_second = True assert not got_second, "Should not receive a broadcast for an older round" diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index 39c117f9..1ea4027a 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -13,7 +13,7 @@ class BaseCommand(TaggedModel): class TestCommand(BaseCommand): - pass + __test__ = False class KillCommand(BaseCommand): diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py index 3cc1c872..7ad465d4 100644 --- a/src/exo/shared/types/events.py +++ b/src/exo/shared/types/events.py @@ -26,7 +26,7 @@ class BaseEvent(TaggedModel): class TestEvent(BaseEvent): - pass + __test__ = False class TaskCreated(BaseEvent): @@ -56,6 +56,12 @@ class TaskFailed(BaseEvent): class InstanceCreated(BaseEvent): instance: Instance + def __eq__(self, other: object) -> bool: + if isinstance(other, InstanceCreated): + return self.instance == other.instance and self.event_id == other.event_id + + return False + class InstanceDeleted(BaseEvent): instance_id: InstanceId