mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-28 00:29:48 -05:00
Compare commits
4 Commits
dont-disco
...
test-app
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
305b02eeb6 | ||
|
|
a25fd21b49 | ||
|
|
9e0c1ac8c8 | ||
|
|
6e76212cac |
@@ -7,9 +7,9 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -19,7 +19,6 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
@@ -130,17 +129,17 @@ def place_instance(
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
hosts_by_node = get_mlx_ring_hosts_by_node(
|
||||
selected_cycle=selected_cycle,
|
||||
cycle_digraph=cycle_digraph,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
target_instances[instance_id] = MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[
|
||||
Host(
|
||||
ip=host.ip,
|
||||
port=random_ephemeral_port(),
|
||||
)
|
||||
for host in hosts
|
||||
],
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -215,7 +215,7 @@ def get_mlx_ibv_devices_matrix(
|
||||
continue
|
||||
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
|
||||
matrix[i][j] = interface_name
|
||||
@@ -238,14 +238,14 @@ def _find_connection_ip(
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[str]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
@@ -269,6 +269,128 @@ def _find_interface_name_for_ip(
|
||||
return None
|
||||
|
||||
|
||||
def _find_general_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
) -> str | None:
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
interface_names = [
|
||||
_find_general_interface_name_for_ip(ip, other_node) for ip, _ in ips
|
||||
]
|
||||
|
||||
en0_ip = next(
|
||||
(
|
||||
ip
|
||||
for (ip, _), interface_name in zip(ips, interface_names)
|
||||
if interface_name == "en0"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = next(
|
||||
(
|
||||
ip
|
||||
for (ip, _), interface_name in zip(ips, interface_names)
|
||||
if interface_name == "en1"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
"""Generate per-node host lists for MLX ring backend.
|
||||
|
||||
Each node gets a list where:
|
||||
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
|
||||
- Left/right neighbors: actual connection IPs
|
||||
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
|
||||
"""
|
||||
world_size = len(selected_cycle)
|
||||
if world_size == 0:
|
||||
return {}
|
||||
|
||||
logger.info(f"[RING3DBG] get_mlx_ring_hosts_by_node: world_size={world_size}, ephemeral_port={ephemeral_port}")
|
||||
logger.info(f"[RING3DBG] cycle node_ids: {[n.node_id for n in selected_cycle]}")
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
for rank, node in enumerate(selected_cycle):
|
||||
node_id = node.node_id
|
||||
left_rank = (rank - 1) % world_size
|
||||
right_rank = (rank + 1) % world_size
|
||||
logger.info(f"[RING3DBG] rank={rank} node_id={node_id} left_rank={left_rank} right_rank={right_rank}")
|
||||
|
||||
hosts_for_node: list[Host] = []
|
||||
|
||||
for idx, other_node in enumerate(selected_cycle):
|
||||
if idx == rank:
|
||||
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
continue
|
||||
|
||||
if idx not in {left_rank, right_rank}:
|
||||
# Placeholder IP from RFC 5737 TEST-NET-2
|
||||
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
logger.info(f"[RING3DBG] rank={rank} idx={idx} connection_ip={connection_ip}")
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
|
||||
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
|
||||
|
||||
logger.info(f"[RING3DBG] rank={rank} final hosts_for_node={hosts_for_node}")
|
||||
hosts_by_node[node_id] = hosts_for_node
|
||||
|
||||
return hosts_by_node
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
@@ -286,7 +408,7 @@ def get_mlx_jaccl_coordinators(
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
|
||||
@@ -163,32 +163,36 @@ async def test_master():
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event.instance == MlxRingInstance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
|
||||
# Validate the shard assignments
|
||||
expected_shard_assignments = ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
)
|
||||
assert created_instance.shard_assignments == expected_shard_assignments
|
||||
# For single-node, hosts_by_node should have one entry with self-binding
|
||||
assert len(created_instance.hosts_by_node) == 1
|
||||
assert node_id in created_instance.hosts_by_node
|
||||
assert len(created_instance.hosts_by_node[node_id]) == 1
|
||||
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
|
||||
@@ -38,7 +38,8 @@ def instance() -> Instance:
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,9 +93,13 @@ def test_get_instance_placements_create_instance(
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -234,17 +239,15 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
|
||||
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
|
||||
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
|
||||
# cycle having LESS total memory than D-E-F. The algorithm should still choose
|
||||
# the cycle that contains a leaf node.
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
@@ -258,11 +261,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
|
||||
node_id_x = NodeId()
|
||||
node_id_y = NodeId()
|
||||
node_id_z = NodeId()
|
||||
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
@@ -273,24 +271,20 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
|
||||
# Extra nodes with tiny memory so they can't form singleton placements
|
||||
topology.add_node(create_node(10, node_id_x))
|
||||
topology.add_node(create_node(10, node_id_y))
|
||||
topology.add_node(create_node(10, node_id_z))
|
||||
|
||||
# Build directed cycles
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
|
||||
# Add extra outgoing edges from D/E/F so none of them are leaves
|
||||
topology.add_connection(create_connection(node_id_d, node_id_x))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_y))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_z))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
@@ -299,18 +293,17 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
|
||||
# D-E-F has more total memory.
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
@@ -25,7 +25,8 @@ class BaseInstance(TaggedModel):
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts: list[Host]
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
|
||||
@@ -22,6 +22,7 @@ from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
@@ -170,6 +171,8 @@ def pipeline_auto_parallel(
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
logger.info(f"[RING3DBG] pipeline_auto_parallel: device_rank={device_rank} world_size={world_size}")
|
||||
logger.info(f"[RING3DBG] layers: start={start_layer} end={end_layer} count={len(layers)}")
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
|
||||
@@ -111,12 +111,17 @@ def mlx_distributed_init(
|
||||
"""
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
logger.info(f"[RING3DBG] mlx_distributed_init: bound_node_id={bound_instance.bound_node_id}")
|
||||
logger.info(f"[RING3DBG] device_rank={rank}, world_size={bound_instance.bound_shard.world_size}")
|
||||
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts=hosts):
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
hostfile = f"./hosts_{rank}.json"
|
||||
hosts_json = HostList.from_hosts(hosts).model_dump_json()
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
logger.info(f"[RING3DBG] hosts_by_node keys: {list(hosts_by_node.keys())}")
|
||||
logger.info(f"[RING3DBG] hosts_for_node (len={len(hosts_for_node)}): {hosts_for_node}")
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
with open(hostfile, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
@@ -148,6 +153,7 @@ def mlx_distributed_init(
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
logger.info(f"[RING3DBG] ring init complete: group.rank()={group.rank()} group.size()={group.size()}")
|
||||
|
||||
return group
|
||||
|
||||
@@ -228,6 +234,8 @@ def shard_and_load(
|
||||
tokenizer = get_tokenizer(model_path, shard_metadata)
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
logger.info(f"[RING3DBG] shard_and_load: expected device_rank={shard_metadata.device_rank} world_size={shard_metadata.world_size}")
|
||||
logger.info(f"[RING3DBG] actual group.rank()={group.rank()} group.size()={group.size()}")
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
|
||||
@@ -414,9 +414,14 @@ class Worker:
|
||||
while True:
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(self.node_id, self.state.topology)
|
||||
conns = await check_reachable(self.state.topology, self.node_id)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=nid,
|
||||
|
||||
@@ -67,5 +67,6 @@ def get_mlx_ring_instance(
|
||||
shard_assignments=get_shard_assignments(
|
||||
model_id, node_to_runner, runner_to_shard
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
@@ -21,134 +21,11 @@ from exo.worker.tests.unittests.conftest import (
|
||||
)
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
"""
|
||||
For non-zero device_rank shards, StartWarmup should be emitted when all
|
||||
shards in the instance are Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_rank_zero_after_others_warming():
|
||||
"""
|
||||
For device_rank == 0, StartWarmup should only be emitted once all the
|
||||
other runners in the instance are already warming up.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerWarmingUp(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warming():
|
||||
"""
|
||||
Non-zero rank should not start warmup while any shard is not Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
"""
|
||||
Rank-zero shard should not start warmup until all non-zero ranks are
|
||||
already WarmingUp.
|
||||
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
|
||||
emitted when all shards in the instance are Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -159,6 +36,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
@@ -173,6 +51,93 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
|
||||
"""
|
||||
For connecting rank (device_rank == world_size - 1), StartWarmup should
|
||||
only be emitted once all the other runners are already warming up.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWarmingUp(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():
|
||||
"""
|
||||
Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
@@ -184,3 +149,46 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
"""
|
||||
Connecting rank (device_rank == world_size - 1) should not start warmup
|
||||
until all other ranks are already WarmingUp.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -1,52 +1,78 @@
|
||||
import socket
|
||||
from ipaddress import ip_address
|
||||
import http.client
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
|
||||
# TODO: ref. api port
|
||||
async def check_reachability(
|
||||
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
) -> None:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1) # 1 second timeout
|
||||
try:
|
||||
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
|
||||
except socket.gaierror:
|
||||
# seems to throw on ipv6 loopback. oh well
|
||||
# logger.warning(f"invalid {target_ip=}")
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
def _fetch_remote_node_id() -> str | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
return body or None
|
||||
except OSError:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
remote_node_id_raw = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id_raw is None:
|
||||
return
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
if target_node_id not in out:
|
||||
out[target_node_id] = set()
|
||||
out[target_node_id].add(target_ip)
|
||||
remote_node_id = NodeId(remote_node_id_raw)
|
||||
if remote_node_id == self_node_id:
|
||||
# Connected to ourselves via loopback - skip
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
f"ip={target_ip}, expected_node_id={expected_node_id}, "
|
||||
f"remote_node_id={remote_node_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if remote_node_id not in out:
|
||||
out[remote_node_id] = set()
|
||||
out[remote_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(our_node_id: NodeId, topology: Topology) -> dict[NodeId, set[str]]:
|
||||
async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
our_profile = topology.get_node_profile(our_node_id)
|
||||
if our_profile is None:
|
||||
return {}
|
||||
our_interfaces = our_profile.network_interfaces
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if node.node_id == our_node_id or node.node_profile is None:
|
||||
if not node.node_profile:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
if ip_address(iface.ip_address).is_loopback:
|
||||
# Definitely a loopback address
|
||||
continue
|
||||
if iface in our_interfaces:
|
||||
# Skip duplicates with our own interfaces
|
||||
continue
|
||||
tg.start_soon(
|
||||
check_reachability, iface.ip_address, node.node_id, reachable
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
)
|
||||
|
||||
return reachable
|
||||
|
||||
Reference in New Issue
Block a user