mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-19 11:28:51 -05:00
Compare commits
1 Commits
leo/fix-pi
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d0549c3046 |
@@ -490,17 +490,17 @@ def main() -> int:
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
# if (
|
||||
# pp * n_nodes > 2048
|
||||
# and "ring" in instance_meta.lower()
|
||||
# and "tensor" in sharding.lower()
|
||||
# ):
|
||||
# model_card = MODEL_CARDS[short_id]
|
||||
# if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
# logger.info(
|
||||
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
# )
|
||||
# continue
|
||||
if (
|
||||
pp * n_nodes > 2048
|
||||
and "ring" in instance_meta.lower()
|
||||
and "tensor" in sharding.lower()
|
||||
):
|
||||
model_card = MODEL_CARDS[short_id]
|
||||
if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
logger.info(
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
)
|
||||
continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
|
||||
@@ -71,35 +71,36 @@ export interface Instance {
|
||||
};
|
||||
}
|
||||
|
||||
interface RawNodeProfile {
|
||||
modelId?: string;
|
||||
chipId?: string;
|
||||
friendlyName?: string;
|
||||
networkInterfaces?: Array<{
|
||||
name?: string;
|
||||
ipAddress?: string;
|
||||
addresses?: Array<{ address?: string } | string>;
|
||||
ipv4?: string;
|
||||
ipv6?: string;
|
||||
ipAddresses?: string[];
|
||||
ips?: string[];
|
||||
}>;
|
||||
memory?: {
|
||||
ramTotal?: { inBytes: number };
|
||||
ramAvailable?: { inBytes: number };
|
||||
swapTotal?: { inBytes: number };
|
||||
swapAvailable?: { inBytes: number };
|
||||
};
|
||||
system?: {
|
||||
gpuUsage?: number;
|
||||
temp?: number;
|
||||
sysPower?: number;
|
||||
};
|
||||
// Split state interfaces
|
||||
interface RawNodeIdentity {
|
||||
modelId: string;
|
||||
chipId: string;
|
||||
friendlyName: string;
|
||||
}
|
||||
|
||||
interface RawNodeMemory {
|
||||
ramTotal: { inBytes: number };
|
||||
ramAvailable: { inBytes: number };
|
||||
swapTotal: { inBytes: number };
|
||||
swapAvailable: { inBytes: number };
|
||||
}
|
||||
|
||||
interface RawNodeSystem {
|
||||
gpuUsage?: number;
|
||||
temp?: number;
|
||||
sysPower?: number;
|
||||
pcpuUsage?: number;
|
||||
ecpuUsage?: number;
|
||||
anePower?: number;
|
||||
}
|
||||
|
||||
interface RawNetworkInterface {
|
||||
name: string;
|
||||
ipAddress: string;
|
||||
}
|
||||
|
||||
interface RawTopologyNode {
|
||||
nodeId: string;
|
||||
nodeProfile: RawNodeProfile;
|
||||
}
|
||||
|
||||
interface RawTopologyConnection {
|
||||
@@ -115,8 +116,6 @@ interface RawTopology {
|
||||
connections?: RawTopologyConnection[];
|
||||
}
|
||||
|
||||
type RawNodeProfiles = Record<string, RawNodeProfile>;
|
||||
|
||||
export interface DownloadProgress {
|
||||
totalBytes: number;
|
||||
downloadedBytes: number;
|
||||
@@ -171,7 +170,11 @@ interface RawStateResponse {
|
||||
>;
|
||||
runners?: Record<string, unknown>;
|
||||
downloads?: Record<string, unknown[]>;
|
||||
nodeProfiles?: RawNodeProfiles;
|
||||
// Split state fields
|
||||
nodeIdentities?: Record<string, RawNodeIdentity>;
|
||||
nodeMemories?: Record<string, RawNodeMemory>;
|
||||
nodeSystems?: Record<string, RawNodeSystem>;
|
||||
nodeNetworks?: Record<string, RawNetworkInterface[]>;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -208,66 +211,41 @@ const STORAGE_KEY = "exo-conversations";
|
||||
|
||||
function transformTopology(
|
||||
raw: RawTopology,
|
||||
profiles?: RawNodeProfiles,
|
||||
identities?: Record<string, RawNodeIdentity>,
|
||||
memories?: Record<string, RawNodeMemory>,
|
||||
systems?: Record<string, RawNodeSystem>,
|
||||
networks?: Record<string, RawNetworkInterface[]>,
|
||||
): TopologyData {
|
||||
const nodes: Record<string, NodeInfo> = {};
|
||||
const edges: TopologyEdge[] = [];
|
||||
|
||||
for (const node of raw.nodes || []) {
|
||||
const mergedProfile = profiles?.[node.nodeId];
|
||||
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
|
||||
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
|
||||
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
|
||||
// Get split state fields (may be undefined if events haven't arrived yet)
|
||||
const identity = identities?.[node.nodeId];
|
||||
const memory = memories?.[node.nodeId];
|
||||
const system = systems?.[node.nodeId];
|
||||
const network = networks?.[node.nodeId];
|
||||
|
||||
const ramTotal = memory?.ramTotal?.inBytes ?? 0;
|
||||
const ramAvailable = memory?.ramAvailable?.inBytes ?? 0;
|
||||
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
|
||||
|
||||
const networkInterfaces = (profile?.networkInterfaces || []).map(
|
||||
(iface) => {
|
||||
const addresses: string[] = [];
|
||||
if (iface.ipAddress && typeof iface.ipAddress === "string") {
|
||||
addresses.push(iface.ipAddress);
|
||||
}
|
||||
if (Array.isArray(iface.addresses)) {
|
||||
for (const addr of iface.addresses) {
|
||||
if (typeof addr === "string") addresses.push(addr);
|
||||
else if (addr && typeof addr === "object" && addr.address)
|
||||
addresses.push(addr.address);
|
||||
}
|
||||
}
|
||||
if (Array.isArray(iface.ipAddresses)) {
|
||||
addresses.push(
|
||||
...iface.ipAddresses.filter(
|
||||
(a): a is string => typeof a === "string",
|
||||
),
|
||||
);
|
||||
}
|
||||
if (Array.isArray(iface.ips)) {
|
||||
addresses.push(
|
||||
...iface.ips.filter((a): a is string => typeof a === "string"),
|
||||
);
|
||||
}
|
||||
if (iface.ipv4 && typeof iface.ipv4 === "string")
|
||||
addresses.push(iface.ipv4);
|
||||
if (iface.ipv6 && typeof iface.ipv6 === "string")
|
||||
addresses.push(iface.ipv6);
|
||||
|
||||
return {
|
||||
name: iface.name,
|
||||
addresses: Array.from(new Set(addresses)),
|
||||
};
|
||||
},
|
||||
);
|
||||
const networkInterfaces = (network ?? []).map((iface) => ({
|
||||
name: iface.name,
|
||||
addresses: [iface.ipAddress],
|
||||
}));
|
||||
|
||||
const ipToInterface: Record<string, string> = {};
|
||||
for (const iface of networkInterfaces) {
|
||||
for (const addr of iface.addresses || []) {
|
||||
ipToInterface[addr] = iface.name ?? "";
|
||||
for (const addr of iface.addresses) {
|
||||
ipToInterface[addr] = iface.name;
|
||||
}
|
||||
}
|
||||
|
||||
nodes[node.nodeId] = {
|
||||
system_info: {
|
||||
model_id: profile?.modelId ?? "Unknown",
|
||||
chip: profile?.chipId,
|
||||
model_id: identity?.modelId ?? "Unknown",
|
||||
chip: identity?.chipId,
|
||||
memory: ramTotal,
|
||||
},
|
||||
network_interfaces: networkInterfaces,
|
||||
@@ -278,17 +256,15 @@ function transformTopology(
|
||||
ram_total: ramTotal,
|
||||
},
|
||||
temp:
|
||||
profile?.system?.temp !== undefined
|
||||
? { gpu_temp_avg: profile.system.temp }
|
||||
system?.temp !== undefined
|
||||
? { gpu_temp_avg: system.temp }
|
||||
: undefined,
|
||||
gpu_usage:
|
||||
profile?.system?.gpuUsage !== undefined
|
||||
? [0, profile.system.gpuUsage]
|
||||
: undefined,
|
||||
sys_power: profile?.system?.sysPower,
|
||||
system?.gpuUsage !== undefined ? [0, system.gpuUsage] : undefined,
|
||||
sys_power: system?.sysPower,
|
||||
},
|
||||
last_macmon_update: Date.now() / 1000,
|
||||
friendly_name: profile?.friendlyName,
|
||||
friendly_name: identity?.friendlyName,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -868,7 +844,13 @@ class AppStore {
|
||||
const data: RawStateResponse = await response.json();
|
||||
|
||||
if (data.topology) {
|
||||
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
|
||||
this.topologyData = transformTopology(
|
||||
data.topology,
|
||||
data.nodeIdentities,
|
||||
data.nodeMemories,
|
||||
data.nodeSystems,
|
||||
data.nodeNetworks,
|
||||
);
|
||||
}
|
||||
if (data.instances) {
|
||||
this.instances = data.instances;
|
||||
|
||||
@@ -600,9 +600,8 @@ class API:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available
|
||||
for memory in self.state.node_memories.values():
|
||||
total_available += memory.ram_available
|
||||
|
||||
return total_available
|
||||
|
||||
|
||||
@@ -113,6 +113,7 @@ def place_instance(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
and node.node_profile.memory is not None
|
||||
),
|
||||
start=Memory(),
|
||||
),
|
||||
|
||||
@@ -25,7 +25,10 @@ class NodeWithProfile(BaseModel):
|
||||
|
||||
|
||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
return all(
|
||||
node.node_profile is not None and node.node_profile.memory is not None
|
||||
for node in nodes
|
||||
)
|
||||
|
||||
|
||||
def filter_cycles_by_memory(
|
||||
@@ -36,8 +39,14 @@ def filter_cycles_by_memory(
|
||||
if not narrow_all_nodes(cycle):
|
||||
continue
|
||||
|
||||
# narrow_all_nodes guarantees memory is not None
|
||||
total_mem = sum(
|
||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile.memory is not None
|
||||
),
|
||||
start=Memory(),
|
||||
)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
||||
@@ -88,7 +97,11 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
raise ValueError("Cannot create shard assignments for empty node cycle")
|
||||
|
||||
cycle_memory = sum(
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in selected_cycle
|
||||
if node.node_profile.memory is not None
|
||||
),
|
||||
start=Memory(),
|
||||
)
|
||||
|
||||
@@ -105,6 +118,7 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
memory_fractions=[
|
||||
node.node_profile.memory.ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node in selected_cycle
|
||||
if node.node_profile.memory is not None
|
||||
],
|
||||
)
|
||||
|
||||
@@ -113,6 +127,7 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
for i, (node, node_layers) in enumerate(
|
||||
zip(selected_cycle, layer_allocations, strict=True)
|
||||
):
|
||||
assert node.node_profile.memory is not None
|
||||
required_memory = node_layers * memory_per_layer
|
||||
available_memory = node.node_profile.memory.ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
|
||||
@@ -19,16 +19,13 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodePerformanceMeasured,
|
||||
NodeIdentityMeasured,
|
||||
NodeMemoryMeasured,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
@@ -75,29 +72,39 @@ async def test_master():
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
# inject a NodePerformanceProfile event
|
||||
logger.info("inject a NodePerformanceProfile event")
|
||||
# inject NodeIdentityMeasured and NodeMemoryMeasured events
|
||||
logger.info("inject NodeIdentityMeasured event")
|
||||
await local_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodePerformanceMeasured(
|
||||
NodeIdentityMeasured(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info("inject NodeMemoryMeasured event")
|
||||
await local_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=1,
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeMemoryMeasured(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -108,7 +115,7 @@ async def test_master():
|
||||
logger.info("wait for initial topology event")
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
while len(master.state.node_profiles) == 0:
|
||||
while len(master.state.node_identities) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
|
||||
logger.info("inject a CreateInstance Command")
|
||||
@@ -155,17 +162,19 @@ async def test_master():
|
||||
),
|
||||
)
|
||||
)
|
||||
while len(_get_events()) < 3:
|
||||
while len(_get_events()) < 4:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
events = _get_events()
|
||||
assert len(events) == 3
|
||||
assert len(events) == 4
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
created_instance = events[1].event.instance
|
||||
assert events[3].idx == 3
|
||||
assert isinstance(events[0].event, NodeIdentityMeasured)
|
||||
assert isinstance(events[1].event, NodeMemoryMeasured)
|
||||
assert isinstance(events[2].event, InstanceCreated)
|
||||
created_instance = events[2].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
|
||||
# Validate the shard assignments
|
||||
@@ -197,10 +206,10 @@ async def test_master():
|
||||
assert len(created_instance.hosts_by_node[node_id]) == 1
|
||||
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
assert events[2].event.task.task_params == ChatCompletionTaskParams(
|
||||
assert isinstance(events[3].event, TaskCreated)
|
||||
assert events[3].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[3].event.task, ChatCompletionTask)
|
||||
assert events[3].event.task.task_params == ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hello, how are you?")
|
||||
|
||||
@@ -13,8 +13,10 @@ from exo.shared.types.events import (
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeIdentityMeasured,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeNetworkMeasured,
|
||||
NodeSystemMeasured,
|
||||
NodeTimedOut,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
@@ -27,7 +29,13 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
NodeIdentity,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
@@ -51,8 +59,12 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodePerformanceMeasured():
|
||||
return apply_node_performance_measured(event, state)
|
||||
case NodeIdentityMeasured():
|
||||
return apply_node_identity_measured(event, state)
|
||||
case NodeSystemMeasured():
|
||||
return apply_node_system_measured(event, state)
|
||||
case NodeNetworkMeasured():
|
||||
return apply_node_network_measured(event, state)
|
||||
case NodeDownloadProgress():
|
||||
return apply_node_download_progress(event, state)
|
||||
case NodeMemoryMeasured():
|
||||
@@ -190,8 +202,19 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
state.topology.remove_node(event.node_id)
|
||||
node_profiles = {
|
||||
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
||||
node_identities = {
|
||||
key: value
|
||||
for key, value in state.node_identities.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_memories = {
|
||||
key: value for key, value in state.node_memories.items() if key != event.node_id
|
||||
}
|
||||
node_systems = {
|
||||
key: value for key, value in state.node_systems.items() if key != event.node_id
|
||||
}
|
||||
node_networks = {
|
||||
key: value for key, value in state.node_networks.items() if key != event.node_id
|
||||
}
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
@@ -199,32 +222,120 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
return state.model_copy(
|
||||
update={
|
||||
"topology": topology,
|
||||
"node_profiles": node_profiles,
|
||||
"node_identities": node_identities,
|
||||
"node_memories": node_memories,
|
||||
"node_systems": node_systems,
|
||||
"node_networks": node_networks,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_performance_measured(
|
||||
event: NodePerformanceMeasured, state: State
|
||||
) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: event.node_profile,
|
||||
def _reconstruct_profile(
|
||||
node_id: NodeId,
|
||||
state: State,
|
||||
*,
|
||||
identity: NodeIdentity | None = None,
|
||||
memory: MemoryPerformanceProfile | None = None,
|
||||
system: SystemPerformanceProfile | None = None,
|
||||
network_interfaces: list[NetworkInterfaceInfo] | None = None,
|
||||
) -> NodePerformanceProfile:
|
||||
"""Reconstruct a NodePerformanceProfile from split state storage.
|
||||
|
||||
Uses provided overrides, falling back to state values.
|
||||
"""
|
||||
ident = identity or state.node_identities.get(node_id)
|
||||
mem = memory or state.node_memories.get(node_id)
|
||||
sys = system or state.node_systems.get(node_id)
|
||||
nets = (
|
||||
network_interfaces
|
||||
if network_interfaces is not None
|
||||
else state.node_networks.get(node_id, [])
|
||||
)
|
||||
|
||||
return NodePerformanceProfile(
|
||||
model_id=ident.model_id if ident else None,
|
||||
chip_id=ident.chip_id if ident else None,
|
||||
friendly_name=ident.friendly_name if ident else None,
|
||||
memory=mem,
|
||||
network_interfaces=nets,
|
||||
system=sys,
|
||||
)
|
||||
|
||||
|
||||
def apply_node_identity_measured(event: NodeIdentityMeasured, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
identity = NodeIdentity(
|
||||
model_id=event.model_id,
|
||||
chip_id=event.chip_id,
|
||||
friendly_name=event.friendly_name,
|
||||
)
|
||||
new_identities: Mapping[NodeId, NodeIdentity] = {
|
||||
**state.node_identities,
|
||||
event.node_id: identity,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
||||
topology = copy.copy(state.topology)
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
reconstructed = _reconstruct_profile(event.node_id, state, identity=identity)
|
||||
topology.update_node_profile(event.node_id, reconstructed)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": new_profiles,
|
||||
"node_identities": new_identities,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_system_measured(event: NodeSystemMeasured, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
new_systems: Mapping[NodeId, SystemPerformanceProfile] = {
|
||||
**state.node_systems,
|
||||
event.node_id: event.system,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
reconstructed = _reconstruct_profile(event.node_id, state, system=event.system)
|
||||
topology.update_node_profile(event.node_id, reconstructed)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_systems": new_systems,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_network_measured(event: NodeNetworkMeasured, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
new_networks: Mapping[NodeId, list[NetworkInterfaceInfo]] = {
|
||||
**state.node_networks,
|
||||
event.node_id: event.network_interfaces,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
reconstructed = _reconstruct_profile(
|
||||
event.node_id, state, network_interfaces=event.network_interfaces
|
||||
)
|
||||
topology.update_node_profile(event.node_id, reconstructed)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_networks": new_networks,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
@@ -232,57 +343,26 @@ def apply_node_performance_measured(
|
||||
|
||||
|
||||
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
|
||||
existing = state.node_profiles.get(event.node_id)
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
if existing is None:
|
||||
created = NodePerformanceProfile(
|
||||
model_id="unknown",
|
||||
chip_id="unknown",
|
||||
friendly_name="Unknown",
|
||||
memory=event.memory,
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(
|
||||
# TODO: flops_fp16=0.0,
|
||||
gpu_usage=0.0,
|
||||
temp=0.0,
|
||||
sys_power=0.0,
|
||||
pcpu_usage=0.0,
|
||||
ecpu_usage=0.0,
|
||||
ane_power=0.0,
|
||||
),
|
||||
)
|
||||
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: created,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
# TODO: NodeCreated
|
||||
topology.update_node_profile(event.node_id, created)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": created_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
updated = existing.model_copy(update={"memory": event.memory})
|
||||
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: updated,
|
||||
new_memories: Mapping[NodeId, MemoryPerformanceProfile] = {
|
||||
**state.node_memories,
|
||||
event.node_id: event.memory,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, updated)
|
||||
reconstructed = _reconstruct_profile(event.node_id, state, memory=event.memory)
|
||||
topology.update_node_profile(event.node_id, reconstructed)
|
||||
return state.model_copy(
|
||||
update={"node_profiles": updated_profiles, "topology": topology}
|
||||
update={
|
||||
"node_memories": new_memories,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,14 @@ from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -85,13 +89,35 @@ class NodeTimedOut(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
class NodePerformanceMeasured(BaseEvent):
|
||||
class NodeIdentityMeasured(BaseEvent):
|
||||
"""Static identity info - emitted once at startup."""
|
||||
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
node_profile: NodePerformanceProfile
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
|
||||
|
||||
class NodeSystemMeasured(BaseEvent):
|
||||
"""Dynamic system metrics (GPU, temp, power) - emitted at 1s intervals."""
|
||||
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
class NodeNetworkMeasured(BaseEvent):
|
||||
"""Semi-static network interface info - emitted at 30s intervals."""
|
||||
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
network_interfaces: list[NetworkInterfaceInfo]
|
||||
|
||||
|
||||
class NodeMemoryMeasured(BaseEvent):
|
||||
"""Dynamic memory metrics - emitted at 0.5s intervals."""
|
||||
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
memory: MemoryPerformanceProfile
|
||||
@@ -127,7 +153,9 @@ Event = (
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
| NodeTimedOut
|
||||
| NodePerformanceMeasured
|
||||
| NodeIdentityMeasured
|
||||
| NodeSystemMeasured
|
||||
| NodeNetworkMeasured
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
|
||||
@@ -52,13 +52,21 @@ class NetworkInterfaceInfo(CamelCaseModel):
|
||||
ip_address: str
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
class NodeIdentity(CamelCaseModel):
|
||||
"""Static identity info for a node."""
|
||||
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
memory: MemoryPerformanceProfile
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
model_id: str | None = None
|
||||
chip_id: str | None = None
|
||||
friendly_name: str | None = None
|
||||
memory: MemoryPerformanceProfile | None = None
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile
|
||||
system: SystemPerformanceProfile | None = None
|
||||
|
||||
|
||||
class ConnectionProfile(CamelCaseModel):
|
||||
|
||||
@@ -7,7 +7,12 @@ from pydantic.alias_generators import to_camel
|
||||
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
NodeIdentity,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -35,7 +40,10 @@ class State(CamelCaseModel):
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
|
||||
node_identities: Mapping[NodeId, NodeIdentity] = {}
|
||||
node_memories: Mapping[NodeId, MemoryPerformanceProfile] = {}
|
||||
node_systems: Mapping[NodeId, SystemPerformanceProfile] = {}
|
||||
node_networks: Mapping[NodeId, list[NetworkInterfaceInfo]] = {}
|
||||
last_seen: Mapping[NodeId, datetime] = {}
|
||||
topology: Topology = Field(default_factory=Topology)
|
||||
last_event_applied_idx: int = Field(default=-1, ge=-1)
|
||||
|
||||
@@ -41,7 +41,7 @@ class _LayerCallable(Protocol):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
|
||||
|
||||
|
||||
class CustomMlxModule(nn.Module):
|
||||
class CustomMlxLayer(nn.Module):
|
||||
"""Base class for replacing an MLX layer with a custom implementation."""
|
||||
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
@@ -63,7 +63,7 @@ class CustomMlxModule(nn.Module):
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxModule):
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
@@ -80,7 +80,7 @@ class PipelineFirstLayer(CustomMlxModule):
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
class PipelineLastLayer(CustomMlxModule):
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
@@ -193,32 +193,7 @@ def pipeline_auto_parallel(
|
||||
"Expected a list of layers after auto-parallel initialisation"
|
||||
)
|
||||
|
||||
return PipelineParallelModel(model, group)
|
||||
|
||||
|
||||
class PipelineParallelModel(CustomMlxModule):
|
||||
def __init__(self, model: nn.Module, group: mx.distributed.Group):
|
||||
super().__init__(model)
|
||||
self.original_call_signature = signature(self.original_layer.__call__)
|
||||
self.group = group
|
||||
dict.__setitem__(self, "original_layer", model)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> mx.array:
|
||||
logits: mx.array = self.original_layer(*args, **kwargs) # type: ignore
|
||||
cache = self.original_call_signature.bind_partial(
|
||||
*args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
if cache is not None:
|
||||
for c in cache: # type: ignore
|
||||
if hasattr(c, "state") and c.state is not None: # type: ignore
|
||||
c.state = mx.depends(c.state, logits) # type: ignore
|
||||
|
||||
return logits
|
||||
return model
|
||||
|
||||
|
||||
def tensor_auto_parallel(
|
||||
@@ -426,7 +401,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxModule):
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
@@ -501,7 +476,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedQwenMoE(CustomMlxModule):
|
||||
class ShardedQwenMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
@@ -548,7 +523,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxModule):
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
@@ -16,8 +16,10 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeIdentityMeasured,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeNetworkMeasured,
|
||||
NodeSystemMeasured,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
TopologyEdgeCreated,
|
||||
@@ -25,7 +27,11 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
@@ -51,7 +57,13 @@ from exo.worker.download.download_utils import (
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
|
||||
from exo.worker.utils import (
|
||||
IdentityMetrics,
|
||||
start_polling_identity_metrics,
|
||||
start_polling_memory_metrics,
|
||||
start_polling_network_metrics,
|
||||
start_polling_system_metrics,
|
||||
)
|
||||
from exo.worker.utils.net_profile import check_reachable
|
||||
|
||||
|
||||
@@ -98,37 +110,51 @@ class Worker:
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
# TODO: CLEANUP HEADER
|
||||
async def resource_monitor_callback(
|
||||
node_performance_profile: NodePerformanceProfile,
|
||||
) -> None:
|
||||
async def identity_callback(identity: IdentityMetrics) -> None:
|
||||
await self.event_sender.send(
|
||||
NodePerformanceMeasured(
|
||||
NodeIdentityMeasured(
|
||||
node_id=self.node_id,
|
||||
node_profile=node_performance_profile,
|
||||
model_id=identity.model_id,
|
||||
chip_id=identity.chip_id,
|
||||
friendly_name=identity.friendly_name,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async def memory_monitor_callback(
|
||||
memory_profile: MemoryPerformanceProfile,
|
||||
) -> None:
|
||||
async def system_callback(system: SystemPerformanceProfile) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeSystemMeasured(
|
||||
node_id=self.node_id,
|
||||
system=system,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async def network_callback(interfaces: list[NetworkInterfaceInfo]) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeNetworkMeasured(
|
||||
node_id=self.node_id,
|
||||
network_interfaces=interfaces,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async def memory_callback(memory: MemoryPerformanceProfile) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeMemoryMeasured(
|
||||
node_id=self.node_id,
|
||||
memory=memory_profile,
|
||||
memory=memory,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
)
|
||||
)
|
||||
|
||||
# END CLEANUP
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(start_polling_identity_metrics, identity_callback)
|
||||
tg.start_soon(start_polling_system_metrics, system_callback)
|
||||
tg.start_soon(start_polling_network_metrics, network_callback)
|
||||
tg.start_soon(start_polling_memory_metrics, memory_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# type: ignore
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -50,13 +50,11 @@ DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_GPT_OSS_MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
|
||||
|
||||
def run_gpt_oss_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
model_path: Path,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
@@ -70,39 +68,17 @@ def run_gpt_oss_pipeline_device(
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
n_layers=24,
|
||||
)
|
||||
|
||||
model, tokenizer = shard_and_load(shard_meta, group)
|
||||
model = cast(Model, model)
|
||||
model, tokenizer = load(str(model_path))
|
||||
|
||||
# Generate a prompt of exact token length
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
@@ -117,21 +93,45 @@ def run_gpt_oss_pipeline_device(
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
|
||||
max_tokens=max_tokens,
|
||||
formatted_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt_text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
n_layers=24,
|
||||
)
|
||||
|
||||
model = pipeline_auto_parallel(model, group, shard_meta)
|
||||
|
||||
# Barrier before generation
|
||||
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
|
||||
mlx_core.eval(barrier)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
for response in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=formatted_prompt,
|
||||
max_tokens=max_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
break
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
@@ -143,6 +143,7 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
model_path: Path,
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
@@ -155,38 +156,14 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
# For tensor parallelism, all devices run all layers
|
||||
shard_meta = TensorShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
end_layer=24,
|
||||
n_layers=24,
|
||||
)
|
||||
|
||||
model, tokenizer = shard_and_load(shard_meta, group)
|
||||
model = cast(Model, model)
|
||||
model, tokenizer = load(str(model_path))
|
||||
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
@@ -198,21 +175,26 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
|
||||
max_tokens=max_tokens,
|
||||
formatted_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt_text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
model = tensor_auto_parallel(model, group)
|
||||
|
||||
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
|
||||
mlx_core.eval(barrier)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
for response in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=formatted_prompt,
|
||||
max_tokens=max_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
break
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
@@ -5,10 +5,9 @@ import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
CustomMlxModule,
|
||||
CustomMlxLayer,
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
PipelineParallelModel,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
|
||||
|
||||
@@ -37,18 +36,6 @@ def run_pipeline_device(
|
||||
) -> mlx_core.array:
|
||||
return x * 2
|
||||
|
||||
class MockModel(mlx_nn.Module):
|
||||
def __init__(self, layers: list[mlx_nn.Module]) -> None:
|
||||
super().__init__()
|
||||
self.layers = layers
|
||||
|
||||
def __call__(
|
||||
self, x: mlx_core.array, *args: object, **kwargs: object
|
||||
) -> mlx_core.array:
|
||||
for layer in self.layers:
|
||||
x = layer(x, *args, **kwargs) # pyright: ignore[reportUnknownVariableType]
|
||||
return x # pyright: ignore[reportUnknownVariableType]
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
@@ -56,12 +43,8 @@ def run_pipeline_device(
|
||||
first = PipelineFirstLayer(mock, r=rank, group=group)
|
||||
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
|
||||
|
||||
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
|
||||
inner_model = MockModel([composed])
|
||||
model = PipelineParallelModel(inner_model, group)
|
||||
|
||||
x = mlx_core.ones((1, 4))
|
||||
result = model(x)
|
||||
result = composed(x)
|
||||
mlx_core.eval(result)
|
||||
|
||||
success = result.shape == x.shape
|
||||
@@ -72,7 +55,7 @@ def run_pipeline_device(
|
||||
|
||||
def test_single_wrapper_delegates_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
wrapped = CustomMlxModule(mock)
|
||||
wrapped = CustomMlxLayer(mock)
|
||||
|
||||
assert wrapped.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert wrapped.use_sliding is True # type: ignore[attr-defined]
|
||||
@@ -91,7 +74,7 @@ def test_composed_wrappers_delegate_attributes() -> None:
|
||||
|
||||
def test_missing_attribute_raises() -> None:
|
||||
mock = MockLayer()
|
||||
wrapped = CustomMlxModule(mock)
|
||||
wrapped = CustomMlxLayer(mock)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = wrapped.nonexistent_attr # type: ignore[attr-defined]
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import (
|
||||
DEFAULT_GPT_OSS_CONFIG,
|
||||
create_hostfile,
|
||||
run_gpt_oss_pipeline_device,
|
||||
run_gpt_oss_tensor_parallel_device,
|
||||
)
|
||||
|
||||
|
||||
def _check_model_exists() -> bool:
|
||||
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedTestResult:
|
||||
timed_out: bool
|
||||
world_size: int
|
||||
results: dict[int, tuple[bool, str]]
|
||||
|
||||
@property
|
||||
def all_success(self) -> bool:
|
||||
if len(self.results) != self.world_size:
|
||||
return False
|
||||
return all(r[0] for r in self.results.values())
|
||||
|
||||
|
||||
def run_distributed_test(
|
||||
world_size: int,
|
||||
port_offset: int,
|
||||
process_timeout: int,
|
||||
target: Callable[..., None],
|
||||
make_args: Callable[[int], tuple[Any, ...]],
|
||||
) -> DistributedTestResult:
|
||||
ctx = mp.get_context("spawn")
|
||||
hostfile_path, _ = create_hostfile(
|
||||
world_size, DEFAULT_GPT_OSS_CONFIG.base_port + port_offset
|
||||
)
|
||||
|
||||
try:
|
||||
result_queue: Any = ctx.Queue()
|
||||
processes: list[Any] = []
|
||||
|
||||
for rank in range(world_size):
|
||||
args = make_args(rank)
|
||||
p = ctx.Process(
|
||||
target=target,
|
||||
args=(rank, world_size, hostfile_path, *args, result_queue),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
p.join(timeout=process_timeout) # pyright: ignore[reportAny]
|
||||
|
||||
timed_out = any(p.is_alive() for p in processes) # pyright: ignore[reportAny]
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
if p.is_alive(): # pyright: ignore[reportAny]
|
||||
p.terminate() # pyright: ignore[reportAny]
|
||||
p.join(timeout=5) # pyright: ignore[reportAny]
|
||||
|
||||
results: dict[int, tuple[bool, str]] = {}
|
||||
while not result_queue.empty(): # pyright: ignore[reportAny]
|
||||
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
|
||||
results[rank] = (success, value)
|
||||
|
||||
return DistributedTestResult(
|
||||
timed_out=timed_out, world_size=world_size, results=results
|
||||
)
|
||||
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
|
||||
|
||||
def run_pipeline_test(
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
port_offset: int = 0,
|
||||
process_timeout: int = 60,
|
||||
) -> DistributedTestResult:
|
||||
def make_args(rank: int) -> tuple[Any, ...]:
|
||||
return (
|
||||
layer_splits,
|
||||
prompt_tokens,
|
||||
prefill_step_size,
|
||||
)
|
||||
|
||||
return run_distributed_test(
|
||||
world_size=len(layer_splits),
|
||||
port_offset=port_offset,
|
||||
process_timeout=process_timeout,
|
||||
target=run_gpt_oss_pipeline_device,
|
||||
make_args=make_args,
|
||||
)
|
||||
|
||||
|
||||
def run_tensor_test(
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
port_offset: int = 0,
|
||||
process_timeout: int = 60,
|
||||
) -> DistributedTestResult:
|
||||
def make_args(rank: int) -> tuple[Any, ...]:
|
||||
return (
|
||||
prompt_tokens,
|
||||
prefill_step_size,
|
||||
)
|
||||
|
||||
return run_distributed_test(
|
||||
world_size=2,
|
||||
port_offset=port_offset,
|
||||
process_timeout=process_timeout,
|
||||
target=run_gpt_oss_tensor_parallel_device,
|
||||
make_args=make_args,
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineParallelFix:
|
||||
BUG_TRIGGER_SPLITS: list[tuple[int, int]] = [(0, 1), (1, 24)]
|
||||
|
||||
def test_pipeline_single_layer_first_device(self) -> None:
|
||||
result = run_pipeline_test(
|
||||
layer_splits=self.BUG_TRIGGER_SPLITS,
|
||||
prompt_tokens=100,
|
||||
prefill_step_size=64,
|
||||
process_timeout=60,
|
||||
)
|
||||
assert not result.timed_out, "Unexpected timeout - fix may not be working"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestPipelineSplitConfigurations:
|
||||
@pytest.mark.parametrize(
|
||||
"layer_splits",
|
||||
[
|
||||
[(0, 1), (1, 24)],
|
||||
[(0, 6), (6, 24)],
|
||||
[(0, 12), (12, 24)],
|
||||
],
|
||||
ids=["1_23", "6_18", "12_12"],
|
||||
)
|
||||
def test_pipeline_splits(
|
||||
self,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
) -> None:
|
||||
result = run_pipeline_test(
|
||||
layer_splits=layer_splits,
|
||||
prompt_tokens=600,
|
||||
prefill_step_size=512,
|
||||
port_offset=100,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout with {layer_splits}"
|
||||
assert result.all_success, f"Failures with {layer_splits}: {result.results}"
|
||||
|
||||
|
||||
class TestPrefillStepSizeBoundaries:
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_step_size,prompt_tokens",
|
||||
[
|
||||
(512, 511),
|
||||
(512, 512),
|
||||
(512, 513),
|
||||
(512, 1024),
|
||||
],
|
||||
ids=["under", "exact", "over", "double"],
|
||||
)
|
||||
def test_boundary_conditions(
|
||||
self,
|
||||
prefill_step_size: int,
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
result = run_pipeline_test(
|
||||
layer_splits=[(0, 12), (12, 24)],
|
||||
prompt_tokens=prompt_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
port_offset=200,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestTensorParallelFix:
|
||||
def test_tensor_parallel(self) -> None:
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=100,
|
||||
prefill_step_size=64,
|
||||
port_offset=400,
|
||||
)
|
||||
assert not result.timed_out, "Unexpected timeout"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestTensorParallelBoundaries:
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_step_size,prompt_tokens",
|
||||
[
|
||||
(512, 511),
|
||||
(512, 512),
|
||||
(512, 513),
|
||||
(512, 1024),
|
||||
],
|
||||
ids=["under", "exact", "over", "double"],
|
||||
)
|
||||
def test_tensor_parallel_boundaries(
|
||||
self,
|
||||
prefill_step_size: int,
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
port_offset=500,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
@@ -1,6 +1,15 @@
|
||||
from .profile import start_polling_memory_metrics, start_polling_node_metrics
|
||||
from .profile import (
|
||||
IdentityMetrics,
|
||||
start_polling_identity_metrics,
|
||||
start_polling_memory_metrics,
|
||||
start_polling_network_metrics,
|
||||
start_polling_system_metrics,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"start_polling_node_metrics",
|
||||
"IdentityMetrics",
|
||||
"start_polling_identity_metrics",
|
||||
"start_polling_memory_metrics",
|
||||
"start_polling_network_metrics",
|
||||
"start_polling_system_metrics",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import anyio
|
||||
@@ -9,7 +10,7 @@ from loguru import logger
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
NetworkInterfaceInfo,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
@@ -27,6 +28,13 @@ from .system_info import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IdentityMetrics:
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics | None:
|
||||
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
|
||||
|
||||
@@ -67,48 +75,73 @@ async def start_polling_memory_metrics(
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_node_metrics(
|
||||
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
|
||||
):
|
||||
poll_interval_s = 1.0
|
||||
async def start_polling_identity_metrics(
|
||||
callback: Callable[[IdentityMetrics], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 30.0,
|
||||
) -> None:
|
||||
"""Continuously poll and emit identity metrics at 30s intervals."""
|
||||
while True:
|
||||
try:
|
||||
model_id, chip_id = await get_model_and_chip()
|
||||
friendly_name = await get_friendly_name()
|
||||
await callback(
|
||||
IdentityMetrics(
|
||||
model_id=model_id,
|
||||
chip_id=chip_id,
|
||||
friendly_name=friendly_name,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("Failed to emit identity metrics")
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_system_metrics(
|
||||
callback: Callable[[SystemPerformanceProfile], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 1.0,
|
||||
) -> None:
|
||||
"""Continuously poll and emit system metrics (GPU, temp, power) at 1s intervals."""
|
||||
while True:
|
||||
try:
|
||||
metrics = await get_metrics_async()
|
||||
if metrics is None:
|
||||
return
|
||||
|
||||
network_interfaces = get_network_interfaces()
|
||||
# these awaits could be joined but realistically they should be cached
|
||||
model_id, chip_id = await get_model_and_chip()
|
||||
friendly_name = await get_friendly_name()
|
||||
|
||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
||||
memory_profile = get_memory_profile()
|
||||
|
||||
await callback(
|
||||
NodePerformanceProfile(
|
||||
model_id=model_id,
|
||||
chip_id=chip_id,
|
||||
friendly_name=friendly_name,
|
||||
network_interfaces=network_interfaces,
|
||||
memory=memory_profile,
|
||||
system=SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
sys_power=metrics.sys_power,
|
||||
pcpu_usage=metrics.pcpu_usage[1],
|
||||
ecpu_usage=metrics.ecpu_usage[1],
|
||||
ane_power=metrics.ane_power,
|
||||
),
|
||||
SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
sys_power=metrics.sys_power,
|
||||
pcpu_usage=metrics.pcpu_usage[1],
|
||||
ecpu_usage=metrics.ecpu_usage[1],
|
||||
ane_power=metrics.ane_power,
|
||||
)
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
|
||||
"[system_monitor] Operation timed out after 30s, skipping this cycle."
|
||||
)
|
||||
except MacMonError as e:
|
||||
logger.opt(exception=e).error("Resource Monitor encountered error")
|
||||
logger.opt(exception=e).error("System Monitor encountered error")
|
||||
return
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_network_metrics(
|
||||
callback: Callable[[list[NetworkInterfaceInfo]], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 30.0,
|
||||
) -> None:
|
||||
"""Continuously poll and emit network interface info at 30s intervals."""
|
||||
while True:
|
||||
try:
|
||||
network_interfaces = get_network_interfaces()
|
||||
await callback(network_interfaces)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("Network Monitor encountered error")
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
Reference in New Issue
Block a user