mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-06 13:11:16 -05:00
Compare commits
30 Commits
test-app
...
linux-cpu-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d2e828aba | ||
|
|
b5319d6b03 | ||
|
|
b988e08d69 | ||
|
|
9bf5979f8a | ||
|
|
91944383d3 | ||
|
|
dcc6872724 | ||
|
|
dccc2709c5 | ||
|
|
20d1246600 | ||
|
|
81bad9e01a | ||
|
|
7ff67d0a28 | ||
|
|
c888b13d3f | ||
|
|
1f80705b56 | ||
|
|
b349330404 | ||
|
|
812ce47194 | ||
|
|
643c6b8d28 | ||
|
|
4754f56bd4 | ||
|
|
66d01369b4 | ||
|
|
d20d9e5fc8 | ||
|
|
e67282282c | ||
|
|
54daa9e2db | ||
|
|
06125d1503 | ||
|
|
505e756872 | ||
|
|
4cd3db0f6e | ||
|
|
8b137a1e64 | ||
|
|
4176c7ec25 | ||
|
|
dbce607911 | ||
|
|
9949b93517 | ||
|
|
f4feeff077 | ||
|
|
f529884344 | ||
|
|
df4c6ce24e |
1
TODO.md
1
TODO.md
@@ -19,6 +19,7 @@
|
|||||||
25. Rethink retry logic
|
25. Rethink retry logic
|
||||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||||
|
28. Validate RDMA connections with ibv_devinfo in the info gatherer
|
||||||
|
|
||||||
Potential refactors:
|
Potential refactors:
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ interface RawNodeProfile {
|
|||||||
|
|
||||||
interface RawTopologyNode {
|
interface RawTopologyNode {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
nodeProfile: RawNodeProfile;
|
nodeProfile?: RawNodeProfile;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface RawTopologyConnection {
|
interface RawTopologyConnection {
|
||||||
@@ -105,9 +105,13 @@ interface RawTopologyConnection {
|
|||||||
sendBackMultiaddr?: { multiaddr?: string; address?: string; ip_address?: string } | string;
|
sendBackMultiaddr?: { multiaddr?: string; address?: string; ip_address?: string } | string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Connection can be an object or a tuple [source, target, metadata]
|
||||||
|
type RawConnectionItem = RawTopologyConnection | [string, string, { sinkMultiaddr?: { ip_address?: string; address?: string } }?];
|
||||||
|
|
||||||
interface RawTopology {
|
interface RawTopology {
|
||||||
nodes: RawTopologyNode[];
|
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
|
||||||
connections?: RawTopologyConnection[];
|
nodes: (string | RawTopologyNode)[];
|
||||||
|
connections?: RawConnectionItem[];
|
||||||
}
|
}
|
||||||
|
|
||||||
type RawNodeProfiles = Record<string, RawNodeProfile>;
|
type RawNodeProfiles = Record<string, RawNodeProfile>;
|
||||||
@@ -198,9 +202,17 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
|
|||||||
const nodes: Record<string, NodeInfo> = {};
|
const nodes: Record<string, NodeInfo> = {};
|
||||||
const edges: TopologyEdge[] = [];
|
const edges: TopologyEdge[] = [];
|
||||||
|
|
||||||
|
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
|
||||||
for (const node of raw.nodes || []) {
|
for (const node of raw.nodes || []) {
|
||||||
const mergedProfile = profiles?.[node.nodeId];
|
// Determine the node ID - could be a string or an object with nodeId property
|
||||||
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
|
const nodeId = typeof node === 'string' ? node : node.nodeId;
|
||||||
|
if (!nodeId) continue;
|
||||||
|
|
||||||
|
// Get the profile - from the separate profiles map or from the node object itself
|
||||||
|
const profileFromMap = profiles?.[nodeId];
|
||||||
|
const profileFromNode = typeof node === 'object' ? node.nodeProfile : undefined;
|
||||||
|
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
|
||||||
|
|
||||||
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
|
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
|
||||||
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
|
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
|
||||||
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
|
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
|
||||||
@@ -238,7 +250,7 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes[node.nodeId] = {
|
nodes[nodeId] = {
|
||||||
system_info: {
|
system_info: {
|
||||||
model_id: profile?.modelId ?? 'Unknown',
|
model_id: profile?.modelId ?? 'Unknown',
|
||||||
chip: profile?.chipId,
|
chip: profile?.chipId,
|
||||||
@@ -260,14 +272,34 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
|
||||||
for (const conn of raw.connections || []) {
|
for (const conn of raw.connections || []) {
|
||||||
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
|
let localNodeId: string | undefined;
|
||||||
if (conn.localNodeId === conn.sendBackNodeId) continue;
|
let sendBackNodeId: string | undefined;
|
||||||
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
|
let sendBackMultiaddr: { multiaddr?: string; address?: string; ip_address?: string } | string | undefined;
|
||||||
|
|
||||||
|
// Check if it's a tuple format [source, target, metadata]
|
||||||
|
if (Array.isArray(conn)) {
|
||||||
|
localNodeId = conn[0] as string;
|
||||||
|
sendBackNodeId = conn[1] as string;
|
||||||
|
const metadata = conn[2] as { sinkMultiaddr?: { ip_address?: string; address?: string } } | undefined;
|
||||||
|
if (metadata?.sinkMultiaddr) {
|
||||||
|
sendBackMultiaddr = metadata.sinkMultiaddr;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Object format with localNodeId/sendBackNodeId
|
||||||
|
localNodeId = conn.localNodeId;
|
||||||
|
sendBackNodeId = conn.sendBackNodeId;
|
||||||
|
sendBackMultiaddr = conn.sendBackMultiaddr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!localNodeId || !sendBackNodeId) continue;
|
||||||
|
if (localNodeId === sendBackNodeId) continue;
|
||||||
|
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
|
||||||
|
|
||||||
let sendBackIp: string | undefined;
|
let sendBackIp: string | undefined;
|
||||||
if (conn.sendBackMultiaddr) {
|
if (sendBackMultiaddr) {
|
||||||
const multi = conn.sendBackMultiaddr;
|
const multi = sendBackMultiaddr;
|
||||||
if (typeof multi === 'string') {
|
if (typeof multi === 'string') {
|
||||||
sendBackIp = extractIpFromMultiaddr(multi);
|
sendBackIp = extractIpFromMultiaddr(multi);
|
||||||
} else {
|
} else {
|
||||||
@@ -276,8 +308,8 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
|
|||||||
}
|
}
|
||||||
|
|
||||||
edges.push({
|
edges.push({
|
||||||
source: conn.localNodeId,
|
source: localNodeId,
|
||||||
target: conn.sendBackNodeId,
|
target: sendBackNodeId,
|
||||||
sendBackIp
|
sendBackIp
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -207,6 +207,7 @@ class API:
|
|||||||
instance_meta=instance_meta,
|
instance_meta=instance_meta,
|
||||||
min_nodes=min_nodes,
|
min_nodes=min_nodes,
|
||||||
),
|
),
|
||||||
|
node_profiles=self.state.node_profiles,
|
||||||
topology=self.state.topology,
|
topology=self.state.topology,
|
||||||
current_instances=self.state.instances,
|
current_instances=self.state.instances,
|
||||||
)
|
)
|
||||||
@@ -262,6 +263,7 @@ class API:
|
|||||||
instance_meta=instance_meta,
|
instance_meta=instance_meta,
|
||||||
min_nodes=min_nodes,
|
min_nodes=min_nodes,
|
||||||
),
|
),
|
||||||
|
node_profiles=self.state.node_profiles,
|
||||||
topology=self.state.topology,
|
topology=self.state.topology,
|
||||||
current_instances=self.state.instances,
|
current_instances=self.state.instances,
|
||||||
)
|
)
|
||||||
@@ -426,9 +428,8 @@ class API:
|
|||||||
"""Calculate total available memory across all nodes in bytes."""
|
"""Calculate total available memory across all nodes in bytes."""
|
||||||
total_available = Memory()
|
total_available = Memory()
|
||||||
|
|
||||||
for node in self.state.topology.list_nodes():
|
for profile in self.state.node_profiles.values():
|
||||||
if node.node_profile is not None:
|
total_available += profile.memory.ram_available
|
||||||
total_available += node.node_profile.memory.ram_available
|
|
||||||
|
|
||||||
return total_available
|
return total_available
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ class Master:
|
|||||||
command,
|
command,
|
||||||
self.state.topology,
|
self.state.topology,
|
||||||
self.state.instances,
|
self.state.instances,
|
||||||
|
self.state.node_profiles,
|
||||||
)
|
)
|
||||||
transition_events = get_transition_events(
|
transition_events = get_transition_events(
|
||||||
self.state.instances, placement
|
self.state.instances, placement
|
||||||
@@ -200,9 +201,7 @@ class Master:
|
|||||||
async def _plan(self) -> None:
|
async def _plan(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
# kill broken instances
|
# kill broken instances
|
||||||
connected_node_ids = set(
|
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
|
||||||
[x.node_id for x in self.state.topology.list_nodes()]
|
|
||||||
)
|
|
||||||
for instance_id, instance in self.state.instances.items():
|
for instance_id, instance in self.state.instances.items():
|
||||||
for node_id in instance.shard_assignments.node_to_runner:
|
for node_id in instance.shard_assignments.node_to_runner:
|
||||||
if node_id not in connected_node_ids:
|
if node_id not in connected_node_ids:
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ from typing import Sequence
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from exo.master.placement_utils import (
|
from exo.master.placement_utils import (
|
||||||
|
NodeWithProfile,
|
||||||
filter_cycles_by_memory,
|
filter_cycles_by_memory,
|
||||||
get_hosts_from_subgraph,
|
get_hosts_from_subgraph,
|
||||||
get_mlx_ibv_devices_matrix,
|
|
||||||
get_mlx_jaccl_coordinators,
|
get_mlx_jaccl_coordinators,
|
||||||
|
get_mlx_jaccl_devices_matrix,
|
||||||
get_shard_assignments,
|
get_shard_assignments,
|
||||||
get_smallest_cycles,
|
get_smallest_cycles,
|
||||||
)
|
)
|
||||||
@@ -19,10 +20,10 @@ from exo.shared.types.commands import (
|
|||||||
DeleteInstance,
|
DeleteInstance,
|
||||||
PlaceInstance,
|
PlaceInstance,
|
||||||
)
|
)
|
||||||
from exo.shared.types.common import Host
|
from exo.shared.types.common import Host, NodeId
|
||||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.topology import NodeInfo
|
from exo.shared.types.profiling import NodePerformanceProfile
|
||||||
from exo.shared.types.worker.instances import (
|
from exo.shared.types.worker.instances import (
|
||||||
Instance,
|
Instance,
|
||||||
InstanceId,
|
InstanceId,
|
||||||
@@ -51,19 +52,16 @@ def place_instance(
|
|||||||
command: PlaceInstance,
|
command: PlaceInstance,
|
||||||
topology: Topology,
|
topology: Topology,
|
||||||
current_instances: Mapping[InstanceId, Instance],
|
current_instances: Mapping[InstanceId, Instance],
|
||||||
|
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||||
) -> dict[InstanceId, Instance]:
|
) -> dict[InstanceId, Instance]:
|
||||||
all_nodes = list(topology.list_nodes())
|
all_nodes = list(topology.list_nodes())
|
||||||
|
|
||||||
logger.info("finding cycles:")
|
cycles = topology.get_cycles() + [[node] for node in all_nodes]
|
||||||
cycles = topology.get_cycles()
|
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||||
singleton_cycles = [[node] for node in all_nodes]
|
|
||||||
candidate_cycles = list(
|
|
||||||
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
|
|
||||||
)
|
|
||||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||||
candidate_cycles, command.model_meta.storage_size
|
candidate_cycles, node_profiles, command.model_meta.storage_size
|
||||||
)
|
)
|
||||||
if not cycles_with_sufficient_memory:
|
if len(cycles_with_sufficient_memory) == 0:
|
||||||
raise ValueError("No cycles found with sufficient memory")
|
raise ValueError("No cycles found with sufficient memory")
|
||||||
|
|
||||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||||
@@ -71,13 +69,15 @@ def place_instance(
|
|||||||
smallest_tb_cycles = [
|
smallest_tb_cycles = [
|
||||||
cycle
|
cycle
|
||||||
for cycle in smallest_cycles
|
for cycle in smallest_cycles
|
||||||
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
if topology.get_subgraph_from_nodes(
|
||||||
|
[node.node_id for node in cycle]
|
||||||
|
).is_thunderbolt_cycle([node.node_id for node in cycle])
|
||||||
]
|
]
|
||||||
|
|
||||||
if smallest_tb_cycles != []:
|
if smallest_tb_cycles != []:
|
||||||
smallest_cycles = smallest_tb_cycles
|
smallest_cycles = smallest_tb_cycles
|
||||||
|
|
||||||
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
|
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
|
||||||
cycle
|
cycle
|
||||||
for cycle in smallest_cycles
|
for cycle in smallest_cycles
|
||||||
if any(topology.node_is_leaf(node.node_id) for node in cycle)
|
if any(topology.node_is_leaf(node.node_id) for node in cycle)
|
||||||
@@ -86,11 +86,7 @@ def place_instance(
|
|||||||
selected_cycle = max(
|
selected_cycle = max(
|
||||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||||
key=lambda cycle: sum(
|
key=lambda cycle: sum(
|
||||||
(
|
(node.node_profile.memory.ram_available for node in cycle),
|
||||||
node.node_profile.memory.ram_available
|
|
||||||
for node in cycle
|
|
||||||
if node.node_profile is not None
|
|
||||||
),
|
|
||||||
start=Memory(),
|
start=Memory(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -99,14 +95,16 @@ def place_instance(
|
|||||||
command.model_meta, selected_cycle, command.sharding
|
command.model_meta, selected_cycle, command.sharding
|
||||||
)
|
)
|
||||||
|
|
||||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
|
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
|
||||||
|
[node.node_id for node in selected_cycle]
|
||||||
|
)
|
||||||
|
|
||||||
instance_id = InstanceId()
|
instance_id = InstanceId()
|
||||||
target_instances = dict(deepcopy(current_instances))
|
target_instances = dict(deepcopy(current_instances))
|
||||||
|
|
||||||
if len(selected_cycle) == 1:
|
if len(selected_cycle) == 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
|
||||||
)
|
)
|
||||||
|
|
||||||
command.instance_meta = InstanceMeta.MlxRing
|
command.instance_meta = InstanceMeta.MlxRing
|
||||||
@@ -114,19 +112,18 @@ def place_instance(
|
|||||||
# TODO: Single node instances
|
# TODO: Single node instances
|
||||||
match command.instance_meta:
|
match command.instance_meta:
|
||||||
case InstanceMeta.MlxJaccl:
|
case InstanceMeta.MlxJaccl:
|
||||||
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
|
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
||||||
selected_cycle,
|
|
||||||
cycle_digraph,
|
cycle_digraph,
|
||||||
)
|
)
|
||||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||||
selected_cycle,
|
coordinator=selected_cycle[0].node_id,
|
||||||
coordinator_port=random_ephemeral_port(),
|
coordinator_port=random_ephemeral_port(),
|
||||||
cycle_digraph=cycle_digraph,
|
cycle_digraph=cycle_digraph,
|
||||||
)
|
)
|
||||||
target_instances[instance_id] = MlxJacclInstance(
|
target_instances[instance_id] = MlxJacclInstance(
|
||||||
instance_id=instance_id,
|
instance_id=instance_id,
|
||||||
shard_assignments=shard_assignments,
|
shard_assignments=shard_assignments,
|
||||||
ibv_devices=mlx_ibv_devices,
|
jaccl_devices=mlx_jaccl_devices,
|
||||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||||
)
|
)
|
||||||
case InstanceMeta.MlxRing:
|
case InstanceMeta.MlxRing:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping
|
||||||
from typing import TypeGuard, cast
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -9,7 +8,7 @@ from exo.shared.types.common import Host, NodeId
|
|||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.models import ModelMetadata
|
from exo.shared.types.models import ModelMetadata
|
||||||
from exo.shared.types.profiling import NodePerformanceProfile
|
from exo.shared.types.profiling import NodePerformanceProfile
|
||||||
from exo.shared.types.topology import NodeInfo
|
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||||
from exo.shared.types.worker.shards import (
|
from exo.shared.types.worker.shards import (
|
||||||
PipelineShardMetadata,
|
PipelineShardMetadata,
|
||||||
@@ -24,27 +23,32 @@ class NodeWithProfile(BaseModel):
|
|||||||
node_profile: NodePerformanceProfile
|
node_profile: NodePerformanceProfile
|
||||||
|
|
||||||
|
|
||||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
|
||||||
return all(node.node_profile is not None for node in nodes)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_cycles_by_memory(
|
def filter_cycles_by_memory(
|
||||||
cycles: list[list[NodeInfo]], required_memory: Memory
|
cycles: list[list[NodeId]],
|
||||||
) -> list[list[NodeInfo]]:
|
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||||
filtered_cycles: list[list[NodeInfo]] = []
|
required_memory: Memory,
|
||||||
|
) -> list[list[NodeWithProfile]]:
|
||||||
|
filtered_cycles: list[list[NodeWithProfile]] = []
|
||||||
for cycle in cycles:
|
for cycle in cycles:
|
||||||
if not narrow_all_nodes(cycle):
|
if not all(node in node_profiles for node in cycle):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
total_mem = sum(
|
total_mem = sum(
|
||||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
|
||||||
)
|
)
|
||||||
if total_mem >= required_memory:
|
if total_mem >= required_memory:
|
||||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
filtered_cycles.append(
|
||||||
|
[
|
||||||
|
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
|
||||||
|
for node in cycle
|
||||||
|
]
|
||||||
|
)
|
||||||
return filtered_cycles
|
return filtered_cycles
|
||||||
|
|
||||||
|
|
||||||
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
def get_smallest_cycles(
|
||||||
|
cycles: list[list[NodeWithProfile]],
|
||||||
|
) -> list[list[NodeWithProfile]]:
|
||||||
min_nodes = min(len(cycle) for cycle in cycles)
|
min_nodes = min(len(cycle) for cycle in cycles)
|
||||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||||
|
|
||||||
@@ -135,11 +139,9 @@ def get_shard_assignments_for_tensor_parallel(
|
|||||||
|
|
||||||
def get_shard_assignments(
|
def get_shard_assignments(
|
||||||
model_meta: ModelMetadata,
|
model_meta: ModelMetadata,
|
||||||
selected_cycle: list[NodeInfo],
|
selected_cycle: list[NodeWithProfile],
|
||||||
sharding: Sharding,
|
sharding: Sharding,
|
||||||
) -> ShardAssignments:
|
) -> ShardAssignments:
|
||||||
if not narrow_all_nodes(selected_cycle):
|
|
||||||
raise ValueError("All nodes must have profiles to create shard assignments")
|
|
||||||
match sharding:
|
match sharding:
|
||||||
case Sharding.Pipeline:
|
case Sharding.Pipeline:
|
||||||
return get_shard_assignments_for_pipeline_parallel(
|
return get_shard_assignments_for_pipeline_parallel(
|
||||||
@@ -176,17 +178,16 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
|||||||
current_node = cycle[i]
|
current_node = cycle[i]
|
||||||
next_node = cycle[(i + 1) % len(cycle)]
|
next_node = cycle[(i + 1) % len(cycle)]
|
||||||
|
|
||||||
for connection in cycle_digraph.list_connections():
|
for src, sink, connection in cycle_digraph.list_connections():
|
||||||
if (
|
if not isinstance(connection, SocketConnection):
|
||||||
connection.local_node_id == current_node.node_id
|
continue
|
||||||
and connection.send_back_node_id == next_node.node_id
|
|
||||||
):
|
if src == current_node and sink == next_node:
|
||||||
if get_thunderbolt and not connection.is_thunderbolt():
|
if get_thunderbolt and not connection.is_thunderbolt():
|
||||||
continue
|
continue
|
||||||
assert connection.send_back_multiaddr is not None
|
|
||||||
host = Host(
|
host = Host(
|
||||||
ip=connection.send_back_multiaddr.ip_address,
|
ip=connection.sink_multiaddr.ip_address,
|
||||||
port=connection.send_back_multiaddr.port,
|
port=connection.sink_multiaddr.port,
|
||||||
)
|
)
|
||||||
hosts.append(host)
|
hosts.append(host)
|
||||||
break
|
break
|
||||||
@@ -194,8 +195,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
|||||||
return hosts
|
return hosts
|
||||||
|
|
||||||
|
|
||||||
def get_mlx_ibv_devices_matrix(
|
def get_mlx_jaccl_devices_matrix(
|
||||||
selected_cycle: list[NodeInfo],
|
|
||||||
cycle_digraph: Topology,
|
cycle_digraph: Topology,
|
||||||
) -> list[list[str | None]]:
|
) -> list[list[str | None]]:
|
||||||
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
|
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
|
||||||
@@ -204,6 +204,7 @@ def get_mlx_ibv_devices_matrix(
|
|||||||
to device j, or None if no connection exists or no interface name is found.
|
to device j, or None if no connection exists or no interface name is found.
|
||||||
Diagonal elements are always None.
|
Diagonal elements are always None.
|
||||||
"""
|
"""
|
||||||
|
selected_cycle = list(cycle_digraph.list_nodes())
|
||||||
num_nodes = len(selected_cycle)
|
num_nodes = len(selected_cycle)
|
||||||
matrix: list[list[str | None]] = [
|
matrix: list[list[str | None]] = [
|
||||||
[None for _ in range(num_nodes)] for _ in range(num_nodes)
|
[None for _ in range(num_nodes)] for _ in range(num_nodes)
|
||||||
@@ -214,86 +215,55 @@ def get_mlx_ibv_devices_matrix(
|
|||||||
if i == j:
|
if i == j:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Find the IP J uses to talk to I
|
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||||
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
|
if isinstance(conn, RDMAConnection):
|
||||||
# This is a local IP on I, which is attached to an interface: find that interface
|
matrix[i][j] = conn.source_rdma_iface
|
||||||
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
|
|
||||||
matrix[i][j] = interface_name
|
|
||||||
logger.info(
|
|
||||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
|
||||||
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
|
|
||||||
)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Current ibv backend requires all-to-all rdma connections"
|
"Current jaccl backend requires all-to-all RDMA connections"
|
||||||
)
|
)
|
||||||
|
|
||||||
return matrix
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
def _find_connection_ip(
|
def _find_connection_ip(
|
||||||
node_i: NodeInfo,
|
node_i: NodeId,
|
||||||
node_j: NodeInfo,
|
node_j: NodeId,
|
||||||
cycle_digraph: Topology,
|
cycle_digraph: Topology,
|
||||||
) -> Generator[str]:
|
) -> Generator[str]:
|
||||||
"""Find all IP addresses that connect node i to node j."""
|
"""Find all IP addresses that connect node i to node j."""
|
||||||
for connection in cycle_digraph.list_connections():
|
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
|
||||||
if (
|
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||||
connection.local_node_id == node_i.node_id
|
if isinstance(connection, SocketConnection):
|
||||||
and connection.send_back_node_id == node_j.node_id
|
yield connection.sink_multiaddr.ip_address
|
||||||
):
|
|
||||||
yield connection.send_back_multiaddr.ip_address
|
|
||||||
|
|
||||||
|
|
||||||
def _find_interface_name_for_ip(
|
|
||||||
ip_address: str,
|
|
||||||
node_info: NodeInfo,
|
|
||||||
) -> str | None:
|
|
||||||
if node_info.node_profile is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
|
|
||||||
for interface in node_info.node_profile.network_interfaces:
|
|
||||||
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
|
|
||||||
continue
|
|
||||||
logger.info(f" | {interface.name}: {interface.ip_address}")
|
|
||||||
if interface.ip_address != ip_address:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info("Found")
|
|
||||||
return f"rdma_{interface.name}"
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_mlx_jaccl_coordinators(
|
def get_mlx_jaccl_coordinators(
|
||||||
selected_cycle: list[NodeInfo],
|
coordinator: NodeId,
|
||||||
coordinator_port: int,
|
coordinator_port: int,
|
||||||
cycle_digraph: Topology,
|
cycle_digraph: Topology,
|
||||||
) -> dict[NodeId, str]:
|
) -> dict[NodeId, str]:
|
||||||
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
|
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
|
||||||
|
|
||||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||||
address in format "X.X.X.X:PORT" per node.
|
address in format "X.X.X.X:PORT" per node.
|
||||||
"""
|
"""
|
||||||
rank_0_node = selected_cycle[0]
|
selected_cycle = list(cycle_digraph.list_nodes())
|
||||||
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
logger.info(f"Selecting coordinator: {coordinator}")
|
||||||
|
|
||||||
def get_ip_for_node(n: NodeInfo) -> str:
|
def get_ip_for_node(n: NodeId) -> str:
|
||||||
if n.node_id == rank_0_node.node_id:
|
if n == coordinator:
|
||||||
return "0.0.0.0"
|
return "0.0.0.0"
|
||||||
|
|
||||||
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
|
||||||
return ip
|
return ip
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
|
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||||
)
|
)
|
||||||
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
|
||||||
|
|
||||||
return {
|
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}
|
||||||
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,67 +1,36 @@
|
|||||||
from typing import Callable
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from exo.shared.types.common import NodeId
|
|
||||||
from exo.shared.types.multiaddr import Multiaddr
|
from exo.shared.types.multiaddr import Multiaddr
|
||||||
from exo.shared.types.profiling import (
|
from exo.shared.types.profiling import (
|
||||||
MemoryPerformanceProfile,
|
MemoryUsage,
|
||||||
NodePerformanceProfile,
|
NodePerformanceProfile,
|
||||||
SystemPerformanceProfile,
|
SystemPerformanceProfile,
|
||||||
)
|
)
|
||||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def create_node_profile(memory: int) -> NodePerformanceProfile:
|
||||||
def create_node():
|
return NodePerformanceProfile(
|
||||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
model_id="test",
|
||||||
if node_id is None:
|
chip_id="test",
|
||||||
node_id = NodeId()
|
friendly_name="test",
|
||||||
return NodeInfo(
|
memory=MemoryUsage.from_bytes(
|
||||||
node_id=node_id,
|
ram_total=1000,
|
||||||
node_profile=NodePerformanceProfile(
|
ram_available=memory,
|
||||||
model_id="test",
|
swap_total=1000,
|
||||||
chip_id="test",
|
swap_available=1000,
|
||||||
friendly_name="test",
|
),
|
||||||
memory=MemoryPerformanceProfile.from_bytes(
|
network_interfaces=[],
|
||||||
ram_total=1000,
|
system=SystemPerformanceProfile(),
|
||||||
ram_available=memory,
|
)
|
||||||
swap_total=1000,
|
|
||||||
swap_available=1000,
|
|
||||||
),
|
|
||||||
network_interfaces=[],
|
|
||||||
system=SystemPerformanceProfile(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return _create_node
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||||
@pytest.fixture
|
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
return SocketConnection(
|
||||||
port_counter = 1235
|
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||||
ip_counter = 1
|
)
|
||||||
|
|
||||||
def _create_connection(
|
|
||||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
|
||||||
) -> Connection:
|
|
||||||
nonlocal port_counter
|
|
||||||
nonlocal ip_counter
|
|
||||||
# assign unique ips
|
|
||||||
ip_counter += 1
|
|
||||||
if send_back_port is None:
|
|
||||||
send_back_port = port_counter
|
|
||||||
port_counter += 1
|
|
||||||
return Connection(
|
|
||||||
local_node_id=source_node_id,
|
|
||||||
send_back_node_id=sink_node_id,
|
|
||||||
send_back_multiaddr=Multiaddr(
|
|
||||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
|
||||||
),
|
|
||||||
connection_profile=ConnectionProfile(
|
|
||||||
throughput=1000, latency=1000, jitter=1000
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return _create_connection
|
def create_rdma_connection(iface: int) -> RDMAConnection:
|
||||||
|
return RDMAConnection(
|
||||||
|
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -19,15 +19,13 @@ from exo.shared.types.events import (
|
|||||||
ForwarderEvent,
|
ForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InstanceCreated,
|
InstanceCreated,
|
||||||
NodePerformanceMeasured,
|
NodeGatheredInfo,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
)
|
)
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.models import ModelId, ModelMetadata
|
from exo.shared.types.models import ModelId, ModelMetadata
|
||||||
from exo.shared.types.profiling import (
|
from exo.shared.types.profiling import (
|
||||||
MemoryPerformanceProfile,
|
MemoryUsage,
|
||||||
NodePerformanceProfile,
|
|
||||||
SystemPerformanceProfile,
|
|
||||||
)
|
)
|
||||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||||
from exo.shared.types.tasks import TaskStatus
|
from exo.shared.types.tasks import TaskStatus
|
||||||
@@ -83,21 +81,14 @@ async def test_master():
|
|||||||
origin=sender_node_id,
|
origin=sender_node_id,
|
||||||
session=session_id,
|
session=session_id,
|
||||||
event=(
|
event=(
|
||||||
NodePerformanceMeasured(
|
NodeGatheredInfo(
|
||||||
when=str(datetime.now(tz=timezone.utc)),
|
when=str(datetime.now(tz=timezone.utc)),
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_profile=NodePerformanceProfile(
|
info=MemoryUsage(
|
||||||
model_id="maccy",
|
ram_total=Memory.from_bytes(678948 * 1024),
|
||||||
chip_id="arm",
|
ram_available=Memory.from_bytes(678948 * 1024),
|
||||||
friendly_name="test",
|
swap_total=Memory.from_bytes(0),
|
||||||
memory=MemoryPerformanceProfile(
|
swap_available=Memory.from_bytes(0),
|
||||||
ram_total=Memory.from_bytes(678948 * 1024),
|
|
||||||
ram_available=Memory.from_bytes(678948 * 1024),
|
|
||||||
swap_total=Memory.from_bytes(0),
|
|
||||||
swap_available=Memory.from_bytes(0),
|
|
||||||
),
|
|
||||||
network_interfaces=[],
|
|
||||||
system=SystemPerformanceProfile(),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@@ -161,7 +152,7 @@ async def test_master():
|
|||||||
assert events[0].idx == 0
|
assert events[0].idx == 0
|
||||||
assert events[1].idx == 1
|
assert events[1].idx == 1
|
||||||
assert events[2].idx == 2
|
assert events[2].idx == 2
|
||||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
assert isinstance(events[0].event, NodeGatheredInfo)
|
||||||
assert isinstance(events[1].event, InstanceCreated)
|
assert isinstance(events[1].event, InstanceCreated)
|
||||||
runner_id = list(
|
runner_id = list(
|
||||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import Callable
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -7,14 +5,20 @@ from exo.master.placement import (
|
|||||||
get_transition_events,
|
get_transition_events,
|
||||||
place_instance,
|
place_instance,
|
||||||
)
|
)
|
||||||
|
from exo.master.tests.conftest import (
|
||||||
|
create_connection,
|
||||||
|
create_node_profile,
|
||||||
|
create_rdma_connection,
|
||||||
|
)
|
||||||
from exo.shared.topology import Topology
|
from exo.shared.topology import Topology
|
||||||
from exo.shared.types.commands import PlaceInstance
|
from exo.shared.types.commands import PlaceInstance
|
||||||
from exo.shared.types.common import CommandId, NodeId
|
from exo.shared.types.common import CommandId, NodeId
|
||||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.models import ModelId, ModelMetadata
|
from exo.shared.types.models import ModelId, ModelMetadata
|
||||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
from exo.shared.types.multiaddr import Multiaddr
|
||||||
from exo.shared.types.topology import Connection, NodeInfo
|
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||||
|
from exo.shared.types.topology import SocketConnection
|
||||||
from exo.shared.types.worker.instances import (
|
from exo.shared.types.worker.instances import (
|
||||||
Instance,
|
Instance,
|
||||||
InstanceId,
|
InstanceId,
|
||||||
@@ -26,11 +30,6 @@ from exo.shared.types.worker.runners import ShardAssignments
|
|||||||
from exo.shared.types.worker.shards import Sharding
|
from exo.shared.types.worker.shards import Sharding
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def topology() -> Topology:
|
|
||||||
return Topology()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def instance() -> Instance:
|
def instance() -> Instance:
|
||||||
return MlxRingInstance(
|
return MlxRingInstance(
|
||||||
@@ -74,30 +73,33 @@ def test_get_instance_placements_create_instance(
|
|||||||
available_memory: tuple[int, int, int],
|
available_memory: tuple[int, int, int],
|
||||||
total_layers: int,
|
total_layers: int,
|
||||||
expected_layers: tuple[int, int, int],
|
expected_layers: tuple[int, int, int],
|
||||||
topology: Topology,
|
|
||||||
model_meta: ModelMetadata,
|
model_meta: ModelMetadata,
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
|
||||||
):
|
):
|
||||||
# arrange
|
# arrange
|
||||||
model_meta.n_layers = total_layers
|
model_meta.n_layers = total_layers
|
||||||
model_meta.storage_size.in_bytes = sum(
|
model_meta.storage_size.in_bytes = sum(
|
||||||
available_memory
|
available_memory
|
||||||
) # make it exactly fit across all nodes
|
) # make it exactly fit across all nodes
|
||||||
|
topology = Topology()
|
||||||
|
|
||||||
cic = place_instance_command(model_meta)
|
cic = place_instance_command(model_meta)
|
||||||
node_id_a = NodeId()
|
node_id_a = NodeId()
|
||||||
node_id_b = NodeId()
|
node_id_b = NodeId()
|
||||||
node_id_c = NodeId()
|
node_id_c = NodeId()
|
||||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
profiles = {
|
||||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
node_id_a: create_node_profile(available_memory[0]),
|
||||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
node_id_b: create_node_profile(available_memory[1]),
|
||||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
node_id_c: create_node_profile(available_memory[2]),
|
||||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
}
|
||||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
topology.add_node(node_id_a)
|
||||||
|
topology.add_node(node_id_b)
|
||||||
|
topology.add_node(node_id_c)
|
||||||
|
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||||
|
topology.add_connection(node_id_b, node_id_c, create_connection(2))
|
||||||
|
topology.add_connection(node_id_c, node_id_a, create_connection(3))
|
||||||
|
|
||||||
# act
|
# act
|
||||||
placements = place_instance(cic, topology, {})
|
placements = place_instance(cic, topology, {}, profiles)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(placements) == 1
|
assert len(placements) == 1
|
||||||
@@ -123,12 +125,11 @@ def test_get_instance_placements_create_instance(
|
|||||||
assert shards_sorted[-1].end_layer == total_layers
|
assert shards_sorted[-1].end_layer == total_layers
|
||||||
|
|
||||||
|
|
||||||
def test_get_instance_placements_one_node_exact_fit(
|
def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
) -> None:
|
|
||||||
topology = Topology()
|
topology = Topology()
|
||||||
node_id = NodeId()
|
node_id = NodeId()
|
||||||
topology.add_node(create_node(1000 * 1024, node_id))
|
topology.add_node(node_id)
|
||||||
|
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||||
cic = place_instance_command(
|
cic = place_instance_command(
|
||||||
ModelMetadata(
|
ModelMetadata(
|
||||||
model_id=ModelId("test-model"),
|
model_id=ModelId("test-model"),
|
||||||
@@ -137,7 +138,7 @@ def test_get_instance_placements_one_node_exact_fit(
|
|||||||
n_layers=10,
|
n_layers=10,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
placements = place_instance(cic, topology, {})
|
placements = place_instance(cic, topology, {}, profiles)
|
||||||
|
|
||||||
assert len(placements) == 1
|
assert len(placements) == 1
|
||||||
instance_id = list(placements.keys())[0]
|
instance_id = list(placements.keys())[0]
|
||||||
@@ -148,12 +149,11 @@ def test_get_instance_placements_one_node_exact_fit(
|
|||||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
) -> None:
|
|
||||||
topology = Topology()
|
topology = Topology()
|
||||||
node_id = NodeId()
|
node_id = NodeId()
|
||||||
topology.add_node(create_node(1001 * 1024, node_id))
|
topology.add_node(node_id)
|
||||||
|
profiles = {node_id: create_node_profile(1001 * 1024)}
|
||||||
cic = place_instance_command(
|
cic = place_instance_command(
|
||||||
ModelMetadata(
|
ModelMetadata(
|
||||||
model_id=ModelId("test-model"),
|
model_id=ModelId("test-model"),
|
||||||
@@ -162,7 +162,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
|||||||
n_layers=10,
|
n_layers=10,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
placements = place_instance(cic, topology, {})
|
placements = place_instance(cic, topology, {}, profiles)
|
||||||
|
|
||||||
assert len(placements) == 1
|
assert len(placements) == 1
|
||||||
instance_id = list(placements.keys())[0]
|
instance_id = list(placements.keys())[0]
|
||||||
@@ -173,12 +173,11 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
|||||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_get_instance_placements_one_node_not_fit(
|
def test_get_instance_placements_one_node_not_fit() -> None:
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
) -> None:
|
|
||||||
topology = Topology()
|
topology = Topology()
|
||||||
node_id = NodeId()
|
node_id = NodeId()
|
||||||
topology.add_node(create_node(1000 * 1024, node_id))
|
topology.add_node(node_id)
|
||||||
|
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||||
cic = place_instance_command(
|
cic = place_instance_command(
|
||||||
model_meta=ModelMetadata(
|
model_meta=ModelMetadata(
|
||||||
model_id=ModelId("test-model"),
|
model_id=ModelId("test-model"),
|
||||||
@@ -189,7 +188,7 @@ def test_get_instance_placements_one_node_not_fit(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
||||||
place_instance(cic, topology, {})
|
place_instance(cic, topology, {}, profiles)
|
||||||
|
|
||||||
|
|
||||||
def test_get_transition_events_no_change(instance: Instance):
|
def test_get_transition_events_no_change(instance: Instance):
|
||||||
@@ -235,190 +234,102 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
|||||||
|
|
||||||
|
|
||||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||||
topology: Topology,
|
|
||||||
model_meta: ModelMetadata,
|
model_meta: ModelMetadata,
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
|
||||||
):
|
):
|
||||||
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
|
# arrange
|
||||||
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
|
topology = Topology()
|
||||||
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
|
|
||||||
# cycle having LESS total memory than D-E-F. The algorithm should still choose
|
|
||||||
# the cycle that contains a leaf node.
|
|
||||||
|
|
||||||
# Model requires more than any single node but fits within a 3-node cycle
|
model_meta.storage_size = Memory.from_bytes(1000)
|
||||||
model_meta.storage_size.in_bytes = 1500
|
|
||||||
model_meta.n_layers = 12
|
|
||||||
|
|
||||||
# Create node ids
|
|
||||||
node_id_a = NodeId()
|
node_id_a = NodeId()
|
||||||
node_id_b = NodeId()
|
node_id_b = NodeId()
|
||||||
node_id_c = NodeId()
|
node_id_c = NodeId()
|
||||||
node_id_d = NodeId()
|
node_id_d = NodeId()
|
||||||
node_id_e = NodeId()
|
|
||||||
node_id_f = NodeId()
|
|
||||||
|
|
||||||
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
|
profiles = {
|
||||||
node_id_x = NodeId()
|
node_id_a: create_node_profile(500),
|
||||||
node_id_y = NodeId()
|
node_id_b: create_node_profile(600),
|
||||||
node_id_z = NodeId()
|
node_id_c: create_node_profile(600),
|
||||||
|
node_id_d: create_node_profile(500),
|
||||||
|
}
|
||||||
|
|
||||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
topology.add_node(node_id_a)
|
||||||
topology.add_node(create_node(400, node_id_a))
|
topology.add_node(node_id_b)
|
||||||
topology.add_node(create_node(400, node_id_b))
|
topology.add_node(node_id_c)
|
||||||
topology.add_node(create_node(800, node_id_c))
|
topology.add_node(node_id_d)
|
||||||
|
|
||||||
# D-E-F cycle total memory = 1800 (> A-B-C total)
|
# Daisy chain topology
|
||||||
topology.add_node(create_node(600, node_id_d))
|
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||||
topology.add_node(create_node(600, node_id_e))
|
topology.add_connection(node_id_b, node_id_a, create_connection(1))
|
||||||
topology.add_node(create_node(600, node_id_f))
|
topology.add_connection(node_id_b, node_id_c, create_connection(1))
|
||||||
|
topology.add_connection(node_id_c, node_id_b, create_connection(1))
|
||||||
|
topology.add_connection(node_id_c, node_id_d, create_connection(1))
|
||||||
|
topology.add_connection(node_id_d, node_id_c, create_connection(1))
|
||||||
|
|
||||||
# Extra nodes with tiny memory so they can't form singleton placements
|
logger.info(list(topology.list_connections()))
|
||||||
topology.add_node(create_node(10, node_id_x))
|
|
||||||
topology.add_node(create_node(10, node_id_y))
|
|
||||||
topology.add_node(create_node(10, node_id_z))
|
|
||||||
|
|
||||||
# Build directed cycles
|
|
||||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
|
||||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
|
||||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
|
||||||
|
|
||||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
|
||||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
|
||||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
|
||||||
|
|
||||||
# Add extra outgoing edges from D/E/F so none of them are leaves
|
|
||||||
topology.add_connection(create_connection(node_id_d, node_id_x))
|
|
||||||
topology.add_connection(create_connection(node_id_e, node_id_y))
|
|
||||||
topology.add_connection(create_connection(node_id_f, node_id_z))
|
|
||||||
|
|
||||||
cic = place_instance_command(
|
cic = place_instance_command(
|
||||||
model_meta=model_meta,
|
model_meta=model_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act
|
# act
|
||||||
placements = place_instance(cic, topology, {})
|
placements = place_instance(cic, topology, {}, profiles)
|
||||||
|
|
||||||
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
|
# assert
|
||||||
# D-E-F has more total memory.
|
|
||||||
assert len(placements) == 1
|
assert len(placements) == 1
|
||||||
instance_id = list(placements.keys())[0]
|
instance = list(placements.values())[0]
|
||||||
instance = placements[instance_id]
|
|
||||||
|
|
||||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||||
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
|
||||||
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
(node_id_c, node_id_d)
|
||||||
|
)
|
||||||
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
|
|
||||||
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_rdma_backend_connectivity_matrix(
|
def test_tensor_rdma_backend_connectivity_matrix(
|
||||||
topology: Topology,
|
|
||||||
model_meta: ModelMetadata,
|
model_meta: ModelMetadata,
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
|
||||||
):
|
):
|
||||||
|
topology = Topology()
|
||||||
model_meta.n_layers = 12
|
model_meta.n_layers = 12
|
||||||
model_meta.storage_size.in_bytes = 1500
|
model_meta.storage_size.in_bytes = 1500
|
||||||
|
|
||||||
node_id_a = NodeId()
|
node_a = NodeId()
|
||||||
node_id_b = NodeId()
|
node_b = NodeId()
|
||||||
node_id_c = NodeId()
|
node_c = NodeId()
|
||||||
|
|
||||||
node_a = create_node(500, node_id_a)
|
profiles = {
|
||||||
node_b = create_node(500, node_id_b)
|
node_a: create_node_profile(500),
|
||||||
node_c = create_node(500, node_id_c)
|
node_b: create_node_profile(500),
|
||||||
|
node_c: create_node_profile(500),
|
||||||
|
}
|
||||||
|
|
||||||
ethernet_interface = NetworkInterfaceInfo(
|
ethernet_interface = NetworkInterfaceInfo(
|
||||||
name="en0",
|
name="en0",
|
||||||
ip_address="192.168.1.100",
|
ip_address="192.168.1.100",
|
||||||
)
|
)
|
||||||
|
ethernet_conn = SocketConnection(
|
||||||
assert node_a.node_profile is not None
|
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
|
||||||
assert node_b.node_profile is not None
|
|
||||||
assert node_c.node_profile is not None
|
|
||||||
|
|
||||||
conn_a_b = create_connection(node_id_a, node_id_b)
|
|
||||||
conn_b_c = create_connection(node_id_b, node_id_c)
|
|
||||||
conn_c_a = create_connection(node_id_c, node_id_a)
|
|
||||||
|
|
||||||
conn_b_a = create_connection(node_id_b, node_id_a)
|
|
||||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
|
||||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
|
||||||
|
|
||||||
assert conn_a_b.send_back_multiaddr is not None
|
|
||||||
assert conn_b_c.send_back_multiaddr is not None
|
|
||||||
assert conn_c_a.send_back_multiaddr is not None
|
|
||||||
|
|
||||||
assert conn_b_a.send_back_multiaddr is not None
|
|
||||||
assert conn_c_b.send_back_multiaddr is not None
|
|
||||||
assert conn_a_c.send_back_multiaddr is not None
|
|
||||||
|
|
||||||
node_a.node_profile = NodePerformanceProfile(
|
|
||||||
model_id="test",
|
|
||||||
chip_id="test",
|
|
||||||
friendly_name="test",
|
|
||||||
memory=node_a.node_profile.memory,
|
|
||||||
network_interfaces=[
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en3",
|
|
||||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en4",
|
|
||||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
ethernet_interface,
|
|
||||||
],
|
|
||||||
system=node_a.node_profile.system,
|
|
||||||
)
|
|
||||||
node_b.node_profile = NodePerformanceProfile(
|
|
||||||
model_id="test",
|
|
||||||
chip_id="test",
|
|
||||||
friendly_name="test",
|
|
||||||
memory=node_b.node_profile.memory,
|
|
||||||
network_interfaces=[
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en3",
|
|
||||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en4",
|
|
||||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
ethernet_interface,
|
|
||||||
],
|
|
||||||
system=node_b.node_profile.system,
|
|
||||||
)
|
|
||||||
node_c.node_profile = NodePerformanceProfile(
|
|
||||||
model_id="test",
|
|
||||||
chip_id="test",
|
|
||||||
friendly_name="test",
|
|
||||||
memory=node_c.node_profile.memory,
|
|
||||||
network_interfaces=[
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en3",
|
|
||||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en4",
|
|
||||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
ethernet_interface,
|
|
||||||
],
|
|
||||||
system=node_c.node_profile.system,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
profiles[node_a].network_interfaces = [ethernet_interface]
|
||||||
|
profiles[node_b].network_interfaces = [ethernet_interface]
|
||||||
|
profiles[node_c].network_interfaces = [ethernet_interface]
|
||||||
|
|
||||||
topology.add_node(node_a)
|
topology.add_node(node_a)
|
||||||
topology.add_node(node_b)
|
topology.add_node(node_b)
|
||||||
topology.add_node(node_c)
|
topology.add_node(node_c)
|
||||||
topology.add_connection(conn_a_b)
|
topology.add_connection(node_a, node_b, create_rdma_connection(3))
|
||||||
topology.add_connection(conn_b_c)
|
topology.add_connection(node_b, node_c, create_rdma_connection(4))
|
||||||
topology.add_connection(conn_c_a)
|
topology.add_connection(node_c, node_a, create_rdma_connection(5))
|
||||||
topology.add_connection(conn_b_a)
|
topology.add_connection(node_b, node_a, create_rdma_connection(3))
|
||||||
topology.add_connection(conn_c_b)
|
topology.add_connection(node_c, node_b, create_rdma_connection(4))
|
||||||
topology.add_connection(conn_a_c)
|
topology.add_connection(node_a, node_c, create_rdma_connection(5))
|
||||||
|
|
||||||
|
topology.add_connection(node_a, node_b, ethernet_conn)
|
||||||
|
topology.add_connection(node_b, node_c, ethernet_conn)
|
||||||
|
topology.add_connection(node_c, node_a, ethernet_conn)
|
||||||
|
topology.add_connection(node_a, node_c, ethernet_conn)
|
||||||
|
topology.add_connection(node_b, node_a, ethernet_conn)
|
||||||
|
topology.add_connection(node_c, node_b, ethernet_conn)
|
||||||
|
|
||||||
cic = PlaceInstance(
|
cic = PlaceInstance(
|
||||||
sharding=Sharding.Tensor,
|
sharding=Sharding.Tensor,
|
||||||
@@ -428,7 +339,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
|||||||
min_nodes=1,
|
min_nodes=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
placements = place_instance(cic, topology, {})
|
placements = place_instance(cic, topology, {}, profiles)
|
||||||
|
|
||||||
assert len(placements) == 1
|
assert len(placements) == 1
|
||||||
instance_id = list(placements.keys())[0]
|
instance_id = list(placements.keys())[0]
|
||||||
@@ -436,10 +347,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
|||||||
|
|
||||||
assert isinstance(instance, MlxJacclInstance)
|
assert isinstance(instance, MlxJacclInstance)
|
||||||
|
|
||||||
assert instance.ibv_devices is not None
|
assert instance.jaccl_devices is not None
|
||||||
assert instance.jaccl_coordinators is not None
|
assert instance.jaccl_coordinators is not None
|
||||||
|
|
||||||
matrix = instance.ibv_devices
|
matrix = instance.jaccl_devices
|
||||||
assert len(matrix) == 3
|
assert len(matrix) == 3
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
@@ -448,15 +359,15 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
|||||||
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
|
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
|
||||||
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
|
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
|
||||||
|
|
||||||
idx_a = node_to_idx[node_id_a]
|
idx_a = node_to_idx[node_a]
|
||||||
idx_b = node_to_idx[node_id_b]
|
idx_b = node_to_idx[node_b]
|
||||||
idx_c = node_to_idx[node_id_c]
|
idx_c = node_to_idx[node_c]
|
||||||
|
|
||||||
logger.info(matrix)
|
logger.info(matrix)
|
||||||
|
|
||||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||||
|
|
||||||
# Verify coordinators are set for all nodes
|
# Verify coordinators are set for all nodes
|
||||||
assert len(instance.jaccl_coordinators) == 3
|
assert len(instance.jaccl_coordinators) == 3
|
||||||
|
|||||||
@@ -1,56 +1,48 @@
|
|||||||
from typing import Callable
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from exo.master.placement_utils import (
|
from exo.master.placement_utils import (
|
||||||
|
NodeWithProfile,
|
||||||
filter_cycles_by_memory,
|
filter_cycles_by_memory,
|
||||||
get_hosts_from_subgraph,
|
get_hosts_from_subgraph,
|
||||||
get_mlx_jaccl_coordinators,
|
get_mlx_jaccl_coordinators,
|
||||||
get_shard_assignments,
|
get_shard_assignments,
|
||||||
get_smallest_cycles,
|
get_smallest_cycles,
|
||||||
)
|
)
|
||||||
|
from exo.master.tests.conftest import create_connection, create_node_profile
|
||||||
from exo.shared.topology import Topology
|
from exo.shared.topology import Topology
|
||||||
from exo.shared.types.common import Host, NodeId
|
from exo.shared.types.common import Host, NodeId
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.models import ModelId, ModelMetadata
|
from exo.shared.types.models import ModelId, ModelMetadata
|
||||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
|
||||||
from exo.shared.types.topology import Connection, NodeInfo
|
|
||||||
from exo.shared.types.worker.shards import Sharding
|
from exo.shared.types.worker.shards import Sharding
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def test_filter_cycles_by_memory():
|
||||||
def topology() -> Topology:
|
|
||||||
topology = Topology()
|
|
||||||
return topology
|
|
||||||
|
|
||||||
|
|
||||||
def test_filter_cycles_by_memory(
|
|
||||||
topology: Topology,
|
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
node1_id = NodeId()
|
node1_id = NodeId()
|
||||||
node2_id = NodeId()
|
node2_id = NodeId()
|
||||||
|
topology = Topology()
|
||||||
|
|
||||||
node1 = create_node(1000 * 1024, node1_id)
|
node1 = create_node_profile(1000 * 1024)
|
||||||
node2 = create_node(1000 * 1024, node2_id)
|
node2 = create_node_profile(1000 * 1024)
|
||||||
|
node_profiles = {node1_id: node1, node2_id: node2}
|
||||||
|
|
||||||
topology.add_node(node1)
|
topology.add_node(node1_id)
|
||||||
topology.add_node(node2)
|
topology.add_node(node2_id)
|
||||||
|
|
||||||
connection1 = create_connection(node1_id, node2_id)
|
connection1 = create_connection(1)
|
||||||
connection2 = create_connection(node2_id, node1_id)
|
connection2 = create_connection(2)
|
||||||
|
|
||||||
topology.add_connection(connection1)
|
topology.add_connection(node1_id, node2_id, connection1)
|
||||||
topology.add_connection(connection2)
|
topology.add_connection(node2_id, node1_id, connection2)
|
||||||
|
|
||||||
cycles = topology.get_cycles()
|
cycles = topology.get_cycles()
|
||||||
assert len(cycles) == 1
|
assert len(cycles) == 1
|
||||||
assert len(cycles[0]) == 2
|
assert len(cycles[0]) == 2
|
||||||
|
|
||||||
# act
|
# act
|
||||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
|
filtered_cycles = filter_cycles_by_memory(
|
||||||
|
cycles, node_profiles, Memory.from_bytes(1)
|
||||||
|
)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(filtered_cycles) == 1
|
assert len(filtered_cycles) == 1
|
||||||
@@ -58,64 +50,65 @@ def test_filter_cycles_by_memory(
|
|||||||
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||||
|
|
||||||
|
|
||||||
def test_filter_cycles_by_insufficient_memory(
|
def test_filter_cycles_by_insufficient_memory():
|
||||||
topology: Topology,
|
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
node1_id = NodeId()
|
node1_id = NodeId()
|
||||||
node2_id = NodeId()
|
node2_id = NodeId()
|
||||||
|
topology = Topology()
|
||||||
|
|
||||||
node1 = create_node(1000 * 1024, node1_id)
|
node1 = create_node_profile(1000 * 1024)
|
||||||
node2 = create_node(1000 * 1024, node2_id)
|
node2 = create_node_profile(1000 * 1024)
|
||||||
|
node_profiles = {node1_id: node1, node2_id: node2}
|
||||||
|
|
||||||
topology.add_node(node1)
|
topology.add_node(node1_id)
|
||||||
topology.add_node(node2)
|
topology.add_node(node2_id)
|
||||||
|
|
||||||
connection1 = create_connection(node1_id, node2_id)
|
connection1 = create_connection(1)
|
||||||
connection2 = create_connection(node2_id, node1_id)
|
connection2 = create_connection(2)
|
||||||
|
|
||||||
topology.add_connection(connection1)
|
topology.add_connection(node1_id, node2_id, connection1)
|
||||||
topology.add_connection(connection2)
|
topology.add_connection(node2_id, node1_id, connection2)
|
||||||
|
|
||||||
# act
|
# act
|
||||||
filtered_cycles = filter_cycles_by_memory(
|
filtered_cycles = filter_cycles_by_memory(
|
||||||
topology.get_cycles(), Memory.from_kb(2001)
|
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
|
||||||
)
|
)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(filtered_cycles) == 0
|
assert len(filtered_cycles) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_filter_multiple_cycles_by_memory(
|
def test_filter_multiple_cycles_by_memory():
|
||||||
topology: Topology,
|
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
node_a_id = NodeId()
|
node_a_id = NodeId()
|
||||||
node_b_id = NodeId()
|
node_b_id = NodeId()
|
||||||
node_c_id = NodeId()
|
node_c_id = NodeId()
|
||||||
|
topology = Topology()
|
||||||
|
|
||||||
node_a = create_node(500 * 1024, node_a_id)
|
node_a = create_node_profile(500 * 1024)
|
||||||
node_b = create_node(500 * 1024, node_b_id)
|
node_b = create_node_profile(500 * 1024)
|
||||||
node_c = create_node(1000 * 1024, node_c_id)
|
node_c = create_node_profile(1000 * 1024)
|
||||||
|
node_profiles = {
|
||||||
|
node_a_id: node_a,
|
||||||
|
node_b_id: node_b,
|
||||||
|
node_c_id: node_c,
|
||||||
|
}
|
||||||
|
|
||||||
topology.add_node(node_a)
|
topology.add_node(node_a_id)
|
||||||
topology.add_node(node_b)
|
topology.add_node(node_b_id)
|
||||||
topology.add_node(node_c)
|
topology.add_node(node_c_id)
|
||||||
|
|
||||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||||
|
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
|
||||||
|
|
||||||
cycles = topology.get_cycles()
|
cycles = topology.get_cycles()
|
||||||
|
|
||||||
# act
|
# act
|
||||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
|
filtered_cycles = filter_cycles_by_memory(
|
||||||
|
cycles, node_profiles, Memory.from_kb(1500)
|
||||||
|
)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(filtered_cycles) == 1
|
assert len(filtered_cycles) == 1
|
||||||
@@ -127,31 +120,38 @@ def test_filter_multiple_cycles_by_memory(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_get_smallest_cycles(
|
def test_get_smallest_cycles():
|
||||||
topology: Topology,
|
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
node_a_id = NodeId()
|
node_a_id = NodeId()
|
||||||
node_b_id = NodeId()
|
node_b_id = NodeId()
|
||||||
node_c_id = NodeId()
|
node_c_id = NodeId()
|
||||||
|
topology = Topology()
|
||||||
|
|
||||||
node_a = create_node(500 * 1024, node_a_id)
|
node_a = create_node_profile(500 * 1024)
|
||||||
node_b = create_node(500 * 1024, node_b_id)
|
node_b = create_node_profile(500 * 1024)
|
||||||
node_c = create_node(1000 * 1024, node_c_id)
|
node_c = create_node_profile(1000 * 1024)
|
||||||
|
node_profiles = {
|
||||||
|
node_a_id: node_a,
|
||||||
|
node_b_id: node_b,
|
||||||
|
node_c_id: node_c,
|
||||||
|
}
|
||||||
|
|
||||||
topology.add_node(node_a)
|
topology.add_node(node_a_id)
|
||||||
topology.add_node(node_b)
|
topology.add_node(node_b_id)
|
||||||
topology.add_node(node_c)
|
topology.add_node(node_c_id)
|
||||||
|
|
||||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||||
|
|
||||||
|
cycles = [
|
||||||
|
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||||
|
for cycle in topology.get_cycles()
|
||||||
|
]
|
||||||
|
|
||||||
# act
|
# act
|
||||||
smallest_cycles = get_smallest_cycles(topology.get_cycles())
|
smallest_cycles = get_smallest_cycles(cycles)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(smallest_cycles) == 1
|
assert len(smallest_cycles) == 1
|
||||||
@@ -168,9 +168,6 @@ def test_get_smallest_cycles(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_shard_assignments(
|
def test_get_shard_assignments(
|
||||||
topology: Topology,
|
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
|
||||||
available_memory: tuple[int, int, int],
|
available_memory: tuple[int, int, int],
|
||||||
total_layers: int,
|
total_layers: int,
|
||||||
expected_layers: tuple[int, int, int],
|
expected_layers: tuple[int, int, int],
|
||||||
@@ -179,19 +176,25 @@ def test_get_shard_assignments(
|
|||||||
node_a_id = NodeId()
|
node_a_id = NodeId()
|
||||||
node_b_id = NodeId()
|
node_b_id = NodeId()
|
||||||
node_c_id = NodeId()
|
node_c_id = NodeId()
|
||||||
|
topology = Topology()
|
||||||
|
|
||||||
node_a = create_node(available_memory[0] * 1024, node_a_id)
|
node_a = create_node_profile(available_memory[0] * 1024)
|
||||||
node_b = create_node(available_memory[1] * 1024, node_b_id)
|
node_b = create_node_profile(available_memory[1] * 1024)
|
||||||
node_c = create_node(available_memory[2] * 1024, node_c_id)
|
node_c = create_node_profile(available_memory[2] * 1024)
|
||||||
|
node_profiles = {
|
||||||
|
node_a_id: node_a,
|
||||||
|
node_b_id: node_b,
|
||||||
|
node_c_id: node_c,
|
||||||
|
}
|
||||||
|
|
||||||
topology.add_node(node_a)
|
topology.add_node(node_a_id)
|
||||||
topology.add_node(node_b)
|
topology.add_node(node_b_id)
|
||||||
topology.add_node(node_c)
|
topology.add_node(node_c_id)
|
||||||
|
|
||||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
topology.add_connection(node_b_id, node_c_id, create_connection(2))
|
||||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
topology.add_connection(node_c_id, node_a_id, create_connection(3))
|
||||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
topology.add_connection(node_b_id, node_a_id, create_connection(4))
|
||||||
|
|
||||||
model_meta = ModelMetadata(
|
model_meta = ModelMetadata(
|
||||||
model_id=ModelId("test-model"),
|
model_id=ModelId("test-model"),
|
||||||
@@ -199,7 +202,11 @@ def test_get_shard_assignments(
|
|||||||
n_layers=total_layers,
|
n_layers=total_layers,
|
||||||
storage_size=Memory.from_kb(1000),
|
storage_size=Memory.from_kb(1000),
|
||||||
)
|
)
|
||||||
cycles = topology.get_cycles()
|
|
||||||
|
cycles = [
|
||||||
|
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||||
|
for cycle in topology.get_cycles()
|
||||||
|
]
|
||||||
selected_cycle = cycles[0]
|
selected_cycle = cycles[0]
|
||||||
|
|
||||||
# act
|
# act
|
||||||
@@ -228,28 +235,21 @@ def test_get_shard_assignments(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_hosts_from_subgraph(
|
def test_get_hosts_from_subgraph():
|
||||||
topology: Topology,
|
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
node_a_id = NodeId()
|
node_a_id = NodeId()
|
||||||
node_b_id = NodeId()
|
node_b_id = NodeId()
|
||||||
node_c_id = NodeId()
|
node_c_id = NodeId()
|
||||||
|
topology = Topology()
|
||||||
|
|
||||||
node_a = create_node(500, node_a_id)
|
topology.add_node(node_a_id)
|
||||||
node_b = create_node(500, node_b_id)
|
topology.add_node(node_b_id)
|
||||||
node_c = create_node(1000, node_c_id)
|
topology.add_node(node_c_id)
|
||||||
|
|
||||||
topology.add_node(node_a)
|
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||||
topology.add_node(node_b)
|
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||||
topology.add_node(node_c)
|
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||||
|
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||||
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
|
|
||||||
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
|
|
||||||
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
|
|
||||||
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
|
|
||||||
|
|
||||||
# act
|
# act
|
||||||
hosts = get_hosts_from_subgraph(topology)
|
hosts = get_hosts_from_subgraph(topology)
|
||||||
@@ -257,108 +257,47 @@ def test_get_hosts_from_subgraph(
|
|||||||
# assert
|
# assert
|
||||||
assert len(hosts) == 3
|
assert len(hosts) == 3
|
||||||
expected_hosts = [
|
expected_hosts = [
|
||||||
Host(ip=("169.254.0.2"), port=5001),
|
Host(ip=("169.254.0.2"), port=1234),
|
||||||
Host(ip=("169.254.0.3"), port=5002),
|
Host(ip=("169.254.0.3"), port=1234),
|
||||||
Host(ip=("169.254.0.4"), port=5003),
|
Host(ip=("169.254.0.4"), port=1234),
|
||||||
]
|
]
|
||||||
for expected_host in expected_hosts:
|
for expected_host in expected_hosts:
|
||||||
assert expected_host in hosts
|
assert expected_host in hosts
|
||||||
|
|
||||||
|
|
||||||
def test_get_mlx_jaccl_coordinators(
|
def test_get_mlx_jaccl_coordinators():
|
||||||
topology: Topology,
|
|
||||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
|
||||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
node_a_id = NodeId()
|
node_a_id = NodeId()
|
||||||
node_b_id = NodeId()
|
node_b_id = NodeId()
|
||||||
node_c_id = NodeId()
|
node_c_id = NodeId()
|
||||||
|
topology = Topology()
|
||||||
|
|
||||||
node_a = create_node(500 * 1024, node_a_id)
|
topology.add_node(node_a_id)
|
||||||
node_b = create_node(500 * 1024, node_b_id)
|
topology.add_node(node_b_id)
|
||||||
node_c = create_node(1000 * 1024, node_c_id)
|
topology.add_node(node_c_id)
|
||||||
|
|
||||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
|
||||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
|
||||||
|
|
||||||
# Update node profiles with network interfaces before adding to topology
|
conn_a_b = create_connection(1)
|
||||||
assert node_a.node_profile is not None
|
conn_b_a = create_connection(2)
|
||||||
assert node_b.node_profile is not None
|
conn_b_c = create_connection(3)
|
||||||
assert node_c.node_profile is not None
|
conn_c_b = create_connection(4)
|
||||||
|
conn_c_a = create_connection(5)
|
||||||
|
conn_a_c = create_connection(6)
|
||||||
|
|
||||||
node_a.node_profile = NodePerformanceProfile(
|
topology.add_connection(node_a_id, node_b_id, conn_a_b)
|
||||||
model_id="test",
|
topology.add_connection(node_b_id, node_a_id, conn_b_a)
|
||||||
chip_id="test",
|
topology.add_connection(node_b_id, node_c_id, conn_b_c)
|
||||||
friendly_name="test",
|
topology.add_connection(node_c_id, node_b_id, conn_c_b)
|
||||||
memory=node_a.node_profile.memory,
|
topology.add_connection(node_c_id, node_a_id, conn_c_a)
|
||||||
network_interfaces=[
|
topology.add_connection(node_a_id, node_c_id, conn_a_c)
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en3",
|
|
||||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en4",
|
|
||||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
system=node_a.node_profile.system,
|
|
||||||
)
|
|
||||||
node_b.node_profile = NodePerformanceProfile(
|
|
||||||
model_id="test",
|
|
||||||
chip_id="test",
|
|
||||||
friendly_name="test",
|
|
||||||
memory=node_b.node_profile.memory,
|
|
||||||
network_interfaces=[
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en3",
|
|
||||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en4",
|
|
||||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
system=node_b.node_profile.system,
|
|
||||||
)
|
|
||||||
node_c.node_profile = NodePerformanceProfile(
|
|
||||||
model_id="test",
|
|
||||||
chip_id="test",
|
|
||||||
friendly_name="test",
|
|
||||||
memory=node_c.node_profile.memory,
|
|
||||||
network_interfaces=[
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en3",
|
|
||||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
NetworkInterfaceInfo(
|
|
||||||
name="en4",
|
|
||||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
system=node_c.node_profile.system,
|
|
||||||
)
|
|
||||||
|
|
||||||
topology.add_node(node_a)
|
|
||||||
topology.add_node(node_b)
|
|
||||||
topology.add_node(node_c)
|
|
||||||
|
|
||||||
topology.add_connection(conn_a_b)
|
|
||||||
topology.add_connection(conn_b_a)
|
|
||||||
topology.add_connection(conn_b_c)
|
|
||||||
topology.add_connection(conn_c_b)
|
|
||||||
topology.add_connection(conn_c_a)
|
|
||||||
topology.add_connection(conn_a_c)
|
|
||||||
|
|
||||||
cycle = [node_a, node_b, node_c]
|
|
||||||
|
|
||||||
# act
|
# act
|
||||||
coordinators = get_mlx_jaccl_coordinators(
|
coordinators = get_mlx_jaccl_coordinators(
|
||||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
node_a_id, coordinator_port=5000, cycle_digraph=topology
|
||||||
)
|
)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
@@ -387,11 +326,11 @@ def test_get_mlx_jaccl_coordinators(
|
|||||||
|
|
||||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||||
assert coordinators[node_b_id] == (
|
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
|
||||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
"node_b should use the IP from conn_b_a"
|
||||||
), "node_b should use the IP from conn_b_a"
|
)
|
||||||
|
|
||||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||||
assert coordinators[node_c_id] == (
|
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
|
||||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
"node_c should use the IP from conn_c_a"
|
||||||
), "node_c should use the IP from conn_c_a"
|
)
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from exo.shared.topology import Topology
|
from exo.shared.topology import Topology
|
||||||
|
from exo.shared.types.common import NodeId
|
||||||
from exo.shared.types.multiaddr import Multiaddr
|
from exo.shared.types.multiaddr import Multiaddr
|
||||||
from exo.shared.types.profiling import (
|
from exo.shared.types.profiling import (
|
||||||
MemoryPerformanceProfile,
|
MemoryUsage,
|
||||||
NodePerformanceProfile,
|
NodePerformanceProfile,
|
||||||
SystemPerformanceProfile,
|
SystemPerformanceProfile,
|
||||||
)
|
)
|
||||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
from exo.shared.types.topology import SocketConnection
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -16,20 +17,15 @@ def topology() -> Topology:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def connection() -> Connection:
|
def connection() -> SocketConnection:
|
||||||
return Connection(
|
return SocketConnection(
|
||||||
local_node_id=NodeId(),
|
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||||
send_back_node_id=NodeId(),
|
|
||||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
|
||||||
connection_profile=ConnectionProfile(
|
|
||||||
throughput=1000, latency=1000, jitter=1000
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def node_profile() -> NodePerformanceProfile:
|
def node_profile() -> NodePerformanceProfile:
|
||||||
memory_profile = MemoryPerformanceProfile.from_bytes(
|
memory_profile = MemoryUsage.from_bytes(
|
||||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||||
)
|
)
|
||||||
system_profile = SystemPerformanceProfile()
|
system_profile = SystemPerformanceProfile()
|
||||||
@@ -43,162 +39,85 @@ def node_profile() -> NodePerformanceProfile:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def test_add_node(topology: Topology):
|
||||||
def connection_profile() -> ConnectionProfile:
|
|
||||||
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
|
||||||
# arrange
|
# arrange
|
||||||
node_id = NodeId()
|
node_id = NodeId()
|
||||||
|
|
||||||
# act
|
# act
|
||||||
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
|
topology.add_node(node_id)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
data = topology.get_node_profile(node_id)
|
assert topology.node_is_leaf(node_id)
|
||||||
assert data == node_profile
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_connection(
|
def test_add_connection(topology: Topology, connection: SocketConnection):
|
||||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
topology.add_node(
|
node_a = NodeId()
|
||||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
node_b = NodeId()
|
||||||
)
|
|
||||||
topology.add_node(
|
topology.add_node(node_a)
|
||||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
topology.add_node(node_b)
|
||||||
)
|
topology.add_connection(node_a, node_b, connection)
|
||||||
topology.add_connection(connection)
|
|
||||||
|
|
||||||
# act
|
# act
|
||||||
data = topology.get_connection_profile(connection)
|
data = list(conn for _, _, conn in topology.list_connections())
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert data == connection.connection_profile
|
assert data == [connection]
|
||||||
|
|
||||||
|
assert topology.node_is_leaf(node_a)
|
||||||
def test_update_node_profile(
|
assert topology.node_is_leaf(node_b)
|
||||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
|
||||||
):
|
|
||||||
# arrange
|
|
||||||
topology.add_node(
|
|
||||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
|
||||||
)
|
|
||||||
topology.add_node(
|
|
||||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
|
||||||
)
|
|
||||||
topology.add_connection(connection)
|
|
||||||
|
|
||||||
new_node_profile = NodePerformanceProfile(
|
|
||||||
model_id="test",
|
|
||||||
chip_id="test",
|
|
||||||
friendly_name="test",
|
|
||||||
memory=MemoryPerformanceProfile.from_bytes(
|
|
||||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
|
||||||
),
|
|
||||||
network_interfaces=[],
|
|
||||||
system=SystemPerformanceProfile(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# act
|
|
||||||
topology.update_node_profile(
|
|
||||||
connection.local_node_id, node_profile=new_node_profile
|
|
||||||
)
|
|
||||||
|
|
||||||
# assert
|
|
||||||
data = topology.get_node_profile(connection.local_node_id)
|
|
||||||
assert data == new_node_profile
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_connection_profile(
|
|
||||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
|
||||||
):
|
|
||||||
# arrange
|
|
||||||
topology.add_node(
|
|
||||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
|
||||||
)
|
|
||||||
topology.add_node(
|
|
||||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
|
||||||
)
|
|
||||||
topology.add_connection(connection)
|
|
||||||
|
|
||||||
new_connection_profile = ConnectionProfile(
|
|
||||||
throughput=2000, latency=2000, jitter=2000
|
|
||||||
)
|
|
||||||
connection = Connection(
|
|
||||||
local_node_id=connection.local_node_id,
|
|
||||||
send_back_node_id=connection.send_back_node_id,
|
|
||||||
send_back_multiaddr=connection.send_back_multiaddr,
|
|
||||||
connection_profile=new_connection_profile,
|
|
||||||
)
|
|
||||||
|
|
||||||
# act
|
|
||||||
topology.update_connection_profile(connection)
|
|
||||||
|
|
||||||
# assert
|
|
||||||
data = topology.get_connection_profile(connection)
|
|
||||||
assert data == new_connection_profile
|
|
||||||
|
|
||||||
|
|
||||||
def test_remove_connection_still_connected(
|
def test_remove_connection_still_connected(
|
||||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
topology: Topology, connection: SocketConnection
|
||||||
):
|
):
|
||||||
# arrange
|
# arrange
|
||||||
topology.add_node(
|
node_a = NodeId()
|
||||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
node_b = NodeId()
|
||||||
)
|
|
||||||
topology.add_node(
|
topology.add_node(node_a)
|
||||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
topology.add_node(node_b)
|
||||||
)
|
topology.add_connection(node_a, node_b, connection)
|
||||||
topology.add_connection(connection)
|
|
||||||
|
|
||||||
# act
|
# act
|
||||||
topology.remove_connection(connection)
|
topology.remove_connection(node_a, node_b, connection)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert topology.get_connection_profile(connection) is None
|
assert list(topology.get_all_connections_between(node_a, node_b)) == []
|
||||||
|
|
||||||
|
|
||||||
def test_remove_node_still_connected(
|
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
|
||||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
topology.add_node(
|
node_a = NodeId()
|
||||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
node_b = NodeId()
|
||||||
)
|
|
||||||
topology.add_node(
|
topology.add_node(node_a)
|
||||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
topology.add_node(node_b)
|
||||||
)
|
topology.add_connection(node_a, node_b, connection)
|
||||||
topology.add_connection(connection)
|
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||||
|
|
||||||
# act
|
# act
|
||||||
topology.remove_node(connection.local_node_id)
|
topology.remove_node(node_b)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert topology.get_node_profile(connection.local_node_id) is None
|
assert list(topology.out_edges(node_a)) == []
|
||||||
|
|
||||||
|
|
||||||
def test_list_nodes(
|
def test_list_nodes(topology: Topology, connection: SocketConnection):
|
||||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
|
||||||
):
|
|
||||||
# arrange
|
# arrange
|
||||||
topology.add_node(
|
node_a = NodeId()
|
||||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
node_b = NodeId()
|
||||||
)
|
|
||||||
topology.add_node(
|
topology.add_node(node_a)
|
||||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
topology.add_node(node_b)
|
||||||
)
|
topology.add_connection(node_a, node_b, connection)
|
||||||
topology.add_connection(connection)
|
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||||
|
|
||||||
# act
|
# act
|
||||||
nodes = list(topology.list_nodes())
|
nodes = list(topology.list_nodes())
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
assert len(nodes) == 2
|
assert len(nodes) == 2
|
||||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
assert all(isinstance(node, NodeId) for node in nodes)
|
||||||
assert {node.node_id for node in nodes} == {
|
assert {node for node in nodes} == {node_a, node_b}
|
||||||
connection.local_node_id,
|
|
||||||
connection.send_back_node_id,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,10 +11,8 @@ from exo.shared.types.events import (
|
|||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InstanceCreated,
|
InstanceCreated,
|
||||||
InstanceDeleted,
|
InstanceDeleted,
|
||||||
NodeCreated,
|
|
||||||
NodeDownloadProgress,
|
NodeDownloadProgress,
|
||||||
NodeMemoryMeasured,
|
NodeGatheredInfo,
|
||||||
NodePerformanceMeasured,
|
|
||||||
NodeTimedOut,
|
NodeTimedOut,
|
||||||
RunnerDeleted,
|
RunnerDeleted,
|
||||||
RunnerStatusUpdated,
|
RunnerStatusUpdated,
|
||||||
@@ -27,13 +25,23 @@ from exo.shared.types.events import (
|
|||||||
TopologyEdgeCreated,
|
TopologyEdgeCreated,
|
||||||
TopologyEdgeDeleted,
|
TopologyEdgeDeleted,
|
||||||
)
|
)
|
||||||
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
|
from exo.shared.types.profiling import NodePerformanceProfile
|
||||||
from exo.shared.types.state import State
|
from exo.shared.types.state import State
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||||
from exo.shared.types.topology import NodeInfo
|
from exo.shared.types.topology import RDMAConnection
|
||||||
from exo.shared.types.worker.downloads import DownloadProgress
|
from exo.shared.types.worker.downloads import DownloadProgress
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||||
|
from exo.utils.info_gatherer.info_gatherer import (
|
||||||
|
MacmonMetrics,
|
||||||
|
MacTBConnections,
|
||||||
|
MacTBIdentifiers,
|
||||||
|
MemoryUsage,
|
||||||
|
MiscData,
|
||||||
|
NodeConfig,
|
||||||
|
NodeNetworkInterfaces,
|
||||||
|
StaticNodeInformation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def event_apply(event: Event, state: State) -> State:
|
def event_apply(event: Event, state: State) -> State:
|
||||||
@@ -47,16 +55,12 @@ def event_apply(event: Event, state: State) -> State:
|
|||||||
return apply_instance_created(event, state)
|
return apply_instance_created(event, state)
|
||||||
case InstanceDeleted():
|
case InstanceDeleted():
|
||||||
return apply_instance_deleted(event, state)
|
return apply_instance_deleted(event, state)
|
||||||
case NodeCreated():
|
|
||||||
return apply_topology_node_created(event, state)
|
|
||||||
case NodeTimedOut():
|
case NodeTimedOut():
|
||||||
return apply_node_timed_out(event, state)
|
return apply_node_timed_out(event, state)
|
||||||
case NodePerformanceMeasured():
|
|
||||||
return apply_node_performance_measured(event, state)
|
|
||||||
case NodeDownloadProgress():
|
case NodeDownloadProgress():
|
||||||
return apply_node_download_progress(event, state)
|
return apply_node_download_progress(event, state)
|
||||||
case NodeMemoryMeasured():
|
case NodeGatheredInfo():
|
||||||
return apply_node_memory_measured(event, state)
|
return apply_node_gathered_info(event, state)
|
||||||
case RunnerDeleted():
|
case RunnerDeleted():
|
||||||
return apply_runner_deleted(event, state)
|
return apply_runner_deleted(event, state)
|
||||||
case RunnerStatusUpdated():
|
case RunnerStatusUpdated():
|
||||||
@@ -188,7 +192,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
|||||||
|
|
||||||
|
|
||||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||||
topology = copy.copy(state.topology)
|
topology = copy.deepcopy(state.topology)
|
||||||
state.topology.remove_node(event.node_id)
|
state.topology.remove_node(event.node_id)
|
||||||
node_profiles = {
|
node_profiles = {
|
||||||
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
||||||
@@ -196,8 +200,12 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
|||||||
last_seen = {
|
last_seen = {
|
||||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||||
}
|
}
|
||||||
|
downloads = {
|
||||||
|
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||||
|
}
|
||||||
return state.model_copy(
|
return state.model_copy(
|
||||||
update={
|
update={
|
||||||
|
"downloads": downloads,
|
||||||
"topology": topology,
|
"topology": topology,
|
||||||
"node_profiles": node_profiles,
|
"node_profiles": node_profiles,
|
||||||
"last_seen": last_seen,
|
"last_seen": last_seen,
|
||||||
@@ -205,103 +213,69 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_node_performance_measured(
|
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||||
event: NodePerformanceMeasured, state: State
|
topology = copy.deepcopy(state.topology)
|
||||||
) -> State:
|
topology.add_node(event.node_id)
|
||||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
info = event.info
|
||||||
**state.node_profiles,
|
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
|
||||||
event.node_id: event.node_profile,
|
# TODO: should be broken up into individual events instead of this monster
|
||||||
}
|
match info:
|
||||||
last_seen: Mapping[NodeId, datetime] = {
|
case MacmonMetrics():
|
||||||
**state.last_seen,
|
profile.system = info.system_profile
|
||||||
event.node_id: datetime.fromisoformat(event.when),
|
profile.memory = info.memory
|
||||||
}
|
case MemoryUsage():
|
||||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
profile.memory = info
|
||||||
topology = copy.copy(state.topology)
|
case NodeConfig():
|
||||||
# TODO: NodeCreated
|
pass
|
||||||
if not topology.contains_node(event.node_id):
|
case MiscData():
|
||||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
profile.friendly_name = info.friendly_name
|
||||||
topology.update_node_profile(event.node_id, event.node_profile)
|
case StaticNodeInformation():
|
||||||
|
profile.model_id = info.model
|
||||||
|
profile.chip_id = info.chip
|
||||||
|
# TODO: makes me slightly sad
|
||||||
|
case NodeNetworkInterfaces():
|
||||||
|
profile.network_interfaces = info.ifaces
|
||||||
|
case MacTBIdentifiers():
|
||||||
|
profile.tb_interfaces = info.idents
|
||||||
|
case MacTBConnections():
|
||||||
|
conn_map = {
|
||||||
|
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
|
||||||
|
for nid in state.node_profiles
|
||||||
|
for tb_ident in state.node_profiles[nid].tb_interfaces
|
||||||
|
}
|
||||||
|
as_rdma_conns = [
|
||||||
|
(
|
||||||
|
conn_map[tb_conn.sink_uuid][0],
|
||||||
|
RDMAConnection(
|
||||||
|
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
|
||||||
|
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tb_conn in info.conns
|
||||||
|
if tb_conn.source_uuid in conn_map
|
||||||
|
if tb_conn.sink_uuid in conn_map
|
||||||
|
]
|
||||||
|
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
|
||||||
|
|
||||||
|
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
|
||||||
|
new_profiles = {**state.node_profiles, event.node_id: profile}
|
||||||
return state.model_copy(
|
return state.model_copy(
|
||||||
update={
|
update={
|
||||||
"node_profiles": new_profiles,
|
"node_profiles": new_profiles,
|
||||||
"topology": topology,
|
|
||||||
"last_seen": last_seen,
|
"last_seen": last_seen,
|
||||||
|
"topology": topology,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
|
|
||||||
existing = state.node_profiles.get(event.node_id)
|
|
||||||
topology = copy.copy(state.topology)
|
|
||||||
|
|
||||||
if existing is None:
|
|
||||||
created = NodePerformanceProfile(
|
|
||||||
model_id="unknown",
|
|
||||||
chip_id="unknown",
|
|
||||||
friendly_name="Unknown",
|
|
||||||
memory=event.memory,
|
|
||||||
network_interfaces=[],
|
|
||||||
system=SystemPerformanceProfile(
|
|
||||||
# TODO: flops_fp16=0.0,
|
|
||||||
gpu_usage=0.0,
|
|
||||||
temp=0.0,
|
|
||||||
sys_power=0.0,
|
|
||||||
pcpu_usage=0.0,
|
|
||||||
ecpu_usage=0.0,
|
|
||||||
ane_power=0.0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
|
||||||
**state.node_profiles,
|
|
||||||
event.node_id: created,
|
|
||||||
}
|
|
||||||
last_seen: Mapping[NodeId, datetime] = {
|
|
||||||
**state.last_seen,
|
|
||||||
event.node_id: datetime.fromisoformat(event.when),
|
|
||||||
}
|
|
||||||
if not topology.contains_node(event.node_id):
|
|
||||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
|
||||||
# TODO: NodeCreated
|
|
||||||
topology.update_node_profile(event.node_id, created)
|
|
||||||
return state.model_copy(
|
|
||||||
update={
|
|
||||||
"node_profiles": created_profiles,
|
|
||||||
"topology": topology,
|
|
||||||
"last_seen": last_seen,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
updated = existing.model_copy(update={"memory": event.memory})
|
|
||||||
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
|
||||||
**state.node_profiles,
|
|
||||||
event.node_id: updated,
|
|
||||||
}
|
|
||||||
# TODO: NodeCreated
|
|
||||||
if not topology.contains_node(event.node_id):
|
|
||||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
|
||||||
topology.update_node_profile(event.node_id, updated)
|
|
||||||
return state.model_copy(
|
|
||||||
update={"node_profiles": updated_profiles, "topology": topology}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
|
|
||||||
topology = copy.copy(state.topology)
|
|
||||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
|
||||||
return state.model_copy(update={"topology": topology})
|
|
||||||
|
|
||||||
|
|
||||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||||
topology = copy.copy(state.topology)
|
topology = copy.deepcopy(state.topology)
|
||||||
topology.add_connection(event.edge)
|
topology.add_connection(event.source, event.sink, event.edge)
|
||||||
return state.model_copy(update={"topology": topology})
|
return state.model_copy(update={"topology": topology})
|
||||||
|
|
||||||
|
|
||||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||||
topology = copy.copy(state.topology)
|
topology = copy.deepcopy(state.topology)
|
||||||
if not topology.contains_connection(event.edge):
|
topology.remove_connection(event.sink, event.source, event.edge)
|
||||||
return state
|
|
||||||
topology.remove_connection(event.edge)
|
|
||||||
# TODO: Clean up removing the reverse connection
|
# TODO: Clean up removing the reverse connection
|
||||||
return state.model_copy(update={"topology": topology})
|
return state.model_copy(update={"topology": topology})
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
|||||||
|
|
||||||
# Identity (config)
|
# Identity (config)
|
||||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||||
|
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||||
|
|
||||||
# libp2p topics for event forwarding
|
# libp2p topics for event forwarding
|
||||||
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ class _InterceptHandler(logging.Handler):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
level = record.levelno
|
level = record.levelno
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())
|
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
|
|||||||
NodeDownloadProgress(download_progress=event), state
|
NodeDownloadProgress(download_progress=event), state
|
||||||
)
|
)
|
||||||
|
|
||||||
assert new_state == State(downloads={NodeId("node-1"): [event]})
|
assert new_state.downloads == {NodeId("node-1"): [event]}
|
||||||
|
|
||||||
|
|
||||||
def test_apply_two_node_download_progress():
|
def test_apply_two_node_download_progress():
|
||||||
@@ -39,7 +39,4 @@ def test_apply_two_node_download_progress():
|
|||||||
NodeDownloadProgress(download_progress=event2), state
|
NodeDownloadProgress(download_progress=event2), state
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: This test is failing. We should support the following:
|
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
|
||||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
|
||||||
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from exo.shared.types.common import NodeId
|
from exo.shared.types.common import NodeId
|
||||||
from exo.shared.types.multiaddr import Multiaddr
|
from exo.shared.types.multiaddr import Multiaddr
|
||||||
from exo.shared.types.state import State
|
from exo.shared.types.state import State
|
||||||
from exo.shared.types.topology import Connection
|
from exo.shared.types.topology import SocketConnection
|
||||||
|
|
||||||
|
|
||||||
def test_state_serialization_roundtrip() -> None:
|
def test_state_serialization_roundtrip() -> None:
|
||||||
@@ -11,14 +11,12 @@ def test_state_serialization_roundtrip() -> None:
|
|||||||
node_a = NodeId("node-a")
|
node_a = NodeId("node-a")
|
||||||
node_b = NodeId("node-b")
|
node_b = NodeId("node-b")
|
||||||
|
|
||||||
connection = Connection(
|
connection = SocketConnection(
|
||||||
local_node_id=node_a,
|
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||||
send_back_node_id=node_b,
|
|
||||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state = State()
|
state = State()
|
||||||
state.topology.add_connection(connection)
|
state.topology.add_connection(node_a, node_b, connection)
|
||||||
|
|
||||||
json_repr = state.model_dump_json()
|
json_repr = state.model_dump_json()
|
||||||
restored_state = State.model_validate_json(json_repr)
|
restored_state = State.model_validate_json(json_repr)
|
||||||
|
|||||||
@@ -1,203 +1,219 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
import rustworkx as rx
|
import rustworkx as rx
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from exo.shared.types.common import NodeId
|
from exo.shared.types.common import NodeId
|
||||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||||
from exo.shared.types.topology import Connection, NodeInfo
|
|
||||||
|
|
||||||
|
|
||||||
class TopologySnapshot(BaseModel):
|
class TopologySnapshot(BaseModel):
|
||||||
nodes: list[NodeInfo]
|
nodes: Sequence[NodeId]
|
||||||
connections: list[Connection]
|
connections: Mapping[
|
||||||
|
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
|
||||||
|
]
|
||||||
|
|
||||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
model_config = ConfigDict(frozen=True, extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Topology:
|
class Topology:
|
||||||
def __init__(self) -> None:
|
# the _graph can be used as a int -> NodeId map.
|
||||||
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
|
||||||
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
|
init=False, default_factory=rx.PyDiGraph
|
||||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
)
|
||||||
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
|
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
|
||||||
|
|
||||||
def to_snapshot(self) -> TopologySnapshot:
|
def to_snapshot(self) -> TopologySnapshot:
|
||||||
return TopologySnapshot(
|
return TopologySnapshot(
|
||||||
nodes=list(self.list_nodes()),
|
nodes=list(self.list_nodes()), connections=self.map_connections()
|
||||||
connections=list(self.list_connections()),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
|
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
|
||||||
topology = cls()
|
topology = cls()
|
||||||
|
|
||||||
for node in snapshot.nodes:
|
for node_id in snapshot.nodes:
|
||||||
with contextlib.suppress(ValueError):
|
with contextlib.suppress(ValueError):
|
||||||
topology.add_node(node)
|
topology.add_node(node_id)
|
||||||
|
|
||||||
for connection in snapshot.connections:
|
for source in snapshot.connections:
|
||||||
topology.add_connection(connection)
|
for sink in snapshot.connections[source]:
|
||||||
|
for conn in snapshot.connections[source][sink]:
|
||||||
|
topology.add_connection(source, sink, conn)
|
||||||
|
|
||||||
return topology
|
return topology
|
||||||
|
|
||||||
def add_node(self, node: NodeInfo) -> None:
|
def add_node(self, node_id: NodeId) -> None:
|
||||||
if node.node_id in self._node_id_to_rx_id_map:
|
if node_id in self._vertex_indices:
|
||||||
return
|
return
|
||||||
rx_id = self._graph.add_node(node)
|
rx_id = self._graph.add_node(node_id)
|
||||||
self._node_id_to_rx_id_map[node.node_id] = rx_id
|
self._vertex_indices[node_id] = rx_id
|
||||||
self._rx_id_to_node_id_map[rx_id] = node.node_id
|
|
||||||
|
|
||||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||||
return (
|
return (
|
||||||
node_id in self._node_id_to_rx_id_map
|
node_id in self._vertex_indices
|
||||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
|
||||||
)
|
)
|
||||||
|
|
||||||
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
||||||
return [
|
return [
|
||||||
self._rx_id_to_node_id_map[rx_id]
|
self._graph[rx_id]
|
||||||
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
|
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
|
||||||
]
|
]
|
||||||
|
|
||||||
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
|
def out_edges(
|
||||||
if node_id not in self._node_id_to_rx_id_map:
|
self, node_id: NodeId
|
||||||
|
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
|
||||||
|
if node_id not in self._vertex_indices:
|
||||||
return []
|
return []
|
||||||
return [
|
return (
|
||||||
(self._rx_id_to_node_id_map[nid], conn)
|
(self._graph[nid], conn)
|
||||||
for _, nid, conn in self._graph.out_edges(
|
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
|
||||||
self._node_id_to_rx_id_map[node_id]
|
)
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def contains_node(self, node_id: NodeId) -> bool:
|
def contains_node(self, node_id: NodeId) -> bool:
|
||||||
return node_id in self._node_id_to_rx_id_map
|
return node_id in self._vertex_indices
|
||||||
|
|
||||||
def contains_connection(self, connection: Connection) -> bool:
|
|
||||||
return connection in self._edge_id_to_rx_id_map
|
|
||||||
|
|
||||||
def add_connection(
|
def add_connection(
|
||||||
self,
|
self,
|
||||||
connection: Connection,
|
source: NodeId,
|
||||||
|
sink: NodeId,
|
||||||
|
connection: SocketConnection | RDMAConnection,
|
||||||
) -> None:
|
) -> None:
|
||||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
if connection in self.get_all_connections_between(source, sink):
|
||||||
self.add_node(NodeInfo(node_id=connection.local_node_id))
|
|
||||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
|
||||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
|
||||||
|
|
||||||
if connection in self._edge_id_to_rx_id_map:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
if source not in self._vertex_indices:
|
||||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
self.add_node(source)
|
||||||
|
if sink not in self._vertex_indices:
|
||||||
|
self.add_node(sink)
|
||||||
|
|
||||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
src_id = self._vertex_indices[source]
|
||||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
sink_id = self._vertex_indices[sink]
|
||||||
|
|
||||||
def list_nodes(self) -> Iterable[NodeInfo]:
|
_ = self._graph.add_edge(src_id, sink_id, connection)
|
||||||
return (self._graph[i] for i in self._graph.node_indices())
|
|
||||||
|
|
||||||
def list_connections(self) -> Iterable[Connection]:
|
def get_all_connections_between(
|
||||||
return (connection for _, _, connection in self._graph.weighted_edge_list())
|
self, source: NodeId, sink: NodeId
|
||||||
|
) -> Iterable[SocketConnection | RDMAConnection]:
|
||||||
|
if source not in self._vertex_indices:
|
||||||
|
return []
|
||||||
|
if sink not in self._vertex_indices:
|
||||||
|
return []
|
||||||
|
|
||||||
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
|
src_id = self._vertex_indices[source]
|
||||||
|
sink_id = self._vertex_indices[sink]
|
||||||
try:
|
try:
|
||||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
return self._graph.get_all_edge_data(src_id, sink_id)
|
||||||
return self._graph.get_node_data(rx_idx).node_profile
|
except rx.NoEdgeBetweenNodes:
|
||||||
except KeyError:
|
return []
|
||||||
return None
|
|
||||||
|
|
||||||
def update_node_profile(
|
def list_nodes(self) -> Iterable[NodeId]:
|
||||||
self, node_id: NodeId, node_profile: NodePerformanceProfile
|
return self._graph.nodes()
|
||||||
) -> None:
|
|
||||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
|
||||||
self._graph[rx_idx].node_profile = node_profile
|
|
||||||
|
|
||||||
def update_connection_profile(self, connection: Connection) -> None:
|
def map_connections(
|
||||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
self,
|
||||||
self._graph.update_edge_by_index(rx_idx, connection)
|
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
|
||||||
|
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
|
||||||
|
for src_id, sink_id, connection in self._graph.weighted_edge_list():
|
||||||
|
source = self._graph[src_id]
|
||||||
|
sink = self._graph[sink_id]
|
||||||
|
if source not in base:
|
||||||
|
base[source] = {}
|
||||||
|
if sink not in base[source]:
|
||||||
|
base[source][sink] = []
|
||||||
|
base[source][sink].append(connection)
|
||||||
|
return base
|
||||||
|
|
||||||
def get_connection_profile(
|
def list_connections(
|
||||||
self, connection: Connection
|
self,
|
||||||
) -> ConnectionProfile | None:
|
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
|
||||||
try:
|
return (
|
||||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
(
|
||||||
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
|
self._graph[src_id],
|
||||||
except KeyError:
|
self._graph[sink_id],
|
||||||
return None
|
connection,
|
||||||
|
)
|
||||||
|
for src_id, sink_id, connection in self._graph.weighted_edge_list()
|
||||||
|
)
|
||||||
|
|
||||||
def remove_node(self, node_id: NodeId) -> None:
|
def remove_node(self, node_id: NodeId) -> None:
|
||||||
if node_id not in self._node_id_to_rx_id_map:
|
if node_id not in self._vertex_indices:
|
||||||
return
|
return
|
||||||
|
|
||||||
for connection in self.list_connections():
|
rx_idx = self._vertex_indices[node_id]
|
||||||
if (
|
|
||||||
connection.local_node_id == node_id
|
|
||||||
or connection.send_back_node_id == node_id
|
|
||||||
):
|
|
||||||
self.remove_connection(connection)
|
|
||||||
|
|
||||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
|
||||||
self._graph.remove_node(rx_idx)
|
self._graph.remove_node(rx_idx)
|
||||||
|
|
||||||
del self._node_id_to_rx_id_map[node_id]
|
del self._vertex_indices[node_id]
|
||||||
del self._rx_id_to_node_id_map[rx_idx]
|
|
||||||
|
|
||||||
def remove_connection(self, connection: Connection) -> None:
|
def replace_all_out_tb_connections(
|
||||||
if connection not in self._edge_id_to_rx_id_map:
|
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
|
||||||
|
) -> None:
|
||||||
|
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
|
||||||
|
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
|
||||||
|
self._graph.remove_edge_from_index(conn_idx)
|
||||||
|
for sink, conn in new_connections:
|
||||||
|
self.add_connection(source, sink, conn)
|
||||||
|
|
||||||
|
def remove_connection(
|
||||||
|
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
|
||||||
|
) -> None:
|
||||||
|
if source not in self._vertex_indices or sink not in self._vertex_indices:
|
||||||
return
|
return
|
||||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
for conn_idx in self._graph.edge_indices_from_endpoints(
|
||||||
self._graph.remove_edge_from_index(rx_idx)
|
self._vertex_indices[source], self._vertex_indices[sink]
|
||||||
del self._edge_id_to_rx_id_map[connection]
|
):
|
||||||
|
if self._graph.get_edge_data_by_index(conn_idx) == edge:
|
||||||
|
self._graph.remove_edge_from_index(conn_idx)
|
||||||
|
|
||||||
def get_cycles(self) -> list[list[NodeInfo]]:
|
def get_cycles(self) -> list[list[NodeId]]:
|
||||||
cycle_idxs = rx.simple_cycles(self._graph)
|
cycle_idxs = rx.simple_cycles(self._graph)
|
||||||
cycles: list[list[NodeInfo]] = []
|
cycles: list[list[NodeId]] = []
|
||||||
for cycle_idx in cycle_idxs:
|
for cycle_idx in cycle_idxs:
|
||||||
cycle = [self._graph[idx] for idx in cycle_idx]
|
cycle = [self._graph[idx] for idx in cycle_idx]
|
||||||
cycles.append(cycle)
|
cycles.append(cycle)
|
||||||
|
|
||||||
return cycles
|
return cycles
|
||||||
|
|
||||||
def get_cycles_tb(self) -> list[list[NodeInfo]]:
|
def get_cycles_tb(self) -> list[list[NodeId]]:
|
||||||
tb_edges = [
|
tb_edges = [
|
||||||
(u, v, conn)
|
(u, v, conn)
|
||||||
for u, v, conn in self._graph.weighted_edge_list()
|
for u, v, conn in self._graph.weighted_edge_list()
|
||||||
if conn.is_thunderbolt()
|
if conn.is_thunderbolt()
|
||||||
]
|
]
|
||||||
|
|
||||||
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||||
tb_graph.add_nodes_from(self._graph.nodes())
|
tb_graph.add_nodes_from(self._graph.nodes())
|
||||||
|
|
||||||
for u, v, conn in tb_edges:
|
for u, v, conn in tb_edges:
|
||||||
tb_graph.add_edge(u, v, conn)
|
if isinstance(conn, SocketConnection):
|
||||||
|
tb_graph.add_edge(u, v, conn)
|
||||||
|
|
||||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||||
cycles: list[list[NodeInfo]] = []
|
cycles: list[list[NodeId]] = []
|
||||||
for cycle_idx in cycle_idxs:
|
for cycle_idx in cycle_idxs:
|
||||||
cycle = [tb_graph[idx] for idx in cycle_idx]
|
cycle = [tb_graph[idx] for idx in cycle_idx]
|
||||||
cycles.append(cycle)
|
cycles.append(cycle)
|
||||||
|
|
||||||
return cycles
|
return cycles
|
||||||
|
|
||||||
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
|
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
|
||||||
node_idxs = [node.node_id for node in nodes]
|
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
|
||||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
|
||||||
topology = Topology()
|
topology = Topology()
|
||||||
for rx_idx in rx_idxs:
|
for rx_idx in rx_idxs:
|
||||||
topology.add_node(self._graph[rx_idx])
|
topology.add_node(self._graph[rx_idx])
|
||||||
for connection in self.list_connections():
|
for source, sink, connection in self.list_connections():
|
||||||
if (
|
if source in node_ids and sink in node_ids:
|
||||||
connection.local_node_id in node_idxs
|
topology.add_connection(source, sink, connection)
|
||||||
and connection.send_back_node_id in node_idxs
|
|
||||||
):
|
|
||||||
topology.add_connection(connection)
|
|
||||||
return topology
|
return topology
|
||||||
|
|
||||||
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
|
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
|
||||||
node_idxs = [node.node_id for node in cycle]
|
node_idxs = [node for node in cycle]
|
||||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||||
for rid in rx_idxs:
|
for rid in rx_idxs:
|
||||||
for neighbor_rid in self._graph.neighbors(rid):
|
for neighbor_rid in self._graph.neighbors(rid):
|
||||||
if neighbor_rid not in rx_idxs:
|
if neighbor_rid not in rx_idxs:
|
||||||
|
|||||||
@@ -2,14 +2,14 @@ from datetime import datetime
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
from exo.shared.topology import SocketConnection
|
||||||
from exo.shared.types.chunks import GenerationChunk
|
from exo.shared.types.chunks import GenerationChunk
|
||||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||||
from exo.shared.types.worker.downloads import DownloadProgress
|
from exo.shared.types.worker.downloads import DownloadProgress
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||||
|
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||||
|
|
||||||
|
|
||||||
@@ -76,25 +76,15 @@ class RunnerDeleted(BaseEvent):
|
|||||||
runner_id: RunnerId
|
runner_id: RunnerId
|
||||||
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
class NodeCreated(BaseEvent):
|
|
||||||
node_id: NodeId
|
|
||||||
|
|
||||||
|
|
||||||
class NodeTimedOut(BaseEvent):
|
class NodeTimedOut(BaseEvent):
|
||||||
node_id: NodeId
|
node_id: NodeId
|
||||||
|
|
||||||
|
|
||||||
class NodePerformanceMeasured(BaseEvent):
|
# TODO: bikeshed this naem
|
||||||
|
class NodeGatheredInfo(BaseEvent):
|
||||||
node_id: NodeId
|
node_id: NodeId
|
||||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||||
node_profile: NodePerformanceProfile
|
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
|
||||||
|
|
||||||
|
|
||||||
class NodeMemoryMeasured(BaseEvent):
|
|
||||||
node_id: NodeId
|
|
||||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
|
||||||
memory: MemoryPerformanceProfile
|
|
||||||
|
|
||||||
|
|
||||||
class NodeDownloadProgress(BaseEvent):
|
class NodeDownloadProgress(BaseEvent):
|
||||||
@@ -107,11 +97,15 @@ class ChunkGenerated(BaseEvent):
|
|||||||
|
|
||||||
|
|
||||||
class TopologyEdgeCreated(BaseEvent):
|
class TopologyEdgeCreated(BaseEvent):
|
||||||
edge: Connection
|
source: NodeId
|
||||||
|
sink: NodeId
|
||||||
|
edge: SocketConnection
|
||||||
|
|
||||||
|
|
||||||
class TopologyEdgeDeleted(BaseEvent):
|
class TopologyEdgeDeleted(BaseEvent):
|
||||||
edge: Connection
|
source: NodeId
|
||||||
|
sink: NodeId
|
||||||
|
edge: SocketConnection
|
||||||
|
|
||||||
|
|
||||||
Event = (
|
Event = (
|
||||||
@@ -125,10 +119,8 @@ Event = (
|
|||||||
| InstanceDeleted
|
| InstanceDeleted
|
||||||
| RunnerStatusUpdated
|
| RunnerStatusUpdated
|
||||||
| RunnerDeleted
|
| RunnerDeleted
|
||||||
| NodeCreated
|
|
||||||
| NodeTimedOut
|
| NodeTimedOut
|
||||||
| NodePerformanceMeasured
|
| NodeGatheredInfo
|
||||||
| NodeMemoryMeasured
|
|
||||||
| NodeDownloadProgress
|
| NodeDownloadProgress
|
||||||
| ChunkGenerated
|
| ChunkGenerated
|
||||||
| TopologyEdgeCreated
|
| TopologyEdgeCreated
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import re
|
import re
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
from pydantic import BaseModel, computed_field, field_validator
|
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||||
|
|
||||||
|
|
||||||
class Multiaddr(BaseModel):
|
class Multiaddr(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
address: str
|
address: str
|
||||||
|
|
||||||
PATTERNS: ClassVar[list[str]] = [
|
PATTERNS: ClassVar[list[str]] = [
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
|
from exo.shared.types.thunderbolt import TBIdentifier
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel
|
from exo.utils.pydantic_ext import CamelCaseModel
|
||||||
|
|
||||||
|
|
||||||
class MemoryPerformanceProfile(CamelCaseModel):
|
class MemoryUsage(CamelCaseModel):
|
||||||
ram_total: Memory
|
ram_total: Memory
|
||||||
ram_available: Memory
|
ram_available: Memory
|
||||||
swap_total: Memory
|
swap_total: Memory
|
||||||
@@ -44,7 +46,6 @@ class SystemPerformanceProfile(CamelCaseModel):
|
|||||||
sys_power: float = 0.0
|
sys_power: float = 0.0
|
||||||
pcpu_usage: float = 0.0
|
pcpu_usage: float = 0.0
|
||||||
ecpu_usage: float = 0.0
|
ecpu_usage: float = 0.0
|
||||||
ane_power: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkInterfaceInfo(CamelCaseModel):
|
class NetworkInterfaceInfo(CamelCaseModel):
|
||||||
@@ -53,15 +54,16 @@ class NetworkInterfaceInfo(CamelCaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class NodePerformanceProfile(CamelCaseModel):
|
class NodePerformanceProfile(CamelCaseModel):
|
||||||
model_id: str
|
model_id: str = "Unknown"
|
||||||
chip_id: str
|
chip_id: str = "Unknown"
|
||||||
friendly_name: str
|
friendly_name: str = "Unknown"
|
||||||
memory: MemoryPerformanceProfile
|
memory: MemoryUsage = MemoryUsage.from_bytes(
|
||||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
ram_total=0, ram_available=0, swap_total=0, swap_available=0
|
||||||
system: SystemPerformanceProfile
|
)
|
||||||
|
network_interfaces: Sequence[NetworkInterfaceInfo] = []
|
||||||
|
tb_interfaces: Sequence[TBIdentifier] = []
|
||||||
|
system: SystemPerformanceProfile = SystemPerformanceProfile()
|
||||||
|
|
||||||
|
|
||||||
class ConnectionProfile(CamelCaseModel):
|
class ConnectionProfile(CamelCaseModel):
|
||||||
throughput: float
|
pass
|
||||||
latency: float
|
|
||||||
jitter: float
|
|
||||||
|
|||||||
64
src/exo/shared/types/thunderbolt.py
Normal file
64
src/exo/shared/types/thunderbolt.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import anyio
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from exo.utils.pydantic_ext import CamelCaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TBConnection(CamelCaseModel):
|
||||||
|
source_uuid: str
|
||||||
|
sink_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
class TBIdentifier(CamelCaseModel):
|
||||||
|
rdma_interface: str
|
||||||
|
domain_uuid: str
|
||||||
|
|
||||||
|
|
||||||
|
# Intentionally minimal, only collecting data we care about - there's a lot more
|
||||||
|
|
||||||
|
|
||||||
|
class TBReceptacleTag(BaseModel, extra="ignore"):
|
||||||
|
receptacle_id_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class TBConnectivityItem(BaseModel, extra="ignore"):
|
||||||
|
domain_uuid_key: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class TBConnectivityData(BaseModel, extra="ignore"):
|
||||||
|
domain_uuid_key: str | None
|
||||||
|
device_name_key: str
|
||||||
|
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
|
||||||
|
receptacle_1_tag: TBReceptacleTag
|
||||||
|
|
||||||
|
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
|
||||||
|
if self.domain_uuid_key is None:
|
||||||
|
return
|
||||||
|
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
|
||||||
|
iface = f"rdma_{ifaces[tag]}"
|
||||||
|
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
|
||||||
|
|
||||||
|
def conn(self) -> TBConnection | None:
|
||||||
|
if self.domain_uuid_key is None or self.items is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
sink_key = next(
|
||||||
|
item.domain_uuid_key
|
||||||
|
for item in self.items
|
||||||
|
if item.domain_uuid_key is not None
|
||||||
|
)
|
||||||
|
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
|
||||||
|
|
||||||
|
|
||||||
|
class TBConnectivity(BaseModel):
|
||||||
|
SPThunderboltDataType: list[TBConnectivityData]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def gather(cls) -> list[TBConnectivityData] | None:
|
||||||
|
proc = await anyio.run_process(
|
||||||
|
["system_profiler", "SPThunderboltDataType", "-json"], check=False
|
||||||
|
)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
return None
|
||||||
|
# Saving you from PascalCase while avoiding too much pydantic
|
||||||
|
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType
|
||||||
@@ -1,37 +1,32 @@
|
|||||||
from exo.shared.types.common import NodeId
|
from enum import Enum
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from exo.shared.types.multiaddr import Multiaddr
|
from exo.shared.types.multiaddr import Multiaddr
|
||||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
from exo.utils.pydantic_ext import FrozenModel
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class NodeInfo(CamelCaseModel):
|
class RDMAConnection(FrozenModel):
|
||||||
node_id: NodeId
|
source_rdma_iface: str
|
||||||
node_profile: NodePerformanceProfile | None = None
|
sink_rdma_iface: str
|
||||||
|
|
||||||
|
|
||||||
class Connection(CamelCaseModel):
|
|
||||||
local_node_id: NodeId
|
|
||||||
send_back_node_id: NodeId
|
|
||||||
send_back_multiaddr: Multiaddr
|
|
||||||
connection_profile: ConnectionProfile | None = None
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return hash(
|
|
||||||
(
|
|
||||||
self.local_node_id,
|
|
||||||
self.send_back_node_id,
|
|
||||||
self.send_back_multiaddr.address,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, Connection):
|
|
||||||
raise ValueError("Cannot compare Connection with non-Connection")
|
|
||||||
return (
|
|
||||||
self.local_node_id == other.local_node_id
|
|
||||||
and self.send_back_node_id == other.send_back_node_id
|
|
||||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_thunderbolt(self) -> bool:
|
def is_thunderbolt(self) -> bool:
|
||||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
logger.warning("duh")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
class LinkType(str, Enum):
|
||||||
|
Thunderbolt = "Thunderbolt"
|
||||||
|
Ethernet = "Ethernet"
|
||||||
|
WiFi = "WiFi"
|
||||||
|
|
||||||
|
|
||||||
|
class SocketConnection(FrozenModel):
|
||||||
|
sink_multiaddr: Multiaddr
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.sink_multiaddr.ip_address)
|
||||||
|
|
||||||
|
def is_thunderbolt(self) -> bool:
|
||||||
|
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class MlxRingInstance(BaseInstance):
|
|||||||
|
|
||||||
|
|
||||||
class MlxJacclInstance(BaseInstance):
|
class MlxJacclInstance(BaseInstance):
|
||||||
ibv_devices: list[list[str | None]]
|
jaccl_devices: list[list[str | None]]
|
||||||
jaccl_coordinators: dict[NodeId, str]
|
jaccl_coordinators: dict[NodeId, str]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections.abc import Coroutine
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from exo.shared.types.profiling import (
|
|
||||||
MemoryPerformanceProfile,
|
|
||||||
SystemPerformanceProfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceCollector(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
|
|
||||||
|
|
||||||
|
|
||||||
class SystemResourceCollector(ResourceCollector):
|
|
||||||
async def collect(self) -> SystemPerformanceProfile: ...
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryResourceCollector(ResourceCollector):
|
|
||||||
async def collect(self) -> MemoryPerformanceProfile: ...
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceMonitor:
|
|
||||||
data_collectors: list[ResourceCollector]
|
|
||||||
effect_handlers: set[
|
|
||||||
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _collect(
|
|
||||||
self,
|
|
||||||
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
|
|
||||||
tasks: list[
|
|
||||||
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
|
|
||||||
] = [collector.collect() for collector in self.data_collectors]
|
|
||||||
return await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
async def collect(self) -> None:
|
|
||||||
profiles = await self._collect()
|
|
||||||
for profile in profiles:
|
|
||||||
for effect_handler in self.effect_handlers:
|
|
||||||
effect_handler(profile)
|
|
||||||
231
src/exo/utils/info_gatherer/info_gatherer.py
Normal file
231
src/exo/utils/info_gatherer/info_gatherer.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tomllib
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from subprocess import CalledProcessError
|
||||||
|
from typing import Self, cast
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from anyio import create_task_group, open_process
|
||||||
|
from anyio.abc import TaskGroup
|
||||||
|
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||||
|
from anyio.streams.text import TextReceiveStream
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from exo.shared.constants import EXO_CONFIG_FILE
|
||||||
|
from exo.shared.types.memory import Memory
|
||||||
|
from exo.shared.types.profiling import (
|
||||||
|
MemoryUsage,
|
||||||
|
NetworkInterfaceInfo,
|
||||||
|
)
|
||||||
|
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
|
||||||
|
from exo.utils.channels import Sender
|
||||||
|
from exo.utils.pydantic_ext import TaggedModel
|
||||||
|
|
||||||
|
from .macmon import MacmonMetrics
|
||||||
|
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
|
||||||
|
|
||||||
|
IS_DARWIN = sys.platform == "darwin"
|
||||||
|
|
||||||
|
|
||||||
|
class StaticNodeInformation(TaggedModel):
|
||||||
|
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
chip: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def gather(cls) -> Self:
|
||||||
|
model, chip = await get_model_and_chip()
|
||||||
|
return cls(model=model, chip=chip)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeNetworkInterfaces(TaggedModel):
|
||||||
|
ifaces: Sequence[NetworkInterfaceInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class MacTBIdentifiers(TaggedModel):
|
||||||
|
idents: Sequence[TBIdentifier]
|
||||||
|
|
||||||
|
|
||||||
|
class MacTBConnections(TaggedModel):
|
||||||
|
conns: Sequence[TBConnection]
|
||||||
|
|
||||||
|
|
||||||
|
class NodeConfig(TaggedModel):
|
||||||
|
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
@classmethod
|
||||||
|
async def gather(cls) -> Self | None:
|
||||||
|
cfg_file = anyio.Path(EXO_CONFIG_FILE)
|
||||||
|
await cfg_file.touch(exist_ok=True)
|
||||||
|
async with await cfg_file.open("rb") as f:
|
||||||
|
try:
|
||||||
|
contents = (await f.read()).decode("utf-8")
|
||||||
|
data = tomllib.loads(contents)
|
||||||
|
return cls.model_validate(data)
|
||||||
|
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
|
||||||
|
logger.warning("Invalid config file, skipping...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MiscData(TaggedModel):
|
||||||
|
"""Node information that may slowly change that doesn't fall into the other categories"""
|
||||||
|
|
||||||
|
friendly_name: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def gather(cls) -> Self:
|
||||||
|
return cls(friendly_name=await get_friendly_name())
|
||||||
|
|
||||||
|
|
||||||
|
async def _gather_iface_map() -> dict[str, str] | None:
|
||||||
|
proc = await anyio.run_process(
|
||||||
|
["networksetup", "-listallhardwareports"], check=False
|
||||||
|
)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ports: dict[str, str] = {}
|
||||||
|
port = ""
|
||||||
|
for line in proc.stdout.decode("utf-8").split("\n"):
|
||||||
|
if line.startswith("Hardware Port:"):
|
||||||
|
port = line.split(": ")[1]
|
||||||
|
elif line.startswith("Device:"):
|
||||||
|
ports[port] = line.split(": ")[1]
|
||||||
|
port = ""
|
||||||
|
if "" in ports:
|
||||||
|
del ports[""]
|
||||||
|
return ports
|
||||||
|
|
||||||
|
|
||||||
|
GatheredInfo = (
|
||||||
|
MacmonMetrics
|
||||||
|
| MemoryUsage
|
||||||
|
| NodeNetworkInterfaces
|
||||||
|
| MacTBIdentifiers
|
||||||
|
| MacTBConnections
|
||||||
|
| NodeConfig
|
||||||
|
| MiscData
|
||||||
|
| StaticNodeInformation
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InfoGatherer:
|
||||||
|
info_sender: Sender[GatheredInfo]
|
||||||
|
interface_watcher_interval: float | None = 10
|
||||||
|
misc_poll_interval: float | None = 60
|
||||||
|
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||||
|
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||||
|
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||||
|
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
async with self._tg as tg:
|
||||||
|
if (macmon_path := shutil.which("macmon")) is not None:
|
||||||
|
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||||
|
if IS_DARWIN:
|
||||||
|
tg.start_soon(self._monitor_system_profiler)
|
||||||
|
tg.start_soon(self._watch_system_info)
|
||||||
|
tg.start_soon(self._monitor_memory_usage)
|
||||||
|
tg.start_soon(self._monitor_misc)
|
||||||
|
|
||||||
|
nc = await NodeConfig.gather()
|
||||||
|
if nc is not None:
|
||||||
|
await self.info_sender.send(nc)
|
||||||
|
sni = await StaticNodeInformation.gather()
|
||||||
|
await self.info_sender.send(sni)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self._tg.cancel_scope.cancel()
|
||||||
|
|
||||||
|
async def _monitor_misc(self):
|
||||||
|
if self.misc_poll_interval is None:
|
||||||
|
return
|
||||||
|
prev = await MiscData.gather()
|
||||||
|
while True:
|
||||||
|
curr = await MiscData.gather()
|
||||||
|
if prev != curr:
|
||||||
|
prev = curr
|
||||||
|
await self.info_sender.send(curr)
|
||||||
|
await anyio.sleep(self.misc_poll_interval)
|
||||||
|
|
||||||
|
async def _monitor_system_profiler(self):
|
||||||
|
if self.system_profiler_interval is None:
|
||||||
|
return
|
||||||
|
iface_map = await _gather_iface_map()
|
||||||
|
if iface_map is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
old_idents = []
|
||||||
|
while True:
|
||||||
|
data = await TBConnectivity.gather()
|
||||||
|
assert data is not None
|
||||||
|
|
||||||
|
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||||
|
if idents != old_idents:
|
||||||
|
await self.info_sender.send(MacTBIdentifiers(idents=idents))
|
||||||
|
old_idents = idents
|
||||||
|
|
||||||
|
conns = [it for i in data if (it := i.conn()) is not None]
|
||||||
|
await self.info_sender.send(MacTBConnections(conns=conns))
|
||||||
|
|
||||||
|
await anyio.sleep(self.system_profiler_interval)
|
||||||
|
|
||||||
|
async def _monitor_memory_usage(self):
|
||||||
|
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||||
|
override_memory: int | None = (
|
||||||
|
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||||
|
if override_memory_env
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if self.memory_poll_rate is None:
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
await self.info_sender.send(
|
||||||
|
MemoryUsage.from_psutil(override_memory=override_memory)
|
||||||
|
)
|
||||||
|
await anyio.sleep(self.memory_poll_rate)
|
||||||
|
|
||||||
|
async def _watch_system_info(self):
|
||||||
|
if self.interface_watcher_interval is None:
|
||||||
|
return
|
||||||
|
old_nics = []
|
||||||
|
while True:
|
||||||
|
nics = get_network_interfaces()
|
||||||
|
if nics != old_nics:
|
||||||
|
old_nics = nics
|
||||||
|
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||||
|
await anyio.sleep(self.interface_watcher_interval)
|
||||||
|
|
||||||
|
async def _monitor_macmon(self, macmon_path: str):
|
||||||
|
if self.macmon_interval is None:
|
||||||
|
return
|
||||||
|
# macmon pipe --interval [interval in ms]
|
||||||
|
try:
|
||||||
|
async with await open_process(
|
||||||
|
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
|
||||||
|
) as p:
|
||||||
|
if not p.stdout:
|
||||||
|
logger.critical("MacMon closed stdout")
|
||||||
|
return
|
||||||
|
async for text in TextReceiveStream(
|
||||||
|
BufferedByteReceiveStream(p.stdout)
|
||||||
|
):
|
||||||
|
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||||
|
except CalledProcessError as e:
|
||||||
|
stderr_msg = "no stderr"
|
||||||
|
stderr_output = cast(bytes | str | None, e.stderr)
|
||||||
|
if stderr_output is not None:
|
||||||
|
stderr_msg = (
|
||||||
|
stderr_output.decode()
|
||||||
|
if isinstance(stderr_output, bytes)
|
||||||
|
else str(stderr_output)
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||||
|
)
|
||||||
70
src/exo/utils/info_gatherer/macmon.py
Normal file
70
src/exo/utils/info_gatherer/macmon.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
|
||||||
|
from exo.utils.pydantic_ext import TaggedModel
|
||||||
|
|
||||||
|
|
||||||
|
class _TempMetrics(BaseModel, extra="ignore"):
|
||||||
|
"""Temperature-related metrics returned by macmon."""
|
||||||
|
|
||||||
|
cpu_temp_avg: float
|
||||||
|
gpu_temp_avg: float
|
||||||
|
|
||||||
|
|
||||||
|
class _MemoryMetrics(BaseModel, extra="ignore"):
|
||||||
|
"""Memory-related metrics returned by macmon."""
|
||||||
|
|
||||||
|
ram_total: int
|
||||||
|
ram_usage: int
|
||||||
|
swap_total: int
|
||||||
|
swap_usage: int
|
||||||
|
|
||||||
|
|
||||||
|
class RawMacmonMetrics(BaseModel, extra="ignore"):
|
||||||
|
"""Complete set of metrics returned by macmon.
|
||||||
|
|
||||||
|
Unknown fields are ignored for forward-compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timestamp: str # ignored
|
||||||
|
temp: _TempMetrics
|
||||||
|
memory: _MemoryMetrics
|
||||||
|
ecpu_usage: tuple[int, float] # freq mhz, usage %
|
||||||
|
pcpu_usage: tuple[int, float] # freq mhz, usage %
|
||||||
|
gpu_usage: tuple[int, float] # freq mhz, usage %
|
||||||
|
all_power: float
|
||||||
|
ane_power: float
|
||||||
|
cpu_power: float
|
||||||
|
gpu_power: float
|
||||||
|
gpu_ram_power: float
|
||||||
|
ram_power: float
|
||||||
|
sys_power: float
|
||||||
|
|
||||||
|
|
||||||
|
class MacmonMetrics(TaggedModel):
|
||||||
|
system_profile: SystemPerformanceProfile
|
||||||
|
memory: MemoryUsage
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
|
||||||
|
return cls(
|
||||||
|
system_profile=SystemPerformanceProfile(
|
||||||
|
gpu_usage=raw.gpu_usage[1],
|
||||||
|
temp=raw.temp.gpu_temp_avg,
|
||||||
|
sys_power=raw.sys_power,
|
||||||
|
pcpu_usage=raw.pcpu_usage[1],
|
||||||
|
ecpu_usage=raw.ecpu_usage[1],
|
||||||
|
),
|
||||||
|
memory=MemoryUsage.from_bytes(
|
||||||
|
ram_total=raw.memory.ram_total,
|
||||||
|
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
|
||||||
|
swap_total=raw.memory.swap_total,
|
||||||
|
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_raw_json(cls, json: str) -> Self:
|
||||||
|
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))
|
||||||
56
src/exo/utils/info_gatherer/net_profile.py
Normal file
56
src/exo/utils/info_gatherer/net_profile.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import socket
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from ipaddress import ip_address
|
||||||
|
|
||||||
|
from anyio import create_task_group, to_thread
|
||||||
|
|
||||||
|
from exo.shared.topology import Topology
|
||||||
|
from exo.shared.types.common import NodeId
|
||||||
|
from exo.shared.types.profiling import NodePerformanceProfile
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: ref. api port
|
||||||
|
async def check_reachability(
|
||||||
|
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
|
||||||
|
) -> None:
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(1) # 1 second timeout
|
||||||
|
try:
|
||||||
|
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
|
||||||
|
except socket.gaierror:
|
||||||
|
# seems to throw on ipv6 loopback. oh well
|
||||||
|
# logger.warning(f"invalid {target_ip=}")
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
if target_node_id not in out:
|
||||||
|
out[target_node_id] = set()
|
||||||
|
out[target_node_id].add(target_ip)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_reachable(
|
||||||
|
our_node_id: NodeId,
|
||||||
|
topology: Topology,
|
||||||
|
profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||||
|
) -> Mapping[NodeId, set[str]]:
|
||||||
|
reachable: dict[NodeId, set[str]] = {}
|
||||||
|
our_profile = profiles.get(our_node_id, None)
|
||||||
|
if our_profile is None:
|
||||||
|
return {}
|
||||||
|
our_interfaces = our_profile.network_interfaces
|
||||||
|
async with create_task_group() as tg:
|
||||||
|
for node_id in topology.list_nodes():
|
||||||
|
if node_id not in profiles or node_id == our_node_id:
|
||||||
|
continue
|
||||||
|
for iface in profiles[node_id].network_interfaces:
|
||||||
|
if ip_address(iface.ip_address).is_loopback:
|
||||||
|
# Definitely a loopback address
|
||||||
|
continue
|
||||||
|
if iface in our_interfaces:
|
||||||
|
# Skip duplicates with our own interfaces
|
||||||
|
continue
|
||||||
|
tg.start_soon(check_reachability, iface.ip_address, node_id, reachable)
|
||||||
|
|
||||||
|
return reachable
|
||||||
@@ -19,11 +19,20 @@ class CamelCaseModel(BaseModel):
|
|||||||
alias_generator=to_camel,
|
alias_generator=to_camel,
|
||||||
validate_by_name=True,
|
validate_by_name=True,
|
||||||
extra="forbid",
|
extra="forbid",
|
||||||
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
|
|
||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenModel(BaseModel):
|
||||||
|
model_config = ConfigDict(
|
||||||
|
alias_generator=to_camel,
|
||||||
|
validate_by_name=True,
|
||||||
|
extra="forbid",
|
||||||
|
strict=True,
|
||||||
|
frozen=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TaggedModel(CamelCaseModel):
|
class TaggedModel(CamelCaseModel):
|
||||||
@model_serializer(mode="wrap")
|
@model_serializer(mode="wrap")
|
||||||
def _serialize(self, handler: SerializerFunctionWrapHandler):
|
def _serialize(self, handler: SerializerFunctionWrapHandler):
|
||||||
|
|||||||
@@ -28,9 +28,8 @@ def bar(send: MpSender[str]):
|
|||||||
send.close()
|
send.close()
|
||||||
|
|
||||||
|
|
||||||
# not async, just want the fail_after
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_channel_setup():
|
async def test_channel_ipc():
|
||||||
with fail_after(0.5):
|
with fail_after(0.5):
|
||||||
s, r = mp_channel[str]()
|
s, r = mp_channel[str]()
|
||||||
p1 = mp.Process(target=foo, args=(r,))
|
p1 = mp.Process(target=foo, args=(r,))
|
||||||
@@ -101,13 +101,7 @@ def mlx_distributed_init(
|
|||||||
bound_instance: BoundInstance,
|
bound_instance: BoundInstance,
|
||||||
) -> mx.distributed.Group:
|
) -> mx.distributed.Group:
|
||||||
"""
|
"""
|
||||||
Initialize the MLX distributed (runs in thread pool).
|
Initialize the MLX distributed
|
||||||
|
|
||||||
Either hosts or mlx_ibv_devices must be provided:
|
|
||||||
- hosts: traditional host-based connectivity using MLX_HOSTFILE
|
|
||||||
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
|
|
||||||
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
|
|
||||||
- strict: if True, raise an error if the distributed backend is not available
|
|
||||||
"""
|
"""
|
||||||
rank = bound_instance.bound_shard.device_rank
|
rank = bound_instance.bound_shard.device_rank
|
||||||
logger.info(f"Starting initialization for rank {rank}")
|
logger.info(f"Starting initialization for rank {rank}")
|
||||||
@@ -129,20 +123,20 @@ def mlx_distributed_init(
|
|||||||
group = mx.distributed.init(backend="ring", strict=True)
|
group = mx.distributed.init(backend="ring", strict=True)
|
||||||
|
|
||||||
case MlxJacclInstance(
|
case MlxJacclInstance(
|
||||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
|
||||||
):
|
):
|
||||||
# Use RDMA connectivity matrix
|
# Use RDMA connectivity matrix
|
||||||
devices_file = f"./hosts_{rank}.json"
|
devices_file = f"./hosts_{rank}.json"
|
||||||
ibv_devices_json = json.dumps(ibv_devices)
|
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||||
|
|
||||||
with open(devices_file, "w") as f:
|
with open(devices_file, "w") as f:
|
||||||
_ = f.write(ibv_devices_json)
|
_ = f.write(jaccl_devices_json)
|
||||||
|
|
||||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||||
|
|
||||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
logger.info(f"rank {rank} MLX_JACCL_DEVICES: {jaccl_devices_json}")
|
||||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
os.environ["MLX_JACCL_DEVICES"] = devices_file
|
||||||
os.environ["MLX_RANK"] = str(rank)
|
os.environ["MLX_RANK"] = str(rank)
|
||||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||||
|
|||||||
@@ -16,15 +16,13 @@ from exo.shared.types.events import (
|
|||||||
ForwarderEvent,
|
ForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
NodeDownloadProgress,
|
NodeDownloadProgress,
|
||||||
NodeMemoryMeasured,
|
NodeGatheredInfo,
|
||||||
NodePerformanceMeasured,
|
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
TopologyEdgeCreated,
|
TopologyEdgeCreated,
|
||||||
TopologyEdgeDeleted,
|
TopologyEdgeDeleted,
|
||||||
)
|
)
|
||||||
from exo.shared.types.multiaddr import Multiaddr
|
from exo.shared.types.multiaddr import Multiaddr
|
||||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
|
||||||
from exo.shared.types.state import State
|
from exo.shared.types.state import State
|
||||||
from exo.shared.types.tasks import (
|
from exo.shared.types.tasks import (
|
||||||
CreateRunner,
|
CreateRunner,
|
||||||
@@ -33,7 +31,7 @@ from exo.shared.types.tasks import (
|
|||||||
Task,
|
Task,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
)
|
)
|
||||||
from exo.shared.types.topology import Connection
|
from exo.shared.types.topology import SocketConnection
|
||||||
from exo.shared.types.worker.downloads import (
|
from exo.shared.types.worker.downloads import (
|
||||||
DownloadCompleted,
|
DownloadCompleted,
|
||||||
DownloadOngoing,
|
DownloadOngoing,
|
||||||
@@ -44,14 +42,14 @@ from exo.shared.types.worker.runners import RunnerId
|
|||||||
from exo.shared.types.worker.shards import ShardMetadata
|
from exo.shared.types.worker.shards import ShardMetadata
|
||||||
from exo.utils.channels import Receiver, Sender, channel
|
from exo.utils.channels import Receiver, Sender, channel
|
||||||
from exo.utils.event_buffer import OrderedBuffer
|
from exo.utils.event_buffer import OrderedBuffer
|
||||||
|
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||||
|
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||||
from exo.worker.download.download_utils import (
|
from exo.worker.download.download_utils import (
|
||||||
map_repo_download_progress_to_download_progress_data,
|
map_repo_download_progress_to_download_progress_data,
|
||||||
)
|
)
|
||||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||||
from exo.worker.plan import plan
|
from exo.worker.plan import plan
|
||||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||||
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
|
|
||||||
from exo.worker.utils.net_profile import check_reachable
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@@ -85,7 +83,7 @@ class Worker:
|
|||||||
self.state: State = State()
|
self.state: State = State()
|
||||||
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
|
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||||
self._tg: TaskGroup | None = None
|
self._tg: TaskGroup = create_task_group()
|
||||||
|
|
||||||
self._nack_cancel_scope: CancelScope | None = None
|
self._nack_cancel_scope: CancelScope | None = None
|
||||||
self._nack_attempts: int = 0
|
self._nack_attempts: int = 0
|
||||||
@@ -97,37 +95,13 @@ class Worker:
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
logger.info("Starting Worker")
|
logger.info("Starting Worker")
|
||||||
|
|
||||||
# TODO: CLEANUP HEADER
|
info_send, info_recv = channel[GatheredInfo]()
|
||||||
async def resource_monitor_callback(
|
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||||
node_performance_profile: NodePerformanceProfile,
|
|
||||||
) -> None:
|
|
||||||
await self.event_sender.send(
|
|
||||||
NodePerformanceMeasured(
|
|
||||||
node_id=self.node_id,
|
|
||||||
node_profile=node_performance_profile,
|
|
||||||
when=str(datetime.now(tz=timezone.utc)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def memory_monitor_callback(
|
async with self._tg as tg:
|
||||||
memory_profile: MemoryPerformanceProfile,
|
tg.start_soon(info_gatherer.run)
|
||||||
) -> None:
|
tg.start_soon(self._forward_info, info_recv)
|
||||||
await self.event_sender.send(
|
|
||||||
NodeMemoryMeasured(
|
|
||||||
node_id=self.node_id,
|
|
||||||
memory=memory_profile,
|
|
||||||
when=str(datetime.now(tz=timezone.utc)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# END CLEANUP
|
|
||||||
|
|
||||||
async with create_task_group() as tg:
|
|
||||||
self._tg = tg
|
|
||||||
tg.start_soon(self.plan_step)
|
tg.start_soon(self.plan_step)
|
||||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
|
||||||
|
|
||||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
|
||||||
tg.start_soon(self._connection_message_event_writer)
|
tg.start_soon(self._connection_message_event_writer)
|
||||||
tg.start_soon(self._resend_out_for_delivery)
|
tg.start_soon(self._resend_out_for_delivery)
|
||||||
tg.start_soon(self._event_applier)
|
tg.start_soon(self._event_applier)
|
||||||
@@ -140,6 +114,17 @@ class Worker:
|
|||||||
for runner in self.runners.values():
|
for runner in self.runners.values():
|
||||||
runner.shutdown()
|
runner.shutdown()
|
||||||
|
|
||||||
|
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||||
|
with recv as info_stream:
|
||||||
|
async for info in info_stream:
|
||||||
|
await self.event_sender.send(
|
||||||
|
NodeGatheredInfo(
|
||||||
|
node_id=self.node_id,
|
||||||
|
when=str(datetime.now(tz=timezone.utc)),
|
||||||
|
info=info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def _event_applier(self):
|
async def _event_applier(self):
|
||||||
with self.global_event_receiver as events:
|
with self.global_event_receiver as events:
|
||||||
async for f_event in events:
|
async for f_event in events:
|
||||||
@@ -159,7 +144,6 @@ class Worker:
|
|||||||
self._nack_cancel_scope is None
|
self._nack_cancel_scope is None
|
||||||
or self._nack_cancel_scope.cancel_called
|
or self._nack_cancel_scope.cancel_called
|
||||||
):
|
):
|
||||||
assert self._tg
|
|
||||||
# Request the next index.
|
# Request the next index.
|
||||||
self._tg.start_soon(
|
self._tg.start_soon(
|
||||||
self._nack_request, self.state.last_event_applied_idx + 1
|
self._nack_request, self.state.last_event_applied_idx + 1
|
||||||
@@ -248,8 +232,7 @@ class Worker:
|
|||||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if self._tg:
|
self._tg.cancel_scope.cancel()
|
||||||
self._tg.cancel_scope.cancel()
|
|
||||||
|
|
||||||
def _task_to_runner_id(self, task: Task):
|
def _task_to_runner_id(self, task: Task):
|
||||||
instance = self.state.instances[task.instance_id]
|
instance = self.state.instances[task.instance_id]
|
||||||
@@ -266,24 +249,24 @@ class Worker:
|
|||||||
match msg.connection_type:
|
match msg.connection_type:
|
||||||
case ConnectionMessageType.Connected:
|
case ConnectionMessageType.Connected:
|
||||||
return TopologyEdgeCreated(
|
return TopologyEdgeCreated(
|
||||||
edge=Connection(
|
source=self.node_id,
|
||||||
local_node_id=self.node_id,
|
sink=msg.node_id,
|
||||||
send_back_node_id=msg.node_id,
|
edge=SocketConnection(
|
||||||
send_back_multiaddr=Multiaddr(
|
sink_multiaddr=Multiaddr(
|
||||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
case ConnectionMessageType.Disconnected:
|
case ConnectionMessageType.Disconnected:
|
||||||
return TopologyEdgeDeleted(
|
return TopologyEdgeDeleted(
|
||||||
edge=Connection(
|
source=self.node_id,
|
||||||
local_node_id=self.node_id,
|
sink=msg.node_id,
|
||||||
send_back_node_id=msg.node_id,
|
edge=SocketConnection(
|
||||||
send_back_multiaddr=Multiaddr(
|
sink_multiaddr=Multiaddr(
|
||||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _nack_request(self, since_idx: int) -> None:
|
async def _nack_request(self, since_idx: int) -> None:
|
||||||
@@ -332,7 +315,6 @@ class Worker:
|
|||||||
event_sender=self.event_sender.clone(),
|
event_sender=self.event_sender.clone(),
|
||||||
)
|
)
|
||||||
self.runners[task.bound_instance.bound_runner_id] = runner
|
self.runners[task.bound_instance.bound_runner_id] = runner
|
||||||
assert self._tg
|
|
||||||
self._tg.start_soon(runner.run)
|
self._tg.start_soon(runner.run)
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
@@ -391,7 +373,6 @@ class Worker:
|
|||||||
last_progress_time = current_time()
|
last_progress_time = current_time()
|
||||||
|
|
||||||
self.shard_downloader.on_progress(download_progress_callback)
|
self.shard_downloader.on_progress(download_progress_callback)
|
||||||
assert self._tg
|
|
||||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||||
|
|
||||||
async def _forward_events(self) -> None:
|
async def _forward_events(self) -> None:
|
||||||
@@ -414,28 +395,35 @@ class Worker:
|
|||||||
while True:
|
while True:
|
||||||
# TODO: EdgeDeleted
|
# TODO: EdgeDeleted
|
||||||
edges = set(self.state.topology.list_connections())
|
edges = set(self.state.topology.list_connections())
|
||||||
conns = await check_reachable(self.state.topology)
|
conns = await check_reachable(
|
||||||
|
self.node_id, self.state.topology, self.state.node_profiles
|
||||||
|
)
|
||||||
for nid in conns:
|
for nid in conns:
|
||||||
for ip in conns[nid]:
|
for ip in conns[nid]:
|
||||||
edge = Connection(
|
edge = SocketConnection(
|
||||||
local_node_id=self.node_id,
|
|
||||||
send_back_node_id=nid,
|
|
||||||
# nonsense multiaddr
|
# nonsense multiaddr
|
||||||
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||||
if "." in ip
|
if "." in ip
|
||||||
# nonsense multiaddr
|
# nonsense multiaddr
|
||||||
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
||||||
)
|
)
|
||||||
if edge not in edges:
|
if edge not in edges:
|
||||||
logger.debug(f"ping discovered {edge=}")
|
logger.debug(f"ping discovered {edge=}")
|
||||||
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
|
await self.event_sender.send(
|
||||||
|
TopologyEdgeCreated(
|
||||||
|
source=self.node_id, sink=nid, edge=edge
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for nid, conn in self.state.topology.out_edges(self.node_id):
|
for nid, conn in self.state.topology.out_edges(self.node_id):
|
||||||
if (
|
if not isinstance(conn, SocketConnection):
|
||||||
nid not in conns
|
continue
|
||||||
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
|
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
|
||||||
|
nid, set()
|
||||||
):
|
):
|
||||||
logger.debug(f"ping failed to discover {conn=}")
|
logger.debug(f"ping failed to discover {conn=}")
|
||||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
await self.event_sender.send(
|
||||||
|
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
|
||||||
|
)
|
||||||
|
|
||||||
await anyio.sleep(10)
|
await anyio.sleep(10)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def entrypoint(
|
|||||||
) -> None:
|
) -> None:
|
||||||
if (
|
if (
|
||||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||||
and len(bound_instance.instance.ibv_devices) >= 2
|
and len(bound_instance.instance.jaccl_devices) >= 2
|
||||||
):
|
):
|
||||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
from .profile import start_polling_memory_metrics, start_polling_node_metrics
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"start_polling_node_metrics",
|
|
||||||
"start_polling_memory_metrics",
|
|
||||||
]
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
import socket
|
|
||||||
|
|
||||||
from anyio import create_task_group, to_thread
|
|
||||||
|
|
||||||
from exo.shared.topology import Topology
|
|
||||||
from exo.shared.types.common import NodeId
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: ref. api port
|
|
||||||
async def check_reachability(
|
|
||||||
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
|
|
||||||
) -> None:
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
sock.settimeout(1) # 1 second timeout
|
|
||||||
try:
|
|
||||||
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
|
|
||||||
except socket.gaierror:
|
|
||||||
# seems to throw on ipv6 loopback. oh well
|
|
||||||
# logger.warning(f"invalid {target_ip=}")
|
|
||||||
return
|
|
||||||
finally:
|
|
||||||
sock.close()
|
|
||||||
|
|
||||||
if result == 0:
|
|
||||||
if target_node_id not in out:
|
|
||||||
out[target_node_id] = set()
|
|
||||||
out[target_node_id].add(target_ip)
|
|
||||||
|
|
||||||
|
|
||||||
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
|
|
||||||
reachable: dict[NodeId, set[str]] = {}
|
|
||||||
async with create_task_group() as tg:
|
|
||||||
for node in topology.list_nodes():
|
|
||||||
if not node.node_profile:
|
|
||||||
continue
|
|
||||||
for iface in node.node_profile.network_interfaces:
|
|
||||||
tg.start_soon(
|
|
||||||
check_reachability, iface.ip_address, node.node_id, reachable
|
|
||||||
)
|
|
||||||
|
|
||||||
return reachable
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
from typing import Any, Callable, Coroutine
|
|
||||||
|
|
||||||
import anyio
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from exo.shared.types.memory import Memory
|
|
||||||
from exo.shared.types.profiling import (
|
|
||||||
MemoryPerformanceProfile,
|
|
||||||
NodePerformanceProfile,
|
|
||||||
SystemPerformanceProfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .macmon import (
|
|
||||||
MacMonError,
|
|
||||||
Metrics,
|
|
||||||
)
|
|
||||||
from .macmon import (
|
|
||||||
get_metrics_async as macmon_get_metrics_async,
|
|
||||||
)
|
|
||||||
from .system_info import (
|
|
||||||
get_friendly_name,
|
|
||||||
get_model_and_chip,
|
|
||||||
get_network_interfaces,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_metrics_async() -> Metrics | None:
|
|
||||||
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
|
|
||||||
|
|
||||||
if platform.system().lower() == "darwin":
|
|
||||||
return await macmon_get_metrics_async()
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_profile() -> MemoryPerformanceProfile:
|
|
||||||
"""Construct a MemoryPerformanceProfile using psutil"""
|
|
||||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
|
||||||
override_memory: int | None = (
|
|
||||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
|
||||||
if override_memory_env
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
|
|
||||||
|
|
||||||
|
|
||||||
async def start_polling_memory_metrics(
|
|
||||||
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
|
|
||||||
*,
|
|
||||||
poll_interval_s: float = 0.5,
|
|
||||||
) -> None:
|
|
||||||
"""Continuously poll and emit memory-only metrics at a faster cadence.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
|
|
||||||
- poll_interval_s: interval between polls
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
mem = get_memory_profile()
|
|
||||||
await callback(mem)
|
|
||||||
except MacMonError as e:
|
|
||||||
logger.opt(exception=e).error("Memory Monitor encountered error")
|
|
||||||
finally:
|
|
||||||
await anyio.sleep(poll_interval_s)
|
|
||||||
|
|
||||||
|
|
||||||
async def start_polling_node_metrics(
|
|
||||||
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
|
|
||||||
):
|
|
||||||
poll_interval_s = 1.0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
metrics = await get_metrics_async()
|
|
||||||
if metrics is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
network_interfaces = get_network_interfaces()
|
|
||||||
# these awaits could be joined but realistically they should be cached
|
|
||||||
model_id, chip_id = await get_model_and_chip()
|
|
||||||
friendly_name = await get_friendly_name()
|
|
||||||
|
|
||||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
|
||||||
memory_profile = get_memory_profile()
|
|
||||||
|
|
||||||
await callback(
|
|
||||||
NodePerformanceProfile(
|
|
||||||
model_id=model_id,
|
|
||||||
chip_id=chip_id,
|
|
||||||
friendly_name=friendly_name,
|
|
||||||
network_interfaces=network_interfaces,
|
|
||||||
memory=memory_profile,
|
|
||||||
system=SystemPerformanceProfile(
|
|
||||||
gpu_usage=metrics.gpu_usage[1],
|
|
||||||
temp=metrics.temp.gpu_temp_avg,
|
|
||||||
sys_power=metrics.sys_power,
|
|
||||||
pcpu_usage=metrics.pcpu_usage[1],
|
|
||||||
ecpu_usage=metrics.ecpu_usage[1],
|
|
||||||
ane_power=metrics.ane_power,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
|
|
||||||
)
|
|
||||||
except MacMonError as e:
|
|
||||||
logger.opt(exception=e).error("Resource Monitor encountered error")
|
|
||||||
return
|
|
||||||
finally:
|
|
||||||
await anyio.sleep(poll_interval_s)
|
|
||||||
Reference in New Issue
Block a user