placement: pass different ibv_coordinator per node

This commit is contained in:
Jake Hillion
2025-12-05 17:23:22 +00:00
committed by GitHub
parent 39d76aa0a5
commit e8566a3f95
15 changed files with 263 additions and 60 deletions

View File

@@ -8,7 +8,7 @@ from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinator,
get_mlx_ibv_coordinators,
get_mlx_ibv_devices_matrix,
get_shard_assignments,
get_smallest_cycles,
@@ -110,15 +110,16 @@ def get_instance_placements_after_create(
selected_cycle,
cycle_digraph,
)
mlx_ibv_coordinator = get_mlx_ibv_coordinator(
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
ibv_coordinator=mlx_ibv_coordinator,
ibv_coordinators=mlx_ibv_coordinators,
)
case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)

View File

@@ -269,20 +269,31 @@ def _find_interface_name_for_ip(
return None
def get_mlx_ibv_coordinator(
def get_mlx_ibv_coordinators(
selected_cycle: list[NodeInfo],
coordinator_port: int,
) -> str:
"""Get the coordinator address for MLX IBV (rank 0 device).
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX IBV (rank 0 device).
Selects a non-thunderbolt IP address from rank 0 node as a heuristic for
ethernet accessibility. Returns address in format "X.X.X.X:PORT".
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
assert rank_0_node.node_profile is not None
for iface in rank_0_node.node_profile.network_interfaces:
if iface.name == "en0" and "." in iface.ip_address:
return f"{iface.ip_address}:{coordinator_port}"
raise ValueError("No en0 iface found for device")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
return ip
logger.warning(
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}

View File

@@ -166,28 +166,28 @@ async def test_master():
events[1].event.instance.shard_assignments.runner_to_shard.keys()
)[0]
assert events[1].event.instance == MlxRingInstance(
instance_id=events[1].event.instance.instance_id,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
instance_id=events[1].event.instance.instance_id,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
),
hosts=[],
)
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
),
hosts=[],
)
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)

View File

@@ -437,7 +437,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.ibv_coordinator is not None
assert instance.ibv_coordinators is not None
matrix = instance.ibv_devices
assert len(matrix) == 3
@@ -458,5 +458,17 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
assert ":" in instance.ibv_coordinator
assert not instance.ibv_coordinator.startswith("169.254")
# Verify coordinators are set for all nodes
assert len(instance.ibv_coordinators) == 3
for node_id in assigned_nodes:
assert node_id in instance.ibv_coordinators
coordinator = instance.ibv_coordinators[node_id]
assert ":" in coordinator
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
if node_id == assigned_nodes[0]:
assert coordinator.startswith("0.0.0.0:")
else:
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
ip_part = coordinator.split(":")[0]
# Just verify it's a valid IP format
assert len(ip_part.split(".")) == 4

View File

@@ -5,6 +5,7 @@ import pytest
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
@@ -12,6 +13,7 @@ from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
@@ -261,3 +263,135 @@ def test_get_hosts_from_subgraph(
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_ibv_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_b)
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_ibv_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
)
# assert
assert len(coordinators) == 3
assert node_a_id in coordinators
assert node_b_id in coordinators
assert node_c_id in coordinators
# All coordinators should have IP:PORT format
for node_id, coordinator in coordinators.items():
assert ":" in coordinator, (
f"Coordinator for {node_id} should have ':' separator"
)
# Verify port is correct
for node_id, coordinator in coordinators.items():
assert coordinator.endswith(":5000"), (
f"Coordinator for {node_id} should use port 5000"
)
# Rank 0 (node_a) treats this as the listen socket so should listen on all
# IPs
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
"Rank 0 node should use localhost as coordinator"
)
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"

View File

@@ -95,7 +95,7 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
buffer.ingest(*make_indexed_event(0))
buffer.ingest(*event2_first)
with pytest.raises(AssertionError):
buffer.ingest(*event2_second) # This duplicate should be ignored

View File

@@ -18,6 +18,7 @@ from exo.utils.pydantic_ext import CamelCaseModel
DEFAULT_ELECTION_TIMEOUT = 3.0
class ElectionMessage(CamelCaseModel):
clock: int
seniority: int
@@ -152,7 +153,9 @@ class Election:
self._candidates = candidates
logger.debug(f"New candidates: {self._candidates}")
logger.debug("Starting new campaign")
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
self._tg.start_soon(
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
)
logger.debug("Campaign started")
continue
# Dismiss old messages
@@ -181,7 +184,9 @@ class Election:
candidates: list[ElectionMessage] = []
self._candidates = candidates
logger.debug("Starting new campaign")
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
self._tg.start_soon(
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
)
logger.debug("Campaign started")
self._connection_messages.append(first)
self._connection_messages.extend(rest)

View File

@@ -40,6 +40,7 @@ def em(
# TESTS #
# ======================================= #
@pytest.fixture(autouse=True)
def fast_election_timeout(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
ibv_coordinator: str
ibv_coordinators: dict[NodeId, str]
# TODO: Single node instance
@@ -40,6 +40,7 @@ Instance = MlxRingInstance | MlxJacclInstance
class BoundInstance(CamelCaseModel):
instance: Instance
bound_runner_id: RunnerId
bound_node_id: NodeId
@property
def bound_shard(self) -> ShardMetadata:

View File

@@ -128,7 +128,9 @@ def mlx_distributed_init(
os.environ["MLX_RING_VERBOSE"] = "1"
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
case MlxJacclInstance(
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
ibv_devices_json = json.dumps(ibv_devices)
@@ -136,6 +138,8 @@ def mlx_distributed_init(
with open(devices_file, "w") as f:
_ = f.write(ibv_devices_json)
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file

View File

@@ -95,7 +95,9 @@ def _create_runner(
return CreateRunner(
instance_id=instance.instance_id,
bound_instance=BoundInstance(instance=instance, bound_runner_id=runner_id),
bound_instance=BoundInstance(
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
),
)

View File

@@ -35,7 +35,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
@@ -76,7 +78,9 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
@@ -126,7 +130,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
@@ -173,7 +179,9 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)

View File

@@ -36,7 +36,9 @@ def test_plan_kills_runner_when_instance_missing():
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
runners = {RUNNER_1_ID: runner}
@@ -71,7 +73,9 @@ def test_plan_kills_runner_when_sibling_failed():
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
runners = {RUNNER_1_ID: runner}
@@ -143,7 +147,9 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
runners = {RUNNER_1_ID: runner}

View File

@@ -40,7 +40,9 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerReady()
)
@@ -86,7 +88,9 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerReady()
)
@@ -131,7 +135,9 @@ def test_plan_does_not_forward_tasks_for_other_instances():
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(instance=local_instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=local_instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerReady()
)
@@ -175,7 +181,9 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerReady()
@@ -236,7 +244,9 @@ def test_plan_returns_none_when_nothing_to_do():
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerRunning()
)

View File

@@ -35,7 +35,9 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
@@ -75,7 +77,9 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
@@ -114,7 +118,9 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
@@ -153,7 +159,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)