2 Commits

Author SHA1 Message Date
Evan
eb16352c85 stop pinging loopback addresses 2025-12-26 20:01:46 +00:00
Jake Hillion
1c1792f5e8 mlx: update to 0.30.1 and align coordinator naming with MLX conventions
The Jaccl distributed backend requires MLX 0.30.1+, which includes the
RDMA over Thunderbolt support. The previous minimum version (0.29.3)
would fail at runtime with "The only valid values for backend are
'any', 'mpi' and 'ring' but 'jaccl' was provided."

Bump MLX dependency to >=0.30.1 and rename ibv_coordinators to
jaccl_coordinators to match MLX's naming conventions. This includes
the environment variable change from MLX_IBV_COORDINATOR to
MLX_JACCL_COORDINATOR.

Test plan:

Hardware setup: 3x Mac Studio M3 Ultra connected all-to-all with TB5

- Built a DMG [0]
- Installed on all Macs and started cluster.
- Requested a 2 node Tensor + MLX RDMA instance of Llama 3.3 70B (FP16).
- It started successfully.
- Queried the chat a few times. All was good. This didn't work
  previously.
- Killed the instance and spawned Pipeline + MLX Ring Llama 3.3 70B (FP16).
  Also started succesfully on two nodes and could be queried.

Still not working:
- Pipeline + MLX Ring on 3 nodes is failing. Haven't debugged that yet.

[0] https://github.com/exo-explore/exo/actions/runs/20467656904/job/58815275013
2025-12-24 16:47:01 +00:00
11 changed files with 158 additions and 328 deletions

View File

@@ -7,9 +7,9 @@ from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
@@ -19,6 +19,7 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import Host
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.topology import NodeInfo
@@ -129,17 +130,17 @@ def place_instance(
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
hosts_by_node = get_mlx_ring_hosts_by_node(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
)
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
hosts=[
Host(
ip=host.ip,
port=random_ephemeral_port(),
)
for host in hosts
],
)
return target_instances

View File

@@ -215,7 +215,7 @@ def get_mlx_ibv_devices_matrix(
continue
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
matrix[i][j] = interface_name
@@ -238,14 +238,14 @@ def _find_connection_ip(
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
) -> Generator[str]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
yield connection.send_back_multiaddr.ip_address
def _find_interface_name_for_ip(
@@ -269,128 +269,6 @@ def _find_interface_name_for_ip(
return None
def _find_general_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
return None
def _find_ip_prioritised(
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
"""Find an IP address between nodes with prioritization.
Priority order:
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
interface_names = [
_find_general_interface_name_for_ip(ip, other_node) for ip, _ in ips
]
en0_ip = next(
(
ip
for (ip, _), interface_name in zip(ips, interface_names)
if interface_name == "en0"
),
None,
)
if en0_ip:
return en0_ip
en1_ip = next(
(
ip
for (ip, _), interface_name in zip(ips, interface_names)
if interface_name == "en1"
),
None,
)
if en1_ip:
return en1_ip
non_thunderbolt_ip = next(
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
)
if non_thunderbolt_ip:
return non_thunderbolt_ip
if ips:
return ips[0][0]
return None
def get_mlx_ring_hosts_by_node(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
Each node gets a list where:
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
- Left/right neighbors: actual connection IPs
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
"""
world_size = len(selected_cycle)
if world_size == 0:
return {}
logger.info(f"[RING3DBG] get_mlx_ring_hosts_by_node: world_size={world_size}, ephemeral_port={ephemeral_port}")
logger.info(f"[RING3DBG] cycle node_ids: {[n.node_id for n in selected_cycle]}")
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
logger.info(f"[RING3DBG] rank={rank} node_id={node_id} left_rank={left_rank} right_rank={right_rank}")
hosts_for_node: list[Host] = []
for idx, other_node in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
if idx not in {left_rank, right_rank}:
# Placeholder IP from RFC 5737 TEST-NET-2
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
logger.info(f"[RING3DBG] rank={rank} idx={idx} connection_ip={connection_ip}")
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
)
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
logger.info(f"[RING3DBG] rank={rank} final hosts_for_node={hosts_for_node}")
hosts_by_node[node_id] = hosts_for_node
return hosts_by_node
def get_mlx_jaccl_coordinators(
selected_cycle: list[NodeInfo],
coordinator_port: int,
@@ -408,7 +286,7 @@ def get_mlx_jaccl_coordinators(
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
return ip
logger.warning(

View File

@@ -163,36 +163,32 @@ async def test_master():
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
# Validate the shard assignments
expected_shard_assignments = ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()
)[0]
assert events[1].event.instance == MlxRingInstance(
instance_id=events[1].event.instance.instance_id,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
),
hosts=[],
)
assert created_instance.shard_assignments == expected_shard_assignments
# For single-node, hosts_by_node should have one entry with self-binding
assert len(created_instance.hosts_by_node) == 1
assert node_id in created_instance.hosts_by_node
assert len(created_instance.hosts_by_node[node_id]) == 1
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
assert created_instance.ephemeral_port > 0
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)

View File

@@ -38,8 +38,7 @@ def instance() -> Instance:
shard_assignments=ShardAssignments(
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
),
hosts_by_node={},
ephemeral_port=50000,
hosts=[],
)
@@ -93,13 +92,9 @@ def test_get_instance_placements_create_instance(
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {})
@@ -239,15 +234,17 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_cycle_with_most_memory(
def test_placement_prioritizes_leaf_cycle_with_less_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
# cycle having LESS total memory than D-E-F. The algorithm should still choose
# the cycle that contains a leaf node.
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
@@ -261,6 +258,11 @@ def test_placement_selects_cycle_with_most_memory(
node_id_e = NodeId()
node_id_f = NodeId()
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
node_id_x = NodeId()
node_id_y = NodeId()
node_id_z = NodeId()
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
@@ -271,20 +273,24 @@ def test_placement_selects_cycle_with_most_memory(
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Build bidirectional cycles for ring topology
# Extra nodes with tiny memory so they can't form singleton placements
topology.add_node(create_node(10, node_id_x))
topology.add_node(create_node(10, node_id_y))
topology.add_node(create_node(10, node_id_z))
# Build directed cycles
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
# Add extra outgoing edges from D/E/F so none of them are leaves
topology.add_connection(create_connection(node_id_d, node_id_x))
topology.add_connection(create_connection(node_id_e, node_id_y))
topology.add_connection(create_connection(node_id_f, node_id_z))
cic = place_instance_command(
model_meta=model_meta,
@@ -293,17 +299,18 @@ def test_placement_selects_cycle_with_most_memory(
# Act
placements = place_instance(cic, topology, {})
# Assert: D-E-F cycle should be selected as it has more total memory
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
# D-E-F has more total memory.
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(

View File

@@ -25,8 +25,7 @@ class BaseInstance(TaggedModel):
class MlxRingInstance(BaseInstance):
hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int
hosts: list[Host]
class MlxJacclInstance(BaseInstance):

View File

@@ -22,7 +22,6 @@ from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
)
from exo.worker.runner.bootstrap import logger
class _LayerCallable(Protocol):
@@ -171,8 +170,6 @@ def pipeline_auto_parallel(
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
logger.info(f"[RING3DBG] pipeline_auto_parallel: device_rank={device_rank} world_size={world_size}")
logger.info(f"[RING3DBG] layers: start={start_layer} end={end_layer} count={len(layers)}")
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)

View File

@@ -111,17 +111,12 @@ def mlx_distributed_init(
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
logger.info(f"[RING3DBG] mlx_distributed_init: bound_node_id={bound_instance.bound_node_id}")
logger.info(f"[RING3DBG] device_rank={rank}, world_size={bound_instance.bound_shard.world_size}")
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
case MlxRingInstance(hosts=hosts):
hostfile = f"./hosts_{rank}.json"
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
logger.info(f"[RING3DBG] hosts_by_node keys: {list(hosts_by_node.keys())}")
logger.info(f"[RING3DBG] hosts_for_node (len={len(hosts_for_node)}): {hosts_for_node}")
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
hosts_json = HostList.from_hosts(hosts).model_dump_json()
with open(hostfile, "w") as f:
_ = f.write(hosts_json)
@@ -153,7 +148,6 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")
logger.info(f"[RING3DBG] ring init complete: group.rank()={group.rank()} group.size()={group.size()}")
return group
@@ -234,8 +228,6 @@ def shard_and_load(
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
logger.info(f"[RING3DBG] shard_and_load: expected device_rank={shard_metadata.device_rank} world_size={shard_metadata.world_size}")
logger.info(f"[RING3DBG] actual group.rank()={group.rank()} group.size()={group.size()}")
match shard_metadata:
case TensorShardMetadata():

View File

@@ -414,14 +414,9 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology, self.node_id)
conns = await check_reachable(self.node_id, self.state.topology)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,

View File

@@ -67,6 +67,5 @@ def get_mlx_ring_instance(
shard_assignments=get_shard_assignments(
model_id, node_to_runner, runner_to_shard
),
hosts_by_node={},
ephemeral_port=50000,
hosts=[],
)

View File

@@ -21,11 +21,52 @@ from exo.worker.tests.unittests.conftest import (
)
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
"""
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
For non-zero device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_starts_warmup_for_rank_zero_after_others_warming():
"""
For device_rank == 0, StartWarmup should only be emitted once all the
other runners in the instance are already warming up.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -36,7 +77,6 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -48,7 +88,7 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
@@ -65,11 +105,9 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
assert result.instance_id == INSTANCE_1_ID
def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warming():
"""
For connecting rank (device_rank == world_size - 1), StartWarmup should
only be emitted once all the other runners are already warming up.
In a 2-node setup, rank 1 is the connecting rank.
Non-zero rank should not start warmup while any shard is not Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -80,7 +118,6 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
@@ -91,7 +128,7 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWarmingUp(),
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerLoaded(),
}
@@ -99,20 +136,19 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
assert result is None
def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():
def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -123,7 +159,6 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -135,7 +170,7 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
@@ -149,46 +184,3 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
)
assert result is None
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == world_size - 1) should not start warmup
until all other ranks are already WarmingUp.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None

View File

@@ -1,78 +1,52 @@
import http.client
import socket
from ipaddress import ip_address
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
# TODO: ref. api port
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> str | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = response.read().decode("utf-8").strip()
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
return body or None
except OSError:
return None
finally:
connection.close()
remote_node_id_raw = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id_raw is None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout
try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
except socket.gaierror:
# seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}")
return
finally:
sock.close()
remote_node_id = NodeId(remote_node_id_raw)
if remote_node_id == self_node_id:
# Connected to ourselves via loopback - skip
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
if result == 0:
if target_node_id not in out:
out[target_node_id] = set()
out[target_node_id].add(target_ip)
async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
async def check_reachable(our_node_id: NodeId, topology: Topology) -> dict[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
our_profile = topology.get_node_profile(our_node_id)
if our_profile is None:
return {}
our_interfaces = our_profile.network_interfaces
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
if node.node_id == our_node_id or node.node_profile is None:
continue
for iface in node.node_profile.network_interfaces:
if ip_address(iface.ip_address).is_loopback:
# Definitely a loopback address
continue
if iface in our_interfaces:
# Skip duplicates with our own interfaces
continue
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
check_reachability, iface.ip_address, node.node_id, reachable
)
return reachable