Files
exo/master/utils/placement_utils.py
Matt Beton 1fe4ed3442 Worker Exception & Timeout Refactor
Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Seth Howes <sethshowes@gmail.com>
2025-08-02 08:28:37 -07:00

108 lines
3.7 KiB
Python

from typing import TypeGuard, cast
from pydantic import BaseModel
from shared.topology import Topology
from shared.types.common import Host, NodeId
from shared.types.models import ModelMetadata
from shared.types.profiling import NodePerformanceProfile
from shared.types.topology import Node
from shared.types.worker.common import RunnerId
from shared.types.worker.runners import ShardAssignments
from shared.types.worker.shards import PipelineShardMetadata
class NodeWithProfile(BaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[Node]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(cycles: list[list[Node]], required_memory: int) -> list[list[Node]]:
filtered_cycles: list[list[Node]] = []
for cycle in cycles:
if not narrow_all_nodes(cycle):
continue
total_mem = sum(node.node_profile.memory.ram_available for node in cycle)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[Node], cycle))
return filtered_cycles
def get_smallest_cycles(cycles: list[list[Node]]) -> list[list[Node]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[Node],
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
cycle_memory = sum(node.node_profile.memory.ram_available for node in selected_cycle)
total_layers = model_meta.n_layers
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(total_layers * (node.node_profile.memory.ram_available / cycle_memory))
node_layers = max(1, node_layers)
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
device_rank=i,
world_size=len(selected_cycle),
start_layer=layers_assigned,
end_layer=layers_assigned + node_layers,
n_layers=total_layers
)
runner_to_shard[runner_id] = shard
node_to_runner[node.node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner
)
return shard_assignments
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
cycles = cycle_digraph.get_cycles()
if not cycles:
return []
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
get_thunderbolt = True
cycle = cycles[0]
hosts: list[Host] = []
for i in range(len(cycle)):
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.list_connections():
if (connection.local_node_id == current_node.node_id and
connection.send_back_node_id == next_node.node_id):
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port
)
hosts.append(host)
break
return hosts