26 Commits

Author SHA1 Message Date
Sami Khan
f3e8ab2034 parsing api fix 2025-12-22 22:29:32 +05:00
Evan
2f98b51d0a code review followup 2025-12-22 03:26:31 +00:00
Evan
666af3ddc1 rename channel test 2025-12-22 02:41:43 +00:00
Evan
f742c1de34 move macmon test 2025-12-22 02:40:19 +00:00
Evan
af20f024de cleanup after rebase 2025-12-22 02:30:43 +00:00
Evan
b48fc01827 dedup connections 2025-12-22 02:29:36 +00:00
Evan
218fdc0188 freeze those models 2025-12-22 02:29:36 +00:00
Evan
d8018c1307 format 2025-12-22 02:29:36 +00:00
Evan
c97eda8404 tidy 2025-12-22 02:29:36 +00:00
Evan
8f59be2b6f all mastet tests pass 2025-12-22 02:29:36 +00:00
Evan
0a0ea37c64 ibv -> jaccl 2025-12-22 02:29:36 +00:00
Evan
646d553bc5 tidying some horrible logic 2025-12-22 02:29:36 +00:00
Evan
6c136306a4 fix download test 2025-12-22 02:29:36 +00:00
Evan
2c90e2e3bc fix all master tests except rdma placement 2025-12-22 02:29:36 +00:00
Evan
83037d289e fix topology tests 2025-12-22 02:29:36 +00:00
Evan
a30c13c4d9 bug 2025-12-22 02:29:36 +00:00
Evan
28bea3a377 actually update the topology 2025-12-22 02:29:36 +00:00
Evan
3edb452f5b incorrect log 2025-12-22 02:29:16 +00:00
Evan
3a1ab9ad36 handle an error 2025-12-22 02:28:21 +00:00
Evan
f7a2208694 fix pydantic validation 2025-12-22 02:28:21 +00:00
Evan
9451ced365 type checks outside of tests, time to test 2025-12-22 02:28:21 +00:00
Evan
2c2d958280 wuff 2025-12-22 02:28:21 +00:00
Evan
2112e273f5 rework topology 2025-12-22 02:28:09 +00:00
Evan
e54ab7aa2c update placement 2025-12-22 02:28:09 +00:00
Evan
69b60cf3d0 mvp 2025-12-22 02:28:09 +00:00
Evan
f2bb60b420 tidy config 2025-12-22 02:23:58 +00:00
37 changed files with 1122 additions and 1261 deletions

View File

@@ -19,6 +19,7 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -96,7 +96,7 @@ interface RawNodeProfile {
interface RawTopologyNode {
nodeId: string;
nodeProfile: RawNodeProfile;
nodeProfile?: RawNodeProfile;
}
interface RawTopologyConnection {
@@ -105,9 +105,13 @@ interface RawTopologyConnection {
sendBackMultiaddr?: { multiaddr?: string; address?: string; ip_address?: string } | string;
}
// Connection can be an object or a tuple [source, target, metadata]
type RawConnectionItem = RawTopologyConnection | [string, string, { sinkMultiaddr?: { ip_address?: string; address?: string } }?];
interface RawTopology {
nodes: RawTopologyNode[];
connections?: RawTopologyConnection[];
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
connections?: RawConnectionItem[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
@@ -198,9 +202,17 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === 'string' ? node : node.nodeId;
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode = typeof node === 'object' ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
@@ -238,7 +250,7 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
}
}
nodes[node.nodeId] = {
nodes[nodeId] = {
system_info: {
model_id: profile?.modelId ?? 'Unknown',
chip: profile?.chipId,
@@ -260,14 +272,34 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
};
}
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
for (const conn of raw.connections || []) {
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
if (conn.localNodeId === conn.sendBackNodeId) continue;
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
let localNodeId: string | undefined;
let sendBackNodeId: string | undefined;
let sendBackMultiaddr: { multiaddr?: string; address?: string; ip_address?: string } | string | undefined;
// Check if it's a tuple format [source, target, metadata]
if (Array.isArray(conn)) {
localNodeId = conn[0] as string;
sendBackNodeId = conn[1] as string;
const metadata = conn[2] as { sinkMultiaddr?: { ip_address?: string; address?: string } } | undefined;
if (metadata?.sinkMultiaddr) {
sendBackMultiaddr = metadata.sinkMultiaddr;
}
} else {
// Object format with localNodeId/sendBackNodeId
localNodeId = conn.localNodeId;
sendBackNodeId = conn.sendBackNodeId;
sendBackMultiaddr = conn.sendBackMultiaddr;
}
if (!localNodeId || !sendBackNodeId) continue;
if (localNodeId === sendBackNodeId) continue;
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
let sendBackIp: string | undefined;
if (conn.sendBackMultiaddr) {
const multi = conn.sendBackMultiaddr;
if (sendBackMultiaddr) {
const multi = sendBackMultiaddr;
if (typeof multi === 'string') {
sendBackIp = extractIpFromMultiaddr(multi);
} else {
@@ -276,8 +308,8 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
}
edges.push({
source: conn.localNodeId,
target: conn.sendBackNodeId,
source: localNodeId,
target: sendBackNodeId,
sendBackIp
});
}

View File

@@ -207,6 +207,7 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -262,6 +263,7 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -426,9 +428,8 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
return total_available

View File

@@ -158,6 +158,7 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -200,9 +201,7 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:

View File

@@ -6,10 +6,11 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_shard_assignments,
get_smallest_cycles,
)
@@ -19,10 +20,10 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import Host
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.topology import NodeInfo
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -51,19 +52,18 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
cycles = topology.get_cycles() + [[node] for node in all_nodes]
logger.info(cycles)
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
candidate_cycles, node_profiles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
@@ -71,13 +71,15 @@ def place_instance(
smallest_tb_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
if topology.get_subgraph_from_nodes(
[node.node_id for node in cycle]
).is_thunderbolt_cycle([node.node_id for node in cycle])
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node.node_id) for node in cycle)
@@ -86,11 +88,7 @@ def place_instance(
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
(node.node_profile.memory.ram_available for node in cycle),
start=Memory(),
),
)
@@ -99,14 +97,16 @@ def place_instance(
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
[node.node_id for node in selected_cycle]
)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected ibv for a single node instance; falling back to MlxRing"
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -114,20 +114,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
cycle_digraph,
)
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
selected_cycle,
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle[0].node_id,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
ibv_coordinators=mlx_ibv_coordinators,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import TypeGuard, cast
from collections.abc import Generator, Mapping
from loguru import logger
from pydantic import BaseModel
@@ -9,7 +8,7 @@ from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -24,27 +23,32 @@ class NodeWithProfile(BaseModel):
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[list[NodeWithProfile]]:
filtered_cycles: list[list[NodeWithProfile]] = []
for cycle in cycles:
if not narrow_all_nodes(cycle):
if not all(node in node_profiles for node in cycle):
continue
total_mem = sum(
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[NodeInfo], cycle))
filtered_cycles.append(
[
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
for node in cycle
]
)
return filtered_cycles
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
def get_smallest_cycles(
cycles: list[list[NodeWithProfile]],
) -> list[list[NodeWithProfile]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
@@ -135,11 +139,9 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeInfo],
selected_cycle: list[NodeWithProfile],
sharding: Sharding,
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
@@ -176,17 +178,16 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
for src, sink, connection in cycle_digraph.list_connections():
if not isinstance(connection, SocketConnection):
continue
if src == current_node and sink == next_node:
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
@@ -194,8 +195,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
return hosts
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
def get_mlx_jaccl_devices_matrix(
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -204,6 +204,7 @@ def get_mlx_ibv_devices_matrix(
to device j, or None if no connection exists or no interface name is found.
Diagonal elements are always None.
"""
selected_cycle = list(cycle_digraph.list_nodes())
num_nodes = len(selected_cycle)
matrix: list[list[str | None]] = [
[None for _ in range(num_nodes)] for _ in range(num_nodes)
@@ -214,86 +215,55 @@ def get_mlx_ibv_devices_matrix(
if i == j:
continue
# Find the IP J uses to talk to I
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
break
else:
logger.warning(
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current ibv backend requires all-to-all rdma connections"
"Current jaccl backend requires all-to-all RDMA connections"
)
return matrix
def _find_connection_ip(
node_i: NodeInfo,
node_j: NodeInfo,
node_i: NodeId,
node_j: NodeId,
cycle_digraph: Topology,
) -> Generator[str]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address
def _find_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def get_mlx_ibv_coordinators(
selected_cycle: list[NodeInfo],
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX IBV (rank 0 device).
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
selected_cycle = list(cycle_digraph.list_nodes())
logger.info(f"Selecting coordinator: {coordinator}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
return "0.0.0.0"
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
return ip
logger.warning(
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}

View File

@@ -1,67 +1,36 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
return _create_connection
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)

View File

@@ -19,15 +19,13 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodePerformanceMeasured,
NodeGatheredInfo,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
MemoryUsage,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -83,21 +81,14 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodePerformanceMeasured(
NodeGatheredInfo(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
)
),
@@ -161,7 +152,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[1].event, InstanceCreated)
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()

View File

@@ -1,5 +1,3 @@
from typing import Callable
import pytest
from loguru import logger
@@ -7,14 +5,20 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_connection,
create_node_profile,
create_rdma_connection,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -26,11 +30,6 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -74,30 +73,33 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_a))
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(2))
topology.add_connection(node_id_c, node_id_a, create_connection(3))
# act
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
# assert
assert len(placements) == 1
@@ -123,12 +125,11 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_exact_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -137,7 +138,7 @@ def test_get_instance_placements_one_node_exact_fit(
n_layers=10,
),
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -148,12 +149,11 @@ def test_get_instance_placements_one_node_exact_fit(
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1001 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -162,7 +162,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
n_layers=10,
),
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -173,12 +173,11 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_not_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
cic = place_instance_command(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
@@ -189,7 +188,7 @@ def test_get_instance_placements_one_node_not_fit(
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {})
place_instance(cic, topology, {}, profiles)
def test_get_transition_events_no_change(instance: Instance):
@@ -235,190 +234,102 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_prioritizes_leaf_cycle_with_less_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
# cycle having LESS total memory than D-E-F. The algorithm should still choose
# the cycle that contains a leaf node.
# arrange
topology = Topology()
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_meta.storage_size = Memory.from_bytes(1000)
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
node_id_x = NodeId()
node_id_y = NodeId()
node_id_z = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Daisy chain topology
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_a, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(1))
topology.add_connection(node_id_c, node_id_b, create_connection(1))
topology.add_connection(node_id_c, node_id_d, create_connection(1))
topology.add_connection(node_id_d, node_id_c, create_connection(1))
# Extra nodes with tiny memory so they can't form singleton placements
topology.add_node(create_node(10, node_id_x))
topology.add_node(create_node(10, node_id_y))
topology.add_node(create_node(10, node_id_z))
# Build directed cycles
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_d))
# Add extra outgoing edges from D/E/F so none of them are leaves
topology.add_connection(create_connection(node_id_d, node_id_x))
topology.add_connection(create_connection(node_id_e, node_id_y))
topology.add_connection(create_connection(node_id_f, node_id_z))
logger.info(list(topology.list_connections()))
cic = place_instance_command(
model_meta=model_meta,
)
# Act
placements = place_instance(cic, topology, {})
# act
placements = place_instance(cic, topology, {}, profiles)
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
# D-E-F has more total memory.
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
instance = list(placements.values())[0]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(node_id_c, node_id_d)
)
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="192.168.1.100",
)
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
topology.add_connection(node_a, node_b, create_rdma_connection(3))
topology.add_connection(node_b, node_c, create_rdma_connection(4))
topology.add_connection(node_c, node_a, create_rdma_connection(5))
topology.add_connection(node_b, node_a, create_rdma_connection(3))
topology.add_connection(node_c, node_b, create_rdma_connection(4))
topology.add_connection(node_a, node_c, create_rdma_connection(5))
topology.add_connection(node_a, node_b, ethernet_conn)
topology.add_connection(node_b, node_c, ethernet_conn)
topology.add_connection(node_c, node_a, ethernet_conn)
topology.add_connection(node_a, node_c, ethernet_conn)
topology.add_connection(node_b, node_a, ethernet_conn)
topology.add_connection(node_c, node_b, ethernet_conn)
cic = PlaceInstance(
sharding=Sharding.Tensor,
@@ -428,7 +339,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
min_nodes=1,
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -436,10 +347,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.ibv_coordinators is not None
assert instance.jaccl_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.ibv_devices
matrix = instance.jaccl_devices
assert len(matrix) == 3
for i in range(3):
@@ -448,21 +359,21 @@ def test_tensor_rdma_backend_connectivity_matrix(
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
# Verify coordinators are set for all nodes
assert len(instance.ibv_coordinators) == 3
assert len(instance.jaccl_coordinators) == 3
for node_id in assigned_nodes:
assert node_id in instance.ibv_coordinators
coordinator = instance.ibv_coordinators[node_id]
assert node_id in instance.jaccl_coordinators
coordinator = instance.jaccl_coordinators[node_id]
assert ":" in coordinator
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
if node_id == assigned_nodes[0]:

View File

@@ -1,56 +1,48 @@
from typing import Callable
import pytest
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_connection, create_node_profile
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_cycles_by_memory():
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
topology.add_node(node1)
topology.add_node(node2)
topology.add_node(node1_id)
topology.add_node(node2_id)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
connection1 = create_connection(1)
connection2 = create_connection(2)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
# assert
assert len(filtered_cycles) == 1
@@ -58,64 +50,65 @@ def test_filter_cycles_by_memory(
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_cycles_by_insufficient_memory():
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
topology.add_node(node1)
topology.add_node(node2)
topology.add_node(node1_id)
topology.add_node(node2_id)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
connection1 = create_connection(1)
connection2 = create_connection(2)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), Memory.from_kb(2001)
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_multiple_cycles_by_memory():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
# assert
assert len(filtered_cycles) == 1
@@ -127,31 +120,38 @@ def test_filter_multiple_cycles_by_memory(
}
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_get_smallest_cycles():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
# act
smallest_cycles = get_smallest_cycles(topology.get_cycles())
smallest_cycles = get_smallest_cycles(cycles)
# assert
assert len(smallest_cycles) == 1
@@ -168,9 +168,6 @@ def test_get_smallest_cycles(
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -179,19 +176,25 @@ def test_get_shard_assignments(
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_c_id, create_connection(2))
topology.add_connection(node_c_id, node_a_id, create_connection(3))
topology.add_connection(node_b_id, node_a_id, create_connection(4))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
@@ -199,7 +202,11 @@ def test_get_shard_assignments(
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
)
cycles = topology.get_cycles()
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
selected_cycle = cycles[0]
# act
@@ -228,28 +235,21 @@ def test_get_shard_assignments(
)
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
def test_get_hosts_from_subgraph():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -257,108 +257,47 @@ def test_get_hosts_from_subgraph(
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
Host(ip=("169.254.0.2"), port=1234),
Host(ip=("169.254.0.3"), port=1234),
Host(ip=("169.254.0.4"), port=1234),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_ibv_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
def test_get_mlx_jaccl_coordinators():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(1)
conn_b_a = create_connection(2)
conn_b_c = create_connection(3)
conn_c_b = create_connection(4)
conn_c_a = create_connection(5)
conn_a_c = create_connection(6)
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_b)
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
topology.add_connection(node_a_id, node_b_id, conn_a_b)
topology.add_connection(node_b_id, node_a_id, conn_b_a)
topology.add_connection(node_b_id, node_c_id, conn_b_c)
topology.add_connection(node_c_id, node_b_id, conn_c_b)
topology.add_connection(node_c_id, node_a_id, conn_c_a)
topology.add_connection(node_a_id, node_c_id, conn_a_c)
# act
coordinators = get_mlx_ibv_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
coordinators = get_mlx_jaccl_coordinators(
node_a_id, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -387,11 +326,11 @@ def test_get_mlx_ibv_coordinators(
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
"node_b should use the IP from conn_b_a"
)
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
"node_c should use the IP from conn_c_a"
)

View File

@@ -1,13 +1,14 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
from exo.shared.types.topology import SocketConnection
@pytest.fixture
@@ -16,20 +17,15 @@ def topology() -> Topology:
@pytest.fixture
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
def connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile.from_bytes(
memory_profile = MemoryUsage.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -43,162 +39,85 @@ def node_profile() -> NodePerformanceProfile:
)
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
def test_add_node(topology: Topology):
# arrange
node_id = NodeId()
# act
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
topology.add_node(node_id)
# assert
data = topology.get_node_profile(node_id)
assert data == node_profile
assert topology.node_is_leaf(node_id)
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_add_connection(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
# act
data = topology.get_connection_profile(connection)
data = list(conn for _, _, conn in topology.list_connections())
# assert
assert data == connection.connection_profile
assert data == [connection]
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_remove_connection_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, connection: SocketConnection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
# act
topology.remove_connection(connection)
topology.remove_connection(node_a, node_b, connection)
# assert
assert topology.get_connection_profile(connection) is None
assert list(topology.get_all_connections_between(node_a, node_b)) == []
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
# act
topology.remove_node(connection.local_node_id)
topology.remove_node(node_b)
# assert
assert topology.get_node_profile(connection.local_node_id) is None
assert list(topology.out_edges(node_a)) == []
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_list_nodes(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}
assert all(isinstance(node, NodeId) for node in nodes)
assert {node for node in nodes} == {node_a, node_b}

View File

@@ -1,6 +1,7 @@
import copy
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import cast
from loguru import logger
@@ -11,10 +12,8 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeGatheredInfo,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -27,13 +26,23 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import NodeInfo
from exo.shared.types.topology import RDMAConnection
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MemoryUsage,
MiscData,
NetworkInterfaceInfo,
NodeConfig,
StaticNodeInformation,
TBConnection,
TBIdentifier,
)
def event_apply(event: Event, state: State) -> State:
@@ -47,16 +56,12 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -188,7 +193,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.copy(state.topology)
topology = copy.deepcopy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -196,8 +201,12 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -205,103 +214,78 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
# TODO: should be broken up into individual events instead of this monster
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
# TODO: makes me slightly sad
case Sequence():
if info == []:
return state
match info[0]:
case NetworkInterfaceInfo():
profile.network_interfaces = cast(
Sequence[NetworkInterfaceInfo], info
)
case TBIdentifier():
profile.tb_interfaces = cast(Sequence[TBIdentifier], info)
case TBConnection():
info = cast(Sequence[TBConnection], info)
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
(
conn_map[tb_conn.sink_uuid][0],
RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_tb_connections(
event.node_id, as_rdma_conns
)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
return state.model_copy(
update={
"node_profiles": new_profiles,
"topology": topology,
"last_seen": last_seen,
"topology": topology,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
topology = copy.deepcopy(state.topology)
topology.add_connection(event.source, event.sink, event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.sink, event.source, event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -7,29 +7,13 @@ EXO_HOME = Path.home() / EXO_HOME_RELATIVE_PATH
EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR")
EXO_MODELS_DIR = Path(EXO_MODELS_DIR_ENV) if EXO_MODELS_DIR_ENV else EXO_HOME / "models"
EXO_GLOBAL_EVENT_DB = EXO_HOME / "global_events.db"
EXO_WORKER_EVENT_DB = EXO_HOME / "worker_events.db"
EXO_MASTER_STATE = EXO_HOME / "master_state.json"
EXO_WORKER_STATE = EXO_HOME / "worker_state.json"
EXO_MASTER_LOG = EXO_HOME / "master.log"
EXO_WORKER_LOG = EXO_HOME / "worker.log"
EXO_LOG = EXO_HOME / "exo.log"
EXO_TEST_LOG = EXO_HOME / "exo_test.log"
EXO_CONFIG_FILE = EXO_HOME / "config.toml"
EXO_NODE_ID_KEYPAIR = EXO_HOME / "node_id.keypair"
EXO_WORKER_KEYRING_FILE = EXO_HOME / "worker_keyring"
EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"
EXO_IPC_DIR = EXO_HOME / "ipc"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
# lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip.
LB_TFLOPS = 2.3
LB_MEMBW_GBPS = 68
LB_DISK_GBPS = 1.5

View File

@@ -24,6 +24,8 @@ class _InterceptHandler(logging.Handler):
except ValueError:
level = record.levelno
return
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())

View File

@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
NodeDownloadProgress(download_progress=event), state
)
assert new_state == State(downloads={NodeId("node-1"): [event]})
assert new_state.downloads == {NodeId("node-1"): [event]}
def test_apply_two_node_download_progress():
@@ -39,7 +39,4 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection
from exo.shared.types.topology import SocketConnection
def test_state_serialization_roundtrip() -> None:
@@ -11,14 +11,12 @@ def test_state_serialization_roundtrip() -> None:
node_a = NodeId("node-a")
node_b = NodeId("node-b")
connection = Connection(
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
connection = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
state.topology.add_connection(connection)
state.topology.add_connection(node_a, node_b, connection)
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)

View File

@@ -1,27 +1,27 @@
import contextlib
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
class TopologySnapshot(BaseModel):
nodes: list[NodeInfo]
connections: list[Connection]
nodes: list[NodeId]
connections: list[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
model_config = ConfigDict(frozen=True, extra="forbid")
@dataclass
class Topology:
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
# the _graph can be used as a int -> NodeId map.
_graph = rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection]()
_vertex_indices = dict[NodeId, int]()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
@@ -33,171 +33,169 @@ class Topology:
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node in snapshot.nodes:
for node_id in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node)
topology.add_node(node_id)
for connection in snapshot.connections:
topology.add_connection(connection)
for source, sink, connection in snapshot.connections:
topology.add_connection(source, sink, connection)
return topology
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
return
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
self._graph[rx_id] = node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
]
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
def out_edges(
self, node_id: NodeId
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
if node_id not in self._vertex_indices:
return []
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
]
return (
(self._graph[nid], conn)
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
)
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._node_id_to_rx_id_map
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
return node_id in self._vertex_indices
def add_connection(
self,
connection: Connection,
source: NodeId,
sink: NodeId,
connection: SocketConnection | RDMAConnection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
if connection in self.get_all_connections_between(source, sink):
return
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
_ = self._graph.add_edge(src_id, sink_id, connection)
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
try:
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def list_connections(
self,
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
return (
(
self._graph[src_id],
self._graph[sink_id],
connection,
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._node_id_to_rx_id_map:
if node_id not in self._vertex_indices:
return
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
rx_idx = self._vertex_indices[node_id]
self._graph.remove_node(rx_idx)
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
del self._vertex_indices[node_id]
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
def replace_all_out_tb_connections(
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for sink, conn in new_connections:
self.add_connection(source, sink, conn)
def remove_connection(
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
) -> None:
if source not in self._vertex_indices or sink not in self._vertex_indices:
return
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[source], self._vertex_indices[sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == edge:
self._graph.remove_edge_from_index(conn_idx)
def get_cycles(self) -> list[list[NodeInfo]]:
def get_cycles(self) -> list[list[NodeId]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_cycles_tb(self) -> list[list[NodeInfo]]:
def get_cycles_tb(self) -> list[list[NodeId]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
tb_graph.add_edge(u, v, conn)
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]] = []
for cycle_idx in cycle_idxs:
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
topology = Topology()
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
for source, sink, connection in self.list_connections():
if source in node_ids and sink in node_ids:
topology.add_connection(source, sink, connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.topology import SocketConnection
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,25 +76,15 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
class NodePerformanceMeasured(BaseEvent):
# TODO: bikeshed this naem
class NodeGatheredInfo(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
class NodeDownloadProgress(BaseEvent):
@@ -107,11 +97,15 @@ class ChunkGenerated(BaseEvent):
class TopologyEdgeCreated(BaseEvent):
edge: Connection
source: NodeId
sink: NodeId
edge: SocketConnection
class TopologyEdgeDeleted(BaseEvent):
edge: Connection
source: NodeId
sink: NodeId
edge: SocketConnection
Event = (
@@ -125,10 +119,8 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeGatheredInfo
| NodeDownloadProgress
| ChunkGenerated
| TopologyEdgeCreated

View File

@@ -1,10 +1,11 @@
import re
from typing import ClassVar
from pydantic import BaseModel, computed_field, field_validator
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,12 +1,14 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import TBIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryPerformanceProfile(CamelCaseModel):
class MemoryUsage(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -44,7 +46,6 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -53,15 +54,16 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[TBIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
class ConnectionProfile(CamelCaseModel):
throughput: float
latency: float
jitter: float
pass

View File

@@ -0,0 +1,64 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class TBConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class TBIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
# Intentionally minimal, only collecting data we care about - there's a lot more
class TBReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str
class TBConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None
class TBConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None
device_name_key: str
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: TBReceptacleTag
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
if self.domain_uuid_key is None:
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
iface = f"rdma_{ifaces[tag]}"
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
def conn(self) -> TBConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
)
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
class TBConnectivity(BaseModel):
SPThunderboltDataType: list[TBConnectivityData]
@classmethod
async def gather(cls) -> list[TBConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType

View File

@@ -1,37 +1,32 @@
from exo.shared.types.common import NodeId
from enum import Enum
from loguru import logger
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import FrozenModel
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
def is_thunderbolt(self) -> bool:
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
logger.warning("duh")
return True
# TODO
class LinkType(str, Enum):
Thunderbolt = "Thunderbolt"
Ethernet = "Ethernet"
WiFi = "WiFi"
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -29,8 +29,8 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
ibv_coordinators: dict[NodeId, str]
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
# TODO: Single node instance

View File

@@ -1,43 +0,0 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Callable
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
SystemPerformanceProfile,
)
class ResourceCollector(ABC):
@abstractmethod
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
class SystemResourceCollector(ResourceCollector):
async def collect(self) -> SystemPerformanceProfile: ...
class MemoryResourceCollector(ResourceCollector):
async def collect(self) -> MemoryPerformanceProfile: ...
class ResourceMonitor:
data_collectors: list[ResourceCollector]
effect_handlers: set[
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
]
async def _collect(
self,
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
tasks: list[
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
] = [collector.collect() for collector in self.data_collectors]
return await asyncio.gather(*tasks)
async def collect(self) -> None:
profiles = await self._collect()
for profile in profiles:
for effect_handler in self.effect_handlers:
effect_handler(profile)

View File

@@ -0,0 +1,216 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import CamelCaseModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(CamelCaseModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeConfig(CamelCaseModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
# TODO
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(CamelCaseModel):
"""Node information that may change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| Sequence[NetworkInterfaceInfo]
| Sequence[TBIdentifier]
| Sequence[TBConnection]
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
if IS_DARWIN:
tg.start_soon(self._monitor_system_profiler)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
while True:
await self.info_sender.send(await MiscData.gather())
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await TBConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(idents)
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(conns)
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
await self.info_sender.send(nics)
old_nics = nics
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -0,0 +1,70 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(CamelCaseModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -1,9 +1,11 @@
import socket
from collections.abc import Mapping
from anyio import create_task_group, to_thread
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
# TODO: ref. api port
@@ -27,15 +29,15 @@ async def check_reachability(
out[target_node_id].add(target_ip)
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
async def check_reachable(
topology: Topology, profiles: Mapping[NodeId, NodePerformanceProfile]
) -> dict[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
if node not in profiles:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability, iface.ip_address, node.node_id, reachable
)
for iface in profiles[node].network_interfaces:
tg.start_soon(check_reachability, iface.ip_address, node, reachable)
return reachable

View File

@@ -19,11 +19,20 @@ class CamelCaseModel(BaseModel):
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
strict=True,
)
class FrozenModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
strict=True,
frozen=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):

View File

@@ -28,9 +28,8 @@ def bar(send: MpSender[str]):
send.close()
# not async, just want the fail_after
@pytest.mark.anyio
async def test_channel_setup():
async def test_channel_ipc():
with fail_after(0.5):
s, r = mp_channel[str]()
p1 = mp.Process(target=foo, args=(r,))

View File

@@ -101,13 +101,7 @@ def mlx_distributed_init(
bound_instance: BoundInstance,
) -> mx.distributed.Group:
"""
Initialize the MLX distributed (runs in thread pool).
Either hosts or mlx_ibv_devices must be provided:
- hosts: traditional host-based connectivity using MLX_HOSTFILE
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
Initialize the MLX distributed
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
@@ -129,22 +123,22 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
ibv_devices_json = json.dumps(ibv_devices)
jaccl_devices_json = json.dumps(jaccl_devices)
with open(devices_file, "w") as f:
_ = f.write(ibv_devices_json)
_ = f.write(jaccl_devices_json)
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
logger.info(f"rank {rank} MLX_IBV_DEVICES: {jaccl_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
os.environ["MLX_IBV_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")

View File

@@ -16,15 +16,13 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -33,7 +31,7 @@ from exo.shared.types.tasks import (
Task,
TaskStatus,
)
from exo.shared.types.topology import Connection
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -44,14 +42,14 @@ from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
class Worker:
@@ -85,7 +83,7 @@ class Worker:
self.state: State = State()
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
self._tg: TaskGroup = create_task_group()
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -97,37 +95,13 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -140,6 +114,17 @@ class Worker:
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -159,7 +144,6 @@ class Worker:
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
assert self._tg
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
@@ -248,8 +232,7 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
if self._tg:
self._tg.cancel_scope.cancel()
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
@@ -266,24 +249,24 @@ class Worker:
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
),
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
),
)
async def _nack_request(self, since_idx: int) -> None:
@@ -332,7 +315,6 @@ class Worker:
event_sender=self.event_sender.clone(),
)
self.runners[task.bound_instance.bound_runner_id] = runner
assert self._tg
self._tg.start_soon(runner.run)
return runner
@@ -391,7 +373,6 @@ class Worker:
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
@@ -414,28 +395,33 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology)
conns = await check_reachable(self.state.topology, self.state.node_profiles)
for nid in conns:
for ip in conns[nid]:
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
edge = SocketConnection(
# nonsense multiaddr
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
await self.event_sender.send(
TopologyEdgeCreated(
source=self.node_id, sink=nid, edge=edge
)
)
for nid, conn in self.state.topology.out_edges(self.node_id):
if (
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
if not isinstance(conn, SocketConnection):
continue
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
nid, set()
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await self.event_sender.send(
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
)
await anyio.sleep(10)

View File

@@ -22,7 +22,7 @@ def entrypoint(
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
and len(bound_instance.instance.jaccl_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"

View File

@@ -1,6 +0,0 @@
from .profile import start_polling_memory_metrics, start_polling_node_metrics
__all__ = [
"start_polling_node_metrics",
"start_polling_memory_metrics",
]

View File

@@ -1,114 +0,0 @@
import asyncio
import os
import platform
from typing import Any, Callable, Coroutine
import anyio
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from .macmon import (
MacMonError,
Metrics,
)
from .macmon import (
get_metrics_async as macmon_get_metrics_async,
)
from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
)
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
if platform.system().lower() == "darwin":
return await macmon_get_metrics_async()
def get_memory_profile() -> MemoryPerformanceProfile:
"""Construct a MemoryPerformanceProfile using psutil"""
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
async def start_polling_memory_metrics(
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 0.5,
) -> None:
"""Continuously poll and emit memory-only metrics at a faster cadence.
Parameters
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
- poll_interval_s: interval between polls
"""
while True:
try:
mem = get_memory_profile()
await callback(mem)
except MacMonError as e:
logger.opt(exception=e).error("Memory Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
)
)
except asyncio.TimeoutError:
logger.warning(
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
)
except MacMonError as e:
logger.opt(exception=e).error("Resource Monitor encountered error")
return
finally:
await anyio.sleep(poll_interval_s)