mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
proper collection of rdma ports in placement
This commit is contained in:
@@ -260,9 +260,11 @@ def _find_interface_name_for_ip(
|
||||
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
|
||||
continue
|
||||
logger.info(f" | {interface.name}: {interface.ip_address}")
|
||||
if interface.ip_address == ip_address:
|
||||
logger.info("Found")
|
||||
return f"rdma_{interface.name}"
|
||||
if interface.ip_address != ip_address:
|
||||
continue
|
||||
|
||||
logger.info("Found")
|
||||
return f"rdma_{interface.name}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -165,9 +165,7 @@ async def test_master():
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event == InstanceCreated(
|
||||
event_id=events[1].event.event_id,
|
||||
instance=MlxRingInstance(
|
||||
assert events[1].event.instance == MlxRingInstance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
@@ -189,8 +187,7 @@ async def test_master():
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
get_instance_placements_after_create,
|
||||
@@ -362,7 +363,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
@@ -374,9 +379,13 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
@@ -389,8 +398,12 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en5",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
@@ -403,6 +416,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cic = CreateInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
@@ -436,9 +452,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
idx_b = node_to_idx[node_id_b]
|
||||
idx_c = node_to_idx[node_id_c]
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
logger.info(matrix)
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
|
||||
assert ":" in instance.ibv_coordinator
|
||||
assert not instance.ibv_coordinator.startswith("169.254")
|
||||
|
||||
@@ -255,9 +255,9 @@ def test_get_hosts_from_subgraph(
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip=("169.254.0.1"), port=5001),
|
||||
Host(ip=("169.254.0.1"), port=5002),
|
||||
Host(ip=("169.254.0.1"), port=5003),
|
||||
Host(ip=("169.254.0.2"), port=5001),
|
||||
Host(ip=("169.254.0.3"), port=5002),
|
||||
Host(ip=("169.254.0.4"), port=5003),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
@@ -95,7 +95,9 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
|
||||
|
||||
buffer.ingest(*make_indexed_event(0))
|
||||
buffer.ingest(*event2_first)
|
||||
buffer.ingest(*event2_second) # This duplicate should be ignored
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
buffer.ingest(*event2_second) # This duplicate should be ignored
|
||||
|
||||
drained = buffer.drain_indexed()
|
||||
assert len(drained) == 2
|
||||
|
||||
@@ -16,6 +16,7 @@ from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
DEFAULT_ELECTION_TIMEOUT = 3.0
|
||||
|
||||
class ElectionMessage(CamelCaseModel):
|
||||
clock: int
|
||||
@@ -151,7 +152,7 @@ class Election:
|
||||
self._candidates = candidates
|
||||
logger.debug(f"New candidates: {self._candidates}")
|
||||
logger.debug("Starting new campaign")
|
||||
self._tg.start_soon(self._campaign, candidates)
|
||||
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
|
||||
logger.debug("Campaign started")
|
||||
continue
|
||||
# Dismiss old messages
|
||||
@@ -180,7 +181,7 @@ class Election:
|
||||
candidates: list[ElectionMessage] = []
|
||||
self._candidates = candidates
|
||||
logger.debug("Starting new campaign")
|
||||
self._tg.start_soon(self._campaign, candidates)
|
||||
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
|
||||
logger.debug("Campaign started")
|
||||
self._connection_messages.append(first)
|
||||
self._connection_messages.extend(rest)
|
||||
@@ -192,7 +193,7 @@ class Election:
|
||||
self.commands_seen += 1
|
||||
|
||||
async def _campaign(
|
||||
self, candidates: list[ElectionMessage], *, campaign_timeout: float = 3.0
|
||||
self, candidates: list[ElectionMessage], campaign_timeout: float
|
||||
) -> None:
|
||||
clock = self.clock
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import channel
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
|
||||
# ======= #
|
||||
# Helpers #
|
||||
@@ -40,6 +40,10 @@ def em(
|
||||
# TESTS #
|
||||
# ======================================= #
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fast_election_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> None:
|
||||
@@ -188,7 +192,7 @@ async def test_ignores_older_messages() -> None:
|
||||
await em_in_tx.send(em(clock=1, seniority=999, node_id="B"))
|
||||
|
||||
got_second = False
|
||||
with move_on_after(0.2):
|
||||
with move_on_after(0.05):
|
||||
_ = await em_out_rx.receive()
|
||||
got_second = True
|
||||
assert not got_second, "Should not receive a broadcast for an older round"
|
||||
|
||||
@@ -13,7 +13,7 @@ class BaseCommand(TaggedModel):
|
||||
|
||||
|
||||
class TestCommand(BaseCommand):
|
||||
pass
|
||||
__test__ = False
|
||||
|
||||
|
||||
class KillCommand(BaseCommand):
|
||||
|
||||
@@ -26,7 +26,7 @@ class BaseEvent(TaggedModel):
|
||||
|
||||
|
||||
class TestEvent(BaseEvent):
|
||||
pass
|
||||
__test__ = False
|
||||
|
||||
|
||||
class TaskCreated(BaseEvent):
|
||||
@@ -56,6 +56,12 @@ class TaskFailed(BaseEvent):
|
||||
class InstanceCreated(BaseEvent):
|
||||
instance: Instance
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, InstanceCreated):
|
||||
return self.instance == other.instance and self.event_id == other.event_id
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
|
||||
Reference in New Issue
Block a user