proper collection of rdma ports in placement

This commit is contained in:
Evan Quiney
2025-12-05 16:42:20 +00:00
committed by GitHub
parent e702313b32
commit f5783d6455
9 changed files with 56 additions and 26 deletions

View File

@@ -260,9 +260,11 @@ def _find_interface_name_for_ip(
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address == ip_address:
logger.info("Found")
return f"rdma_{interface.name}"
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None

View File

@@ -165,9 +165,7 @@ async def test_master():
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()
)[0]
assert events[1].event == InstanceCreated(
event_id=events[1].event.event_id,
instance=MlxRingInstance(
assert events[1].event.instance == MlxRingInstance(
instance_id=events[1].event.instance.instance_id,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
@@ -189,8 +187,7 @@ async def test_master():
node_to_runner={node_id: runner_id},
),
hosts=[],
),
)
)
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)

View File

@@ -1,6 +1,7 @@
from typing import Callable
import pytest
from loguru import logger
from exo.master.placement import (
get_instance_placements_after_create,
@@ -362,7 +363,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
@@ -374,9 +379,13 @@ def test_tensor_rdma_backend_connectivity_matrix(
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
@@ -389,8 +398,12 @@ def test_tensor_rdma_backend_connectivity_matrix(
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en5",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
@@ -403,6 +416,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
cic = CreateInstance(
sharding=Sharding.Tensor,
@@ -436,9 +452,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
assert ":" in instance.ibv_coordinator
assert not instance.ibv_coordinator.startswith("169.254")

View File

@@ -255,9 +255,9 @@ def test_get_hosts_from_subgraph(
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=("169.254.0.1"), port=5001),
Host(ip=("169.254.0.1"), port=5002),
Host(ip=("169.254.0.1"), port=5003),
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts

View File

@@ -95,7 +95,9 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
buffer.ingest(*make_indexed_event(0))
buffer.ingest(*event2_first)
buffer.ingest(*event2_second) # This duplicate should be ignored
with pytest.raises(AssertionError):
buffer.ingest(*event2_second) # This duplicate should be ignored
drained = buffer.drain_indexed()
assert len(drained) == 2

View File

@@ -16,6 +16,7 @@ from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, Sender
from exo.utils.pydantic_ext import CamelCaseModel
DEFAULT_ELECTION_TIMEOUT = 3.0
class ElectionMessage(CamelCaseModel):
clock: int
@@ -151,7 +152,7 @@ class Election:
self._candidates = candidates
logger.debug(f"New candidates: {self._candidates}")
logger.debug("Starting new campaign")
self._tg.start_soon(self._campaign, candidates)
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
logger.debug("Campaign started")
continue
# Dismiss old messages
@@ -180,7 +181,7 @@ class Election:
candidates: list[ElectionMessage] = []
self._candidates = candidates
logger.debug("Starting new campaign")
self._tg.start_soon(self._campaign, candidates)
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
logger.debug("Campaign started")
self._connection_messages.append(first)
self._connection_messages.extend(rest)
@@ -192,7 +193,7 @@ class Election:
self.commands_seen += 1
async def _campaign(
self, candidates: list[ElectionMessage], *, campaign_timeout: float = 3.0
self, candidates: list[ElectionMessage], campaign_timeout: float
) -> None:
clock = self.clock

View File

@@ -2,10 +2,10 @@ import pytest
from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import channel
from exo.shared.election import Election, ElectionMessage, ElectionResult
# ======= #
# Helpers #
@@ -40,6 +40,10 @@ def em(
# TESTS #
# ======================================= #
@pytest.fixture(autouse=True)
def fast_election_timeout(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
@pytest.mark.anyio
async def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> None:
@@ -188,7 +192,7 @@ async def test_ignores_older_messages() -> None:
await em_in_tx.send(em(clock=1, seniority=999, node_id="B"))
got_second = False
with move_on_after(0.2):
with move_on_after(0.05):
_ = await em_out_rx.receive()
got_second = True
assert not got_second, "Should not receive a broadcast for an older round"

View File

@@ -13,7 +13,7 @@ class BaseCommand(TaggedModel):
class TestCommand(BaseCommand):
pass
__test__ = False
class KillCommand(BaseCommand):

View File

@@ -26,7 +26,7 @@ class BaseEvent(TaggedModel):
class TestEvent(BaseEvent):
pass
__test__ = False
class TaskCreated(BaseEvent):
@@ -56,6 +56,12 @@ class TaskFailed(BaseEvent):
class InstanceCreated(BaseEvent):
instance: Instance
def __eq__(self, other: object) -> bool:
if isinstance(other, InstanceCreated):
return self.instance == other.instance and self.event_id == other.event_id
return False
class InstanceDeleted(BaseEvent):
instance_id: InstanceId