4 Commits

Author SHA1 Message Date
Jake Hillion
305b02eeb6 localhost ring fixes 2025-12-23 23:32:04 +00:00
Jake Hillion
a25fd21b49 tmp: add lots of RING3DBG logging 2025-12-23 22:52:40 +00:00
Jake Hillion
9e0c1ac8c8 placement: generate per-node host lists for MLX ring backend
Pipeline + MLX Ring worked with 2 nodes but failed to initialize with
3 or more nodes. The MLX ring backend requires each node to know its
specific left and right neighbors in the ring, but the previous
implementation provided a single flat host list shared by all nodes.

With 2 nodes, a flat list [host0, host1] accidentally worked because
each node could find its only neighbor. With 3+ nodes, each node needs
a customized view:
- Rank 0: [self, right_neighbor, placeholder]
- Rank 1: [left_neighbor, self, right_neighbor]
- Rank 2: [placeholder, left_neighbor, self]

Changed MlxRingInstance from `hosts: list[Host]` to
`hosts_by_node: dict[NodeId, list[Host]]` with `ephemeral_port: int`.

Added `get_mlx_ring_hosts_by_node()` which generates per-node host
lists where:
- Self position uses 0.0.0.0 for local binding
- Left/right neighbors use actual connection IPs
- Non-neighbors use 198.51.100.1 (RFC 5737 TEST-NET-2 placeholder)

Also added IP prioritization (en0 > en1 > non-Thunderbolt > any) to
prefer stable network interfaces.
2025-12-23 22:26:25 +00:00
Jake Hillion
6e76212cac mlx: update to 0.30.1 and align coordinator naming with MLX conventions
The Jaccl distributed backend requires MLX 0.30.1+, which includes the
RDMA over Thunderbolt support. The previous minimum version (0.29.3)
would fail at runtime with "The only valid values for backend are
'any', 'mpi' and 'ring' but 'jaccl' was provided."

Bump MLX dependency to >=0.30.1 and rename ibv_coordinators to
jaccl_coordinators to match MLX's naming conventions. This includes
the environment variable change from MLX_IBV_COORDINATOR to
MLX_JACCL_COORDINATOR.

Test plan:

Hardware setup: 3x Mac Studio M3 Ultra connected all-to-all with TB5

- Built a DMG [0]
- Installed on all Macs and started cluster.
- Requested a 2 node Tensor + MLX RDMA instance of Llama 3.3 70B (FP16).
- It started successfully.
- Queried the chat a few times. All was good. This didn't work
  previously.
- Killed the instance and spawned Pipeline + MLX Ring Llama 3.3 70B (FP16).
  Also started succesfully on two nodes and could be queried.

Still not working:
- Pipeline + MLX Ring on 3 nodes is failing. Haven't debugged that yet.

[0] https://github.com/exo-explore/exo/actions/runs/20467656904/job/58815275013
2025-12-23 19:28:42 +00:00
14 changed files with 433 additions and 246 deletions

View File

@@ -29,7 +29,7 @@ dependencies = [
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"bidict>=0.23.1",
"mlx>=0.29.3",
"mlx>=0.30.1",
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",

View File

@@ -7,9 +7,9 @@ from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
@@ -19,7 +19,6 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import Host
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.topology import NodeInfo
@@ -118,7 +117,7 @@ def place_instance(
selected_cycle,
cycle_digraph,
)
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
@@ -127,20 +126,20 @@ def place_instance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
ibv_coordinators=mlx_ibv_coordinators,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
ephemeral_port = random_ephemeral_port()
hosts_by_node = get_mlx_ring_hosts_by_node(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts=[
Host(
ip=host.ip,
port=random_ephemeral_port(),
)
for host in hosts
],
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
return target_instances

View File

@@ -215,7 +215,7 @@ def get_mlx_ibv_devices_matrix(
continue
# Find the IP J uses to talk to I
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
matrix[i][j] = interface_name
@@ -238,14 +238,14 @@ def _find_connection_ip(
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[str]:
"""Find all IP addresses that connect node i to node j."""
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_interface_name_for_ip(
@@ -269,12 +269,134 @@ def _find_interface_name_for_ip(
return None
def get_mlx_ibv_coordinators(
def _find_general_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
return None
def _find_ip_prioritised(
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
"""Find an IP address between nodes with prioritization.
Priority order:
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
interface_names = [
_find_general_interface_name_for_ip(ip, other_node) for ip, _ in ips
]
en0_ip = next(
(
ip
for (ip, _), interface_name in zip(ips, interface_names)
if interface_name == "en0"
),
None,
)
if en0_ip:
return en0_ip
en1_ip = next(
(
ip
for (ip, _), interface_name in zip(ips, interface_names)
if interface_name == "en1"
),
None,
)
if en1_ip:
return en1_ip
non_thunderbolt_ip = next(
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
)
if non_thunderbolt_ip:
return non_thunderbolt_ip
if ips:
return ips[0][0]
return None
def get_mlx_ring_hosts_by_node(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
Each node gets a list where:
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
- Left/right neighbors: actual connection IPs
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
"""
world_size = len(selected_cycle)
if world_size == 0:
return {}
logger.info(f"[RING3DBG] get_mlx_ring_hosts_by_node: world_size={world_size}, ephemeral_port={ephemeral_port}")
logger.info(f"[RING3DBG] cycle node_ids: {[n.node_id for n in selected_cycle]}")
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
logger.info(f"[RING3DBG] rank={rank} node_id={node_id} left_rank={left_rank} right_rank={right_rank}")
hosts_for_node: list[Host] = []
for idx, other_node in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
if idx not in {left_rank, right_rank}:
# Placeholder IP from RFC 5737 TEST-NET-2
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
logger.info(f"[RING3DBG] rank={rank} idx={idx} connection_ip={connection_ip}")
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
)
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
logger.info(f"[RING3DBG] rank={rank} final hosts_for_node={hosts_for_node}")
hosts_by_node[node_id] = hosts_for_node
return hosts_by_node
def get_mlx_jaccl_coordinators(
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX IBV (rank 0 device).
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
@@ -286,7 +408,7 @@ def get_mlx_ibv_coordinators(
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
return ip
logger.warning(

View File

@@ -163,32 +163,36 @@ async def test_master():
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()
)[0]
assert events[1].event.instance == MlxRingInstance(
instance_id=events[1].event.instance.instance_id,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
# Validate the shard assignments
expected_shard_assignments = ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
),
hosts=[],
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
)
assert created_instance.shard_assignments == expected_shard_assignments
# For single-node, hosts_by_node should have one entry with self-binding
assert len(created_instance.hosts_by_node) == 1
assert node_id in created_instance.hosts_by_node
assert len(created_instance.hosts_by_node[node_id]) == 1
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
assert created_instance.ephemeral_port > 0
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)

View File

@@ -38,7 +38,8 @@ def instance() -> Instance:
shard_assignments=ShardAssignments(
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
),
hosts=[],
hosts_by_node={},
ephemeral_port=50000,
)
@@ -92,9 +93,13 @@ def test_get_instance_placements_create_instance(
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {})
@@ -234,17 +239,15 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_prioritizes_leaf_cycle_with_less_memory(
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
# cycle having LESS total memory than D-E-F. The algorithm should still choose
# the cycle that contains a leaf node.
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
@@ -258,11 +261,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
node_id_e = NodeId()
node_id_f = NodeId()
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
node_id_x = NodeId()
node_id_y = NodeId()
node_id_z = NodeId()
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
@@ -273,24 +271,20 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Extra nodes with tiny memory so they can't form singleton placements
topology.add_node(create_node(10, node_id_x))
topology.add_node(create_node(10, node_id_y))
topology.add_node(create_node(10, node_id_z))
# Build directed cycles
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
# Add extra outgoing edges from D/E/F so none of them are leaves
topology.add_connection(create_connection(node_id_d, node_id_x))
topology.add_connection(create_connection(node_id_e, node_id_y))
topology.add_connection(create_connection(node_id_f, node_id_z))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
@@ -299,18 +293,17 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
# Act
placements = place_instance(cic, topology, {})
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
# D-E-F has more total memory.
# Assert: D-E-F cycle should be selected as it has more total memory
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(
@@ -437,7 +430,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.ibv_coordinators is not None
assert instance.jaccl_coordinators is not None
matrix = instance.ibv_devices
assert len(matrix) == 3
@@ -459,10 +452,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.ibv_coordinators) == 3
assert len(instance.jaccl_coordinators) == 3
for node_id in assigned_nodes:
assert node_id in instance.ibv_coordinators
coordinator = instance.ibv_coordinators[node_id]
assert node_id in instance.jaccl_coordinators
coordinator = instance.jaccl_coordinators[node_id]
assert ":" in coordinator
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
if node_id == assigned_nodes[0]:

View File

@@ -5,7 +5,7 @@ import pytest
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
@@ -265,7 +265,7 @@ def test_get_hosts_from_subgraph(
assert expected_host in hosts
def test_get_mlx_ibv_coordinators(
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
@@ -357,7 +357,7 @@ def test_get_mlx_ibv_coordinators(
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_ibv_coordinators(
coordinators = get_mlx_jaccl_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
)

View File

@@ -25,12 +25,13 @@ class BaseInstance(TaggedModel):
class MlxRingInstance(BaseInstance):
hosts: list[Host]
hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
ibv_coordinators: dict[NodeId, str]
jaccl_coordinators: dict[NodeId, str]
# TODO: Single node instance

View File

@@ -22,6 +22,7 @@ from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
)
from exo.worker.runner.bootstrap import logger
class _LayerCallable(Protocol):
@@ -170,6 +171,8 @@ def pipeline_auto_parallel(
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
logger.info(f"[RING3DBG] pipeline_auto_parallel: device_rank={device_rank} world_size={world_size}")
logger.info(f"[RING3DBG] layers: start={start_layer} end={end_layer} count={len(layers)}")
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)

View File

@@ -111,12 +111,17 @@ def mlx_distributed_init(
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
logger.info(f"[RING3DBG] mlx_distributed_init: bound_node_id={bound_instance.bound_node_id}")
logger.info(f"[RING3DBG] device_rank={rank}, world_size={bound_instance.bound_shard.world_size}")
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts=hosts):
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
hostfile = f"./hosts_{rank}.json"
hosts_json = HostList.from_hosts(hosts).model_dump_json()
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
logger.info(f"[RING3DBG] hosts_by_node keys: {list(hosts_by_node.keys())}")
logger.info(f"[RING3DBG] hosts_for_node (len={len(hosts_for_node)}): {hosts_for_node}")
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
with open(hostfile, "w") as f:
_ = f.write(hosts_json)
@@ -129,7 +134,7 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
@@ -138,16 +143,17 @@ def mlx_distributed_init(
with open(devices_file, "w") as f:
_ = f.write(ibv_devices_json)
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")
logger.info(f"[RING3DBG] ring init complete: group.rank()={group.rank()} group.size()={group.size()}")
return group
@@ -228,6 +234,8 @@ def shard_and_load(
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
logger.info(f"[RING3DBG] shard_and_load: expected device_rank={shard_metadata.device_rank} world_size={shard_metadata.world_size}")
logger.info(f"[RING3DBG] actual group.rank()={group.rank()} group.size()={group.size()}")
match shard_metadata:
case TensorShardMetadata():

View File

@@ -414,9 +414,14 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology)
conns = await check_reachable(self.state.topology, self.node_id)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,

View File

@@ -67,5 +67,6 @@ def get_mlx_ring_instance(
shard_assignments=get_shard_assignments(
model_id, node_to_runner, runner_to_shard
),
hosts=[],
hosts_by_node={},
ephemeral_port=50000,
)

View File

@@ -21,134 +21,11 @@ from exo.worker.tests.unittests.conftest import (
)
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
"""
For non-zero device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_starts_warmup_for_rank_zero_after_others_warming():
"""
For device_rank == 0, StartWarmup should only be emitted once all the
other runners in the instance are already warming up.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warming():
"""
Non-zero rank should not start warmup while any shard is not Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None
def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -159,6 +36,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -173,6 +51,93 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
"""
For connecting rank (device_rank == world_size - 1), StartWarmup should
only be emitted once all the other runners are already warming up.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWarmingUp(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():
"""
Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWaitingForModel(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
@@ -184,3 +149,46 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
)
assert result is None
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == world_size - 1) should not start warmup
until all other ranks are already WarmingUp.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None

View File

@@ -1,33 +1,66 @@
import socket
import http.client
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
# TODO: ref. api port
async def check_reachability(
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout
try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
except socket.gaierror:
# seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}")
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> str | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = response.read().decode("utf-8").strip()
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
return body or None
except OSError:
return None
finally:
connection.close()
remote_node_id_raw = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id_raw is None:
return
finally:
sock.close()
if result == 0:
if target_node_id not in out:
out[target_node_id] = set()
out[target_node_id].add(target_ip)
remote_node_id = NodeId(remote_node_id_raw)
if remote_node_id == self_node_id:
# Connected to ourselves via loopback - skip
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node in topology.list_nodes():
@@ -35,7 +68,11 @@ async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability, iface.ip_address, node.node_id, reachable
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
)
return reachable

26
uv.lock generated
View File

@@ -374,7 +374,7 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", specifier = ">=0.29.3" },
{ name = "mlx", specifier = ">=0.30.1" },
{ name = "mlx-lm", specifier = ">=0.28.3" },
{ name = "networkx", specifier = ">=3.5" },
{ name = "protobuf", specifier = ">=6.32.0" },
@@ -783,16 +783,22 @@ wheels = [
[[package]]
name = "mlx"
version = "0.29.3"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/fe/a2/078152b45aa8a23949a1b09601d0044f8bb4ab85e909e4475a440c21aaea/mlx-0.29.3-cp313-cp313-macosx_13_0_arm64.whl", hash = "sha256:d59eccf6a1e1e131becc5a3910504507862da3a4e9b7bd9e73a625515d767844", size = 549585, upload-time = "2025-10-17T19:17:01.872Z" },
{ url = "https://files.pythonhosted.org/packages/ae/bb/869eaac4efaae033c13db5fddd6a8907b5d667d135a35a2e482b1af402ee/mlx-0.29.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:6642aa0a6dc2242c024fb8274d00631a7e7ffbdcef26148afd299b877c1e6a4a", size = 549586, upload-time = "2025-10-17T19:16:57.844Z" },
{ url = "https://files.pythonhosted.org/packages/ad/76/196c248c2b2a471f795356564ad1d7dc40284160c8b66370ffadfd991fa1/mlx-0.29.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:ec0aef311fab10cb5f2c274afa6edf6c482636096a5f7886aba43676454aa462", size = 549586, upload-time = "2025-10-17T19:16:39.912Z" },
{ url = "https://files.pythonhosted.org/packages/f2/90/d481dd70b351e28718cfc9a0deb229a75e140abda3ed59284cf635f93f12/mlx-0.29.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:e217a99ece66832a2e631131df32e9feb047276b68ac59ca0ad63735842f6dd0", size = 649781, upload-time = "2025-10-17T19:21:26.075Z" },
{ url = "https://files.pythonhosted.org/packages/f9/fd/c6f56cd87d48763ed63655ace627c06db9819eae7d43d132f40d4965947a/mlx-0.30.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743520758bc8261b2ed8f3b3dc96e4e9236769dd8f61fb17877c5e44037e2058", size = 593366, upload-time = "2025-12-18T01:55:46.786Z" },
{ url = "https://files.pythonhosted.org/packages/dc/53/96d8c48b21f91c4216b6d2ef6dfc10862e5fb0b811a2aaf02c96c78601de/mlx-0.30.1-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:fc9745bc1860ca60128e3a6d36157da06d936e2b4007a4dcba990b40202f598f", size = 593368, upload-time = "2025-12-18T01:55:48.363Z" },
{ url = "https://files.pythonhosted.org/packages/70/ce/476c3b7d3a4153bd0e1c5af1f1b6c09a804b652bbed34072404b322c22e0/mlx-0.30.1-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:a1480399c67bb327a66c5527b73915132e3fcaae3bce9634e5c81ccad9f43229", size = 567561, upload-time = "2025-12-18T00:15:56.153Z" },
{ url = "https://files.pythonhosted.org/packages/33/41/7ad1e639fd7dd1cf01a62c1c5b051024a859888c27504996e9d8380e6754/mlx-0.30.1-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:8e19850a4236a8e174f851f5789b8b62a8eb74f5a8fa49ad8ba286c5ddb5f9bf", size = 643122, upload-time = "2025-12-18T01:55:49.607Z" },
{ url = "https://files.pythonhosted.org/packages/d0/dc/72d3737c5b0662eb5e785d353dbc5e34d793d27b09b99e39993ee051bd19/mlx-0.30.1-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:1c8ed5bcd9f1910fca209e95859ac737e60b3e1954181b820fa269158f81049a", size = 687254, upload-time = "2025-12-18T01:55:51.239Z" },
{ url = "https://files.pythonhosted.org/packages/9b/cc/523448996247bb05d9d68e23bccf3dafdda660befb9330f6bd5fa13361e8/mlx-0.30.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:d34cc2c25b0ee41c1349f14650db760e282685339858e305453f62405c12bc1b", size = 596006, upload-time = "2025-12-18T01:55:52.463Z" },
{ url = "https://files.pythonhosted.org/packages/23/0e/f9f2f9659c34c87be8f4167f6a1d6ed7e826f4889d20eecd4c0d8122f0e9/mlx-0.30.1-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:4e47d301e9095b87f0bda8827bfd6ffe744223aba5cee8f28e25894d647f5823", size = 596008, upload-time = "2025-12-18T01:55:54.02Z" },
{ url = "https://files.pythonhosted.org/packages/56/a7/49e41fb141de95b6a376091a963c737839c9cda04e423c67f57460a50458/mlx-0.30.1-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:cfba13e2a52255d663a1ad62f0f83eb3991e42147edf9a8d38cdd224e48ca49b", size = 570406, upload-time = "2025-12-18T00:15:57.177Z" },
{ url = "https://files.pythonhosted.org/packages/73/99/a43cb112167cf865c069f5e108ae42f5314663930ff3dd86c2d23d984191/mlx-0.30.1-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:bebfec377208eb29cc88aa86c897c7446aa0984838669e138f273f9225d627ff", size = 646461, upload-time = "2025-12-18T01:55:55.285Z" },
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
]
[[package]]
@@ -814,12 +820,12 @@ wheels = [
[[package]]
name = "mlx-metal"
version = "0.29.3"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/41/95/a00054a006df82bb1b5b8f666ae44a676b259146fadbff90fe654309fefc/mlx_metal-0.29.3-py3-none-macosx_13_0_arm64.whl", hash = "sha256:27b5a4d905202a71e84d9fd559ea0236813f6f960ef494e5cafe9c45df4c9d7c", size = 36817352, upload-time = "2025-10-17T19:19:25.801Z" },
{ url = "https://files.pythonhosted.org/packages/c0/d8/5ee91eac16dfcf0334103120b47d4abd8c890ccc0d73d3eee4770ce8810f/mlx_metal-0.29.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:f426d4b67f96b4d6f0ed50d5992933595aadb370dc3e9ed2410bafbc16229882", size = 36555573, upload-time = "2025-10-17T19:18:42.098Z" },
{ url = "https://files.pythonhosted.org/packages/cd/9a/39b7ecdf21cf2a39ced8d7933eed65c6cb38295cadfd0907dd1abd4d1ded/mlx_metal-0.29.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:106616f7f825851043c53d3dc186965c003985da9cbb6e5c034f35108fc1fc27", size = 36549163, upload-time = "2025-10-17T19:18:37.701Z" },
{ url = "https://files.pythonhosted.org/packages/09/3f/0be35ddad7e13d8ecd33a9185895f9739bb00b96ef0cce36cf0405d4aec0/mlx_metal-0.30.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:e7e92c6bdbd7ac8083f528a4c6640552ae106a57bb3d99856ac10a32e93a4b5e", size = 36864966, upload-time = "2025-12-18T01:55:31.473Z" },
{ url = "https://files.pythonhosted.org/packages/1e/1f/c0bddd0d5bf3871411aabe32121e09e1b7cdbece8917a49d5a442310e3e5/mlx_metal-0.30.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:bb50f57418af7fc3c42a2da2c4bde0e7ab7ac0b997de1f6f642a6680ac65d626", size = 36859011, upload-time = "2025-12-18T01:55:34.541Z" },
{ url = "https://files.pythonhosted.org/packages/67/b3/73cc2f584ac612a476096d35a61eed75ee7ed8b4e320b0c36cf60a14d4eb/mlx_metal-0.30.1-py3-none-macosx_26_0_arm64.whl", hash = "sha256:e0b151a0053ac00b4226710bfb6dbf54b87283fb01e10fb3877f9ea969f680aa", size = 44981160, upload-time = "2025-12-18T00:15:47.518Z" },
]
[[package]]