mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-14 17:07:56 -05:00
Compare commits
1 Commits
gather-lin
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e3030221e |
@@ -276,24 +276,23 @@ class BatchGenerator:
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
unprocessed_prompts: List[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: nn.Module,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
stop_tokens: Optional[set[int]] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
|
||||
) -> List[int]: ...
|
||||
def stats(self) -> BatchStats: ...
|
||||
def next(self) -> List[Response]: ...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
|
||||
@@ -39,12 +39,18 @@ class StreamingDetokenizer:
|
||||
"""
|
||||
|
||||
__slots__ = ...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
tokens: list[int]
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
def finalize(self) -> None: ...
|
||||
@property
|
||||
def last_segment(self):
|
||||
def text(self) -> str:
|
||||
"""The full text decoded so far."""
|
||||
...
|
||||
@property
|
||||
def last_segment(self) -> str:
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
...
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
@@ -108,6 +114,7 @@ class TokenizerWrapper:
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
eos_token_ids: list[int] | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
|
||||
39
AGENTS.md
39
AGENTS.md
@@ -91,6 +91,45 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Model Storage
|
||||
|
||||
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
|
||||
|
||||
## Creating Model Instances via API
|
||||
|
||||
When testing with the API, you must first create a model instance before sending chat completions:
|
||||
|
||||
```bash
|
||||
# 1. Get instance previews for a model
|
||||
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
|
||||
|
||||
# 2. Create an instance from the first valid preview
|
||||
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
|
||||
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
|
||||
|
||||
# 3. Wait for the runner to become ready (check logs for "runner ready")
|
||||
|
||||
# 4. Send chat completions using the full model ID
|
||||
curl -X POST http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
|
||||
```
|
||||
|
||||
## Logs
|
||||
|
||||
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
### Distributed Testing
|
||||
|
||||
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
|
||||
|
||||
```bash
|
||||
# On each machine in the test cluster, use the same unique namespace
|
||||
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
|
||||
```
|
||||
|
||||
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.
|
||||
|
||||
1
TODO.md
1
TODO.md
@@ -19,7 +19,6 @@
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Validate RDMA connections with ibv_devinfo in the info gatherer
|
||||
|
||||
Potential refactors:
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ enum NetworkSetupHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge.sh"
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
@@ -28,6 +28,35 @@ enum NetworkSetupHelper {
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
|
||||
@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
// Uses API preview data when available, falls back to local estimation
|
||||
const placementPreview = $derived(() => {
|
||||
const nodeArray = nodeList();
|
||||
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
|
||||
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
|
||||
|
||||
const numNodes = nodeArray.length;
|
||||
const iconSize = numNodes === 1 ? 50 : 36;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from 'svelte';
|
||||
import * as d3 from 'd3';
|
||||
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
|
||||
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
|
||||
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: NodeInfo | undefined): string | null {
|
||||
function checkNode(node: typeof data.nodes[string]): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
@@ -39,19 +39,17 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
|
||||
if (node.ip_to_interface) {
|
||||
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
@@ -257,24 +255,21 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
|
||||
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
|
||||
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
|
||||
const pairMap = new Map<string, PairEntry>();
|
||||
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
|
||||
|
||||
const a = edge.source < edge.target ? edge.source : edge.target;
|
||||
const b = edge.source < edge.target ? edge.target : edge.source;
|
||||
const key = `${a}|${b}`;
|
||||
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
|
||||
|
||||
|
||||
if (edge.source === a) entry.aToB = true;
|
||||
else entry.bToA = true;
|
||||
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
|
||||
const ifaceInfo = getInterfaceLabel(edge.source, ip);
|
||||
entry.connections.push({
|
||||
from: edge.source,
|
||||
@@ -343,8 +338,9 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
@@ -385,32 +381,32 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
|
||||
if (quadrantEdges.length === 0) return;
|
||||
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
quadrantEdges.forEach(edge => {
|
||||
|
||||
edges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
|
||||
@@ -99,7 +99,7 @@ interface RawNodeProfile {
|
||||
|
||||
interface RawTopologyNode {
|
||||
nodeId: string;
|
||||
nodeProfile?: RawNodeProfile;
|
||||
nodeProfile: RawNodeProfile;
|
||||
}
|
||||
|
||||
interface RawTopologyConnection {
|
||||
@@ -110,19 +110,9 @@ interface RawTopologyConnection {
|
||||
| string;
|
||||
}
|
||||
|
||||
// Connection can be an object or a tuple [source, target, metadata]
|
||||
type RawConnectionItem =
|
||||
| RawTopologyConnection
|
||||
| [
|
||||
string,
|
||||
string,
|
||||
{ sinkMultiaddr?: { ip_address?: string; address?: string } }?,
|
||||
];
|
||||
|
||||
interface RawTopology {
|
||||
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
|
||||
nodes: (string | RawTopologyNode)[];
|
||||
connections?: RawConnectionItem[];
|
||||
nodes: RawTopologyNode[];
|
||||
connections?: RawTopologyConnection[];
|
||||
}
|
||||
|
||||
type RawNodeProfiles = Record<string, RawNodeProfile>;
|
||||
@@ -223,18 +213,9 @@ function transformTopology(
|
||||
const nodes: Record<string, NodeInfo> = {};
|
||||
const edges: TopologyEdge[] = [];
|
||||
|
||||
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
|
||||
for (const node of raw.nodes || []) {
|
||||
// Determine the node ID - could be a string or an object with nodeId property
|
||||
const nodeId = typeof node === "string" ? node : node.nodeId;
|
||||
if (!nodeId) continue;
|
||||
|
||||
// Get the profile - from the separate profiles map or from the node object itself
|
||||
const profileFromMap = profiles?.[nodeId];
|
||||
const profileFromNode =
|
||||
typeof node === "object" ? node.nodeProfile : undefined;
|
||||
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
|
||||
|
||||
const mergedProfile = profiles?.[node.nodeId];
|
||||
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
|
||||
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
|
||||
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
|
||||
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
|
||||
@@ -283,7 +264,7 @@ function transformTopology(
|
||||
}
|
||||
}
|
||||
|
||||
nodes[nodeId] = {
|
||||
nodes[node.nodeId] = {
|
||||
system_info: {
|
||||
model_id: profile?.modelId ?? "Unknown",
|
||||
chip: profile?.chipId,
|
||||
@@ -311,39 +292,14 @@ function transformTopology(
|
||||
};
|
||||
}
|
||||
|
||||
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
|
||||
for (const conn of raw.connections || []) {
|
||||
let localNodeId: string | undefined;
|
||||
let sendBackNodeId: string | undefined;
|
||||
let sendBackMultiaddr:
|
||||
| { multiaddr?: string; address?: string; ip_address?: string }
|
||||
| string
|
||||
| undefined;
|
||||
|
||||
// Check if it's a tuple format [source, target, metadata]
|
||||
if (Array.isArray(conn)) {
|
||||
localNodeId = conn[0] as string;
|
||||
sendBackNodeId = conn[1] as string;
|
||||
const metadata = conn[2] as
|
||||
| { sinkMultiaddr?: { ip_address?: string; address?: string } }
|
||||
| undefined;
|
||||
if (metadata?.sinkMultiaddr) {
|
||||
sendBackMultiaddr = metadata.sinkMultiaddr;
|
||||
}
|
||||
} else {
|
||||
// Object format with localNodeId/sendBackNodeId
|
||||
localNodeId = conn.localNodeId;
|
||||
sendBackNodeId = conn.sendBackNodeId;
|
||||
sendBackMultiaddr = conn.sendBackMultiaddr;
|
||||
}
|
||||
|
||||
if (!localNodeId || !sendBackNodeId) continue;
|
||||
if (localNodeId === sendBackNodeId) continue;
|
||||
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
|
||||
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
|
||||
if (conn.localNodeId === conn.sendBackNodeId) continue;
|
||||
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
|
||||
|
||||
let sendBackIp: string | undefined;
|
||||
if (sendBackMultiaddr) {
|
||||
const multi = sendBackMultiaddr;
|
||||
if (conn.sendBackMultiaddr) {
|
||||
const multi = conn.sendBackMultiaddr;
|
||||
if (typeof multi === "string") {
|
||||
sendBackIp = extractIpFromMultiaddr(multi);
|
||||
} else {
|
||||
@@ -355,8 +311,8 @@ function transformTopology(
|
||||
}
|
||||
|
||||
edges.push({
|
||||
source: localNodeId,
|
||||
target: sendBackNodeId,
|
||||
source: conn.localNodeId,
|
||||
target: conn.sendBackNodeId,
|
||||
sendBackIp,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -895,7 +895,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
|
||||
const [tag, shard] = getTagged(shardWrapped);
|
||||
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
|
||||
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
|
||||
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
|
||||
return { runnerId, tag, deviceRank };
|
||||
});
|
||||
|
||||
|
||||
37
flake.lock
generated
37
flake.lock
generated
@@ -8,11 +8,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -42,22 +42,6 @@
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
@@ -68,8 +52,8 @@
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
@@ -78,18 +62,17 @@
|
||||
"fenix": "fenix",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -106,11 +89,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
15
flake.nix
15
flake.nix
@@ -18,9 +18,6 @@
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
nixConfig = {
|
||||
@@ -42,11 +39,9 @@
|
||||
];
|
||||
|
||||
perSystem =
|
||||
{ config, inputs', pkgs, lib, system, ... }:
|
||||
{ config, inputs', pkgs, lib, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
@@ -65,10 +60,7 @@
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
swift-format.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -131,6 +123,9 @@
|
||||
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
echo
|
||||
echo "🍎🍎 Run 'just <recipe>' to get started"
|
||||
just --list
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
@@ -236,7 +236,6 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -292,7 +291,6 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -601,8 +599,9 @@ class API:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
for profile in self.state.node_profiles.values():
|
||||
total_available += profile.memory.ram_available
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available
|
||||
|
||||
return total_available
|
||||
|
||||
|
||||
@@ -158,7 +158,6 @@ class Master:
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_profiles,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
@@ -201,7 +200,9 @@ class Master:
|
||||
async def _plan(self) -> None:
|
||||
while True:
|
||||
# kill broken instances
|
||||
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
|
||||
connected_node_ids = set(
|
||||
[x.node_id for x in self.state.topology.list_nodes()]
|
||||
)
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
|
||||
@@ -6,10 +6,9 @@ from typing import Sequence
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
@@ -20,11 +19,10 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -54,16 +52,19 @@ def place_instance(
|
||||
command: PlaceInstance,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
cycles = topology.get_cycles() + [[node] for node in all_nodes]
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, node_profiles, command.model_meta.storage_size
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
|
||||
)
|
||||
if len(cycles_with_sufficient_memory) == 0:
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, command.model_meta.storage_size
|
||||
)
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
if command.sharding == Sharding.Tensor:
|
||||
@@ -93,15 +94,13 @@ def place_instance(
|
||||
smallest_tb_cycles = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if topology.get_subgraph_from_nodes(
|
||||
[node.node_id for node in cycle]
|
||||
).is_thunderbolt_cycle([node.node_id for node in cycle])
|
||||
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
||||
]
|
||||
|
||||
if smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
|
||||
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if any(topology.node_is_leaf(node.node_id) for node in cycle)
|
||||
@@ -110,7 +109,11 @@ def place_instance(
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(node.node_profile.memory.ram_available for node in cycle),
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
),
|
||||
start=Memory(),
|
||||
),
|
||||
)
|
||||
@@ -119,16 +122,14 @@ def place_instance(
|
||||
command.model_meta, selected_cycle, command.sharding
|
||||
)
|
||||
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
|
||||
[node.node_id for node in selected_cycle]
|
||||
)
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
|
||||
|
||||
instance_id = InstanceId()
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.warning(
|
||||
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
|
||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
command.instance_meta = InstanceMeta.MlxRing
|
||||
@@ -136,19 +137,19 @@ def place_instance(
|
||||
# TODO: Single node instances
|
||||
match command.instance_meta:
|
||||
case InstanceMeta.MlxJaccl:
|
||||
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
||||
[node.node_id for node in selected_cycle],
|
||||
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
|
||||
selected_cycle,
|
||||
cycle_digraph,
|
||||
)
|
||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||
coordinator=selected_cycle[0].node_id,
|
||||
selected_cycle,
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
cycle_digraph=cycle_digraph,
|
||||
)
|
||||
target_instances[instance_id] = MlxJacclInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import TypeGuard, cast
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
@@ -8,7 +9,7 @@ from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
@@ -23,32 +24,27 @@ class NodeWithProfile(BaseModel):
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
|
||||
|
||||
def filter_cycles_by_memory(
|
||||
cycles: list[list[NodeId]],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
required_memory: Memory,
|
||||
) -> list[list[NodeWithProfile]]:
|
||||
filtered_cycles: list[list[NodeWithProfile]] = []
|
||||
cycles: list[list[NodeInfo]], required_memory: Memory
|
||||
) -> list[list[NodeInfo]]:
|
||||
filtered_cycles: list[list[NodeInfo]] = []
|
||||
for cycle in cycles:
|
||||
if not all(node in node_profiles for node in cycle):
|
||||
if not narrow_all_nodes(cycle):
|
||||
continue
|
||||
|
||||
total_mem = sum(
|
||||
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
|
||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
||||
)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(
|
||||
[
|
||||
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
|
||||
for node in cycle
|
||||
]
|
||||
)
|
||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(
|
||||
cycles: list[list[NodeWithProfile]],
|
||||
) -> list[list[NodeWithProfile]]:
|
||||
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
@@ -139,9 +135,11 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
def get_shard_assignments(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
selected_cycle: list[NodeInfo],
|
||||
sharding: Sharding,
|
||||
) -> ShardAssignments:
|
||||
if not narrow_all_nodes(selected_cycle):
|
||||
raise ValueError("All nodes must have profiles to create shard assignments")
|
||||
match sharding:
|
||||
case Sharding.Pipeline:
|
||||
return get_shard_assignments_for_pipeline_parallel(
|
||||
@@ -178,16 +176,17 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
current_node = cycle[i]
|
||||
next_node = cycle[(i + 1) % len(cycle)]
|
||||
|
||||
for src, sink, connection in cycle_digraph.list_connections():
|
||||
if not isinstance(connection, SocketConnection):
|
||||
continue
|
||||
|
||||
if src == current_node and sink == next_node:
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == current_node.node_id
|
||||
and connection.send_back_node_id == next_node.node_id
|
||||
):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
assert connection.send_back_multiaddr is not None
|
||||
host = Host(
|
||||
ip=connection.sink_multiaddr.ip_address,
|
||||
port=connection.sink_multiaddr.port,
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
@@ -195,8 +194,8 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
selected_cycle: list[NodeId],
|
||||
def get_mlx_ibv_devices_matrix(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
) -> list[list[str | None]]:
|
||||
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
|
||||
@@ -215,38 +214,71 @@ def get_mlx_jaccl_devices_matrix(
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(conn, RDMAConnection):
|
||||
matrix[i][j] = conn.source_rdma_iface
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_rdma_interface_name_for_ip(
|
||||
connection_ip, node_i
|
||||
):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i} and {node_j}"
|
||||
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all-to-all RDMA connections"
|
||||
"Current ibv backend requires all-to-all rdma connections"
|
||||
)
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _find_connection_ip(
|
||||
node_i: NodeId,
|
||||
node_j: NodeId,
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
|
||||
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(connection, SocketConnection):
|
||||
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_rdma_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
|
||||
continue
|
||||
logger.info(f" | {interface.name}: {interface.ip_address}")
|
||||
if interface.ip_address != ip_address:
|
||||
continue
|
||||
|
||||
logger.info("Found")
|
||||
return f"rdma_{interface.name}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeWithProfile,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
@@ -255,7 +287,7 @@ def _find_interface_name_for_ip(
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node: NodeWithProfile, other_node: NodeWithProfile, cycle_digraph: Topology
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
@@ -266,7 +298,7 @@ def _find_ip_prioritised(
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node.node_id, other_node.node_id, cycle_digraph))
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
|
||||
|
||||
@@ -292,7 +324,7 @@ def _find_ip_prioritised(
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
@@ -329,7 +361,7 @@ def get_mlx_ring_hosts_by_node(
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node}"
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
@@ -343,32 +375,31 @@ def get_mlx_ring_hosts_by_node(
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
coordinator: NodeId,
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
cycle_digraph: Topology,
|
||||
) -> dict[NodeId, str]:
|
||||
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
|
||||
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
|
||||
|
||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
logger.info(f"Selecting coordinator: {coordinator}")
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
|
||||
def get_ip_for_node(n: NodeId) -> str:
|
||||
if n == coordinator:
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip, _ in _find_connection_ip(n, coordinator, cycle_digraph):
|
||||
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
|
||||
if ip:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
|
||||
)
|
||||
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
||||
|
||||
return {
|
||||
n: f"{get_ip_for_node(n)}:{coordinator_port}"
|
||||
for n in cycle_digraph.list_nodes()
|
||||
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
|
||||
}
|
||||
|
||||
@@ -1,36 +1,67 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
|
||||
|
||||
def create_node_profile(memory: int) -> NodePerformanceProfile:
|
||||
return NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
|
||||
return _create_node
|
||||
|
||||
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||
)
|
||||
@pytest.fixture
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
ip_counter = 1
|
||||
|
||||
def _create_connection(
|
||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
||||
) -> Connection:
|
||||
nonlocal port_counter
|
||||
nonlocal ip_counter
|
||||
# assign unique ips
|
||||
ip_counter += 1
|
||||
if send_back_port is None:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
def create_rdma_connection(iface: int) -> RDMAConnection:
|
||||
return RDMAConnection(
|
||||
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
|
||||
)
|
||||
return _create_connection
|
||||
|
||||
@@ -19,13 +19,15 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodeGatheredInfo,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
@@ -81,14 +83,21 @@ async def test_master():
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeGatheredInfo(
|
||||
NodePerformanceMeasured(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
info=MemoryUsage(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -154,7 +163,7 @@ async def test_master():
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodeGatheredInfo)
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
@@ -5,20 +7,14 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_connection,
|
||||
create_node_profile,
|
||||
create_rdma_connection,
|
||||
)
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -30,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def instance() -> Instance:
|
||||
return MlxRingInstance(
|
||||
@@ -76,36 +77,34 @@ def test_get_instance_placements_create_instance(
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
model_meta.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
cic = place_instance_command(model_meta)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(available_memory[0]),
|
||||
node_id_b: create_node_profile(available_memory[1]),
|
||||
node_id_c: create_node_profile(available_memory[2]),
|
||||
}
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_c, create_connection(2))
|
||||
topology.add_connection(node_id_c, node_id_a, create_connection(3))
|
||||
topology.add_connection(node_id_c, node_id_b, create_connection(4))
|
||||
topology.add_connection(node_id_a, node_id_c, create_connection(5))
|
||||
topology.add_connection(node_id_b, node_id_a, create_connection(6))
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
@@ -131,11 +130,12 @@ def test_get_instance_placements_create_instance(
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -146,7 +146,7 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -157,11 +157,12 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1001 * 1024)}
|
||||
topology.add_node(create_node(1001 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -172,7 +173,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -183,11 +184,12 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -200,7 +202,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
||||
place_instance(cic, topology, {}, profiles)
|
||||
place_instance(cic, topology, {})
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(instance: Instance):
|
||||
@@ -245,103 +247,179 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_selects_leaf_nodes(
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
topology = Topology()
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
|
||||
model_meta.storage_size = Memory.from_bytes(1000)
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
model_meta.n_layers = 12
|
||||
|
||||
# Create node ids
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
node_id_d = NodeId()
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(500),
|
||||
node_id_b: create_node_profile(600),
|
||||
node_id_c: create_node_profile(600),
|
||||
node_id_d: create_node_profile(500),
|
||||
}
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
topology.add_node(create_node(800, node_id_c))
|
||||
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_node(node_id_d)
|
||||
# D-E-F cycle total memory = 1800 (> A-B-C total)
|
||||
topology.add_node(create_node(600, node_id_d))
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
|
||||
# Daisy chain topology
|
||||
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_a, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_c, create_connection(1))
|
||||
topology.add_connection(node_id_c, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_c, node_id_d, create_connection(1))
|
||||
topology.add_connection(node_id_d, node_id_c, create_connection(1))
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
logger.info(list(topology.list_connections()))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# assert
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
assert len(placements) == 1
|
||||
instance = list(placements.values())[0]
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
|
||||
(node_id_c, node_id_d)
|
||||
)
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
topology = Topology()
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
node_c = NodeId()
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
|
||||
profiles = {
|
||||
node_a: create_node_profile(500),
|
||||
node_b: create_node_profile(500),
|
||||
node_c: create_node_profile(500),
|
||||
}
|
||||
node_a = create_node(500, node_id_a)
|
||||
node_b = create_node(500, node_id_b)
|
||||
node_c = create_node(500, node_id_c)
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(
|
||||
name="en0",
|
||||
ip_address="192.168.1.100",
|
||||
)
|
||||
ethernet_conn = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
|
||||
)
|
||||
|
||||
profiles[node_a].network_interfaces = [ethernet_interface]
|
||||
profiles[node_b].network_interfaces = [ethernet_interface]
|
||||
profiles[node_c].network_interfaces = [ethernet_interface]
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
conn_a_b = create_connection(node_id_a, node_id_b)
|
||||
conn_b_c = create_connection(node_id_b, node_id_c)
|
||||
conn_c_a = create_connection(node_id_c, node_id_a)
|
||||
|
||||
conn_b_a = create_connection(node_id_b, node_id_a)
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology.add_connection(node_a, node_b, create_rdma_connection(3))
|
||||
topology.add_connection(node_b, node_c, create_rdma_connection(4))
|
||||
topology.add_connection(node_c, node_a, create_rdma_connection(5))
|
||||
topology.add_connection(node_b, node_a, create_rdma_connection(3))
|
||||
topology.add_connection(node_c, node_b, create_rdma_connection(4))
|
||||
topology.add_connection(node_a, node_c, create_rdma_connection(5))
|
||||
|
||||
topology.add_connection(node_a, node_b, ethernet_conn)
|
||||
topology.add_connection(node_b, node_c, ethernet_conn)
|
||||
topology.add_connection(node_c, node_a, ethernet_conn)
|
||||
topology.add_connection(node_a, node_c, ethernet_conn)
|
||||
topology.add_connection(node_b, node_a, ethernet_conn)
|
||||
topology.add_connection(node_c, node_b, ethernet_conn)
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cic = PlaceInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
@@ -351,7 +429,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -359,10 +437,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.jaccl_devices is not None
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.jaccl_coordinators is not None
|
||||
|
||||
matrix = instance.jaccl_devices
|
||||
matrix = instance.ibv_devices
|
||||
assert len(matrix) == 3
|
||||
|
||||
for i in range(3):
|
||||
@@ -371,15 +449,15 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
|
||||
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
|
||||
|
||||
idx_a = node_to_idx[node_a]
|
||||
idx_b = node_to_idx[node_b]
|
||||
idx_c = node_to_idx[node_c]
|
||||
idx_a = node_to_idx[node_id_a]
|
||||
idx_b = node_to_idx[node_id_b]
|
||||
idx_c = node_to_idx[node_id_c]
|
||||
|
||||
logger.info(matrix)
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.jaccl_coordinators) == 3
|
||||
|
||||
@@ -1,48 +1,56 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import create_connection, create_node_profile
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
topology = Topology()
|
||||
return topology
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(1)
|
||||
connection2 = create_connection(2)
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology.add_connection(node1_id, node2_id, connection1)
|
||||
topology.add_connection(node2_id, node1_id, connection2)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
assert len(cycles) == 1
|
||||
assert len(cycles[0]) == 2
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_bytes(1)
|
||||
)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -50,65 +58,64 @@ def test_filter_cycles_by_memory():
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory():
|
||||
def test_filter_cycles_by_insufficient_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(1)
|
||||
connection2 = create_connection(2)
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology.add_connection(node1_id, node2_id, connection1)
|
||||
topology.add_connection(node2_id, node1_id, connection2)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
|
||||
topology.get_cycles(), Memory.from_kb(2001)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 0
|
||||
|
||||
|
||||
def test_filter_multiple_cycles_by_memory():
|
||||
def test_filter_multiple_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_kb(1500)
|
||||
)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -120,38 +127,31 @@ def test_filter_multiple_cycles_by_memory():
|
||||
}
|
||||
|
||||
|
||||
def test_get_smallest_cycles():
|
||||
def test_get_smallest_cycles(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
|
||||
cycles = [
|
||||
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||
for cycle in topology.get_cycles()
|
||||
]
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
smallest_cycles = get_smallest_cycles(topology.get_cycles())
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
@@ -168,6 +168,9 @@ def test_get_smallest_cycles():
|
||||
],
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
@@ -176,25 +179,19 @@ def test_get_shard_assignments(
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(available_memory[0] * 1024)
|
||||
node_b = create_node_profile(available_memory[1] * 1024)
|
||||
node_c = create_node_profile(available_memory[2] * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(available_memory[0] * 1024, node_a_id)
|
||||
node_b = create_node(available_memory[1] * 1024, node_b_id)
|
||||
node_c = create_node(available_memory[2] * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_c_id, create_connection(2))
|
||||
topology.add_connection(node_c_id, node_a_id, create_connection(3))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(4))
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -204,11 +201,7 @@ def test_get_shard_assignments(
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
cycles = [
|
||||
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||
for cycle in topology.get_cycles()
|
||||
]
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# act
|
||||
@@ -237,21 +230,28 @@ def test_get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph():
|
||||
def test_get_hosts_from_subgraph(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
node_a = create_node(500, node_a_id)
|
||||
node_b = create_node(500, node_b_id)
|
||||
node_c = create_node(1000, node_c_id)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
@@ -259,47 +259,108 @@ def test_get_hosts_from_subgraph():
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip=("169.254.0.2"), port=1234),
|
||||
Host(ip=("169.254.0.3"), port=1234),
|
||||
Host(ip=("169.254.0.4"), port=1234),
|
||||
Host(ip=("169.254.0.2"), port=5001),
|
||||
Host(ip=("169.254.0.3"), port=5002),
|
||||
Host(ip=("169.254.0.4"), port=5003),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
def test_get_mlx_jaccl_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||
|
||||
conn_a_b = create_connection(1)
|
||||
conn_b_a = create_connection(2)
|
||||
conn_b_c = create_connection(3)
|
||||
conn_c_b = create_connection(4)
|
||||
conn_c_a = create_connection(5)
|
||||
conn_a_c = create_connection(6)
|
||||
# Update node profiles with network interfaces before adding to topology
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, conn_a_b)
|
||||
topology.add_connection(node_b_id, node_a_id, conn_b_a)
|
||||
topology.add_connection(node_b_id, node_c_id, conn_b_c)
|
||||
topology.add_connection(node_c_id, node_b_id, conn_c_b)
|
||||
topology.add_connection(node_c_id, node_a_id, conn_c_a)
|
||||
topology.add_connection(node_a_id, node_c_id, conn_a_c)
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cycle = [node_a, node_b, node_c]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_jaccl_coordinators(
|
||||
node_a_id, coordinator_port=5000, cycle_digraph=topology
|
||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||
)
|
||||
|
||||
# assert
|
||||
@@ -328,11 +389,11 @@ def test_get_mlx_jaccl_coordinators():
|
||||
|
||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
|
||||
"node_b should use the IP from conn_b_a"
|
||||
)
|
||||
assert coordinators[node_b_id] == (
|
||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_b should use the IP from conn_b_a"
|
||||
|
||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
|
||||
"node_c should use the IP from conn_c_a"
|
||||
)
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -17,15 +16,20 @@ def topology() -> Topology:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection() -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
def connection() -> Connection:
|
||||
return Connection(
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryUsage.from_bytes(
|
||||
memory_profile = MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
)
|
||||
system_profile = SystemPerformanceProfile()
|
||||
@@ -39,85 +43,162 @@ def node_profile() -> NodePerformanceProfile:
|
||||
)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology):
|
||||
@pytest.fixture
|
||||
def connection_profile() -> ConnectionProfile:
|
||||
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
# arrange
|
||||
node_id = NodeId()
|
||||
|
||||
# act
|
||||
topology.add_node(node_id)
|
||||
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
|
||||
|
||||
# assert
|
||||
assert topology.node_is_leaf(node_id)
|
||||
data = topology.get_node_profile(node_id)
|
||||
assert data == node_profile
|
||||
|
||||
|
||||
def test_add_connection(topology: Topology, connection: SocketConnection):
|
||||
def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
data = list(conn for _, _, conn in topology.list_connections())
|
||||
data = topology.get_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
assert data == [connection]
|
||||
assert data == connection.connection_profile
|
||||
|
||||
assert topology.node_is_leaf(node_a)
|
||||
assert topology.node_is_leaf(node_b)
|
||||
|
||||
def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_node_profile(
|
||||
connection.local_node_id, node_profile=new_node_profile
|
||||
)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(
|
||||
throughput=2000, latency=2000, jitter=2000
|
||||
)
|
||||
connection = Connection(
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
data = topology.get_connection_profile(connection)
|
||||
assert data == new_connection_profile
|
||||
|
||||
|
||||
def test_remove_connection_still_connected(
|
||||
topology: Topology, connection: SocketConnection
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_connection(node_a, node_b, connection)
|
||||
topology.remove_connection(connection)
|
||||
|
||||
# assert
|
||||
assert list(topology.get_all_connections_between(node_a, node_b)) == []
|
||||
assert topology.get_connection_profile(connection) is None
|
||||
|
||||
|
||||
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
|
||||
def test_remove_node_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_node(node_b)
|
||||
topology.remove_node(connection.local_node_id)
|
||||
|
||||
# assert
|
||||
assert list(topology.out_edges(node_a)) == []
|
||||
assert topology.get_node_profile(connection.local_node_id) is None
|
||||
|
||||
|
||||
def test_list_nodes(topology: Topology, connection: SocketConnection):
|
||||
def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
nodes = list(topology.list_nodes())
|
||||
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeId) for node in nodes)
|
||||
assert {node for node in nodes} == {node_a, node_b}
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
}
|
||||
|
||||
@@ -11,8 +11,10 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeTimedOut,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
@@ -25,23 +27,13 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import RDMAConnection
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
MacmonMetrics,
|
||||
MacTBConnections,
|
||||
MacTBIdentifiers,
|
||||
MemoryUsage,
|
||||
MiscData,
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
)
|
||||
|
||||
|
||||
def event_apply(event: Event, state: State) -> State:
|
||||
@@ -55,12 +47,16 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case NodeCreated():
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodePerformanceMeasured():
|
||||
return apply_node_performance_measured(event, state)
|
||||
case NodeDownloadProgress():
|
||||
return apply_node_download_progress(event, state)
|
||||
case NodeGatheredInfo():
|
||||
return apply_node_gathered_info(event, state)
|
||||
case NodeMemoryMeasured():
|
||||
return apply_node_memory_measured(event, state)
|
||||
case RunnerDeleted():
|
||||
return apply_runner_deleted(event, state)
|
||||
case RunnerStatusUpdated():
|
||||
@@ -192,7 +188,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
|
||||
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology = copy.copy(state.topology)
|
||||
state.topology.remove_node(event.node_id)
|
||||
node_profiles = {
|
||||
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
||||
@@ -200,12 +196,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
}
|
||||
downloads = {
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"node_profiles": node_profiles,
|
||||
"last_seen": last_seen,
|
||||
@@ -213,69 +205,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
)
|
||||
|
||||
|
||||
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_node(event.node_id)
|
||||
info = event.info
|
||||
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
|
||||
# TODO: should be broken up into individual events instead of this monster
|
||||
match info:
|
||||
case MacmonMetrics():
|
||||
profile.system = info.system_profile
|
||||
profile.memory = info.memory
|
||||
case MemoryUsage():
|
||||
profile.memory = info
|
||||
case NodeConfig():
|
||||
pass
|
||||
case MiscData():
|
||||
profile.friendly_name = info.friendly_name
|
||||
case StaticNodeInformation():
|
||||
profile.model_id = info.model
|
||||
profile.chip_id = info.chip
|
||||
# TODO: makes me slightly sad
|
||||
case NodeNetworkInterfaces():
|
||||
profile.network_interfaces = info.ifaces
|
||||
case MacTBIdentifiers():
|
||||
profile.tb_interfaces = info.idents
|
||||
case MacTBConnections():
|
||||
conn_map = {
|
||||
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
|
||||
for nid in state.node_profiles
|
||||
for tb_ident in state.node_profiles[nid].tb_interfaces
|
||||
}
|
||||
as_rdma_conns = [
|
||||
(
|
||||
conn_map[tb_conn.sink_uuid][0],
|
||||
RDMAConnection(
|
||||
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
|
||||
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
|
||||
),
|
||||
)
|
||||
for tb_conn in info.conns
|
||||
if tb_conn.source_uuid in conn_map
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
|
||||
|
||||
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
|
||||
new_profiles = {**state.node_profiles, event.node_id: profile}
|
||||
def apply_node_performance_measured(
|
||||
event: NodePerformanceMeasured, state: State
|
||||
) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: event.node_profile,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
||||
topology = copy.copy(state.topology)
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": new_profiles,
|
||||
"last_seen": last_seen,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
|
||||
existing = state.node_profiles.get(event.node_id)
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
if existing is None:
|
||||
created = NodePerformanceProfile(
|
||||
model_id="unknown",
|
||||
chip_id="unknown",
|
||||
friendly_name="Unknown",
|
||||
memory=event.memory,
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(
|
||||
# TODO: flops_fp16=0.0,
|
||||
gpu_usage=0.0,
|
||||
temp=0.0,
|
||||
sys_power=0.0,
|
||||
pcpu_usage=0.0,
|
||||
ecpu_usage=0.0,
|
||||
ane_power=0.0,
|
||||
),
|
||||
)
|
||||
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: created,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
# TODO: NodeCreated
|
||||
topology.update_node_profile(event.node_id, created)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": created_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
updated = existing.model_copy(update={"memory": event.memory})
|
||||
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: updated,
|
||||
}
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, updated)
|
||||
return state.model_copy(
|
||||
update={"node_profiles": updated_profiles, "topology": topology}
|
||||
)
|
||||
|
||||
|
||||
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_connection(event.source, event.sink, event.edge)
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_connection(event.edge)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.remove_connection(event.sink, event.source, event.edge)
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
topology.remove_connection(event.edge)
|
||||
# TODO: Clean up removing the reverse connection
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@@ -38,7 +38,6 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
||||
|
||||
# Identity (config)
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
|
||||
@@ -24,8 +24,6 @@ class _InterceptHandler(logging.Handler):
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
return
|
||||
|
||||
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
|
||||
@@ -43,4 +43,7 @@ def test_apply_two_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event2), state
|
||||
)
|
||||
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import Connection
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
@@ -11,21 +11,17 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_a = NodeId("node-a")
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
connection = Connection(
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
)
|
||||
|
||||
state = State()
|
||||
state.topology.add_connection(node_a, node_b, connection)
|
||||
state.topology.add_connection(connection)
|
||||
|
||||
json_repr = state.model_dump_json()
|
||||
restored_state = State.model_validate_json(json_repr)
|
||||
|
||||
assert (
|
||||
state.topology.to_snapshot().nodes
|
||||
== restored_state.topology.to_snapshot().nodes
|
||||
)
|
||||
assert set(state.topology.to_snapshot().connections) == set(
|
||||
restored_state.topology.to_snapshot().connections
|
||||
)
|
||||
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
|
||||
assert restored_state.model_dump_json() == json_repr
|
||||
|
||||
@@ -1,215 +1,203 @@
|
||||
import contextlib
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
|
||||
import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
|
||||
|
||||
class TopologySnapshot(BaseModel):
|
||||
nodes: Sequence[NodeId]
|
||||
connections: Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]
|
||||
nodes: list[NodeInfo]
|
||||
connections: list[Connection]
|
||||
|
||||
model_config = ConfigDict(frozen=True, extra="forbid")
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Topology:
|
||||
# the _graph can be used as a int -> NodeId map.
|
||||
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
|
||||
init=False, default_factory=rx.PyDiGraph
|
||||
)
|
||||
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
|
||||
def __init__(self) -> None:
|
||||
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
|
||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
|
||||
|
||||
def to_snapshot(self) -> TopologySnapshot:
|
||||
return TopologySnapshot(
|
||||
nodes=list(self.list_nodes()), connections=self.list_connections()
|
||||
nodes=list(self.list_nodes()),
|
||||
connections=list(self.list_connections()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
|
||||
topology = cls()
|
||||
|
||||
for node_id in snapshot.nodes:
|
||||
for node in snapshot.nodes:
|
||||
with contextlib.suppress(ValueError):
|
||||
topology.add_node(node_id)
|
||||
topology.add_node(node)
|
||||
|
||||
for source, sink, conn in snapshot.connections:
|
||||
topology.add_connection(source, sink, conn)
|
||||
for connection in snapshot.connections:
|
||||
topology.add_connection(connection)
|
||||
|
||||
return topology
|
||||
|
||||
def add_node(self, node_id: NodeId) -> None:
|
||||
if node_id in self._vertex_indices:
|
||||
def add_node(self, node: NodeInfo) -> None:
|
||||
if node.node_id in self._node_id_to_rx_id_map:
|
||||
return
|
||||
rx_id = self._graph.add_node(node_id)
|
||||
self._vertex_indices[node_id] = rx_id
|
||||
rx_id = self._graph.add_node(node)
|
||||
self._node_id_to_rx_id_map[node.node_id] = rx_id
|
||||
self._rx_id_to_node_id_map[rx_id] = node.node_id
|
||||
|
||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||
return (
|
||||
node_id in self._vertex_indices
|
||||
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
|
||||
node_id in self._node_id_to_rx_id_map
|
||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
||||
)
|
||||
|
||||
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
||||
return [
|
||||
self._graph[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
|
||||
self._rx_id_to_node_id_map[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
|
||||
]
|
||||
|
||||
def out_edges(
|
||||
self, node_id: NodeId
|
||||
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
|
||||
if node_id not in self._vertex_indices:
|
||||
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return []
|
||||
return (
|
||||
(self._graph[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
|
||||
)
|
||||
return [
|
||||
(self._rx_id_to_node_id_map[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(
|
||||
self._node_id_to_rx_id_map[node_id]
|
||||
)
|
||||
]
|
||||
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._vertex_indices
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
|
||||
def contains_connection(self, connection: Connection) -> bool:
|
||||
return connection in self._edge_id_to_rx_id_map
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
source: NodeId,
|
||||
sink: NodeId,
|
||||
connection: SocketConnection | RDMAConnection,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
if connection in self.get_all_connections_between(source, sink):
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.local_node_id))
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
if source not in self._vertex_indices:
|
||||
self.add_node(source)
|
||||
if sink not in self._vertex_indices:
|
||||
self.add_node(sink)
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
|
||||
_ = self._graph.add_edge(src_id, sink_id, connection)
|
||||
def list_nodes(self) -> Iterable[NodeInfo]:
|
||||
return (self._graph[i] for i in self._graph.node_indices())
|
||||
|
||||
def get_all_connections_between(
|
||||
self, source: NodeId, sink: NodeId
|
||||
) -> Iterable[SocketConnection | RDMAConnection]:
|
||||
if source not in self._vertex_indices:
|
||||
return []
|
||||
if sink not in self._vertex_indices:
|
||||
return []
|
||||
def list_connections(self) -> Iterable[Connection]:
|
||||
return (connection for _, _, connection in self._graph.weighted_edge_list())
|
||||
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
|
||||
try:
|
||||
return self._graph.get_all_edge_data(src_id, sink_id)
|
||||
except rx.NoEdgeBetweenNodes:
|
||||
return []
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
return self._graph.get_node_data(rx_idx).node_profile
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def list_nodes(self) -> Iterable[NodeId]:
|
||||
return self._graph.nodes()
|
||||
def update_node_profile(
|
||||
self, node_id: NodeId, node_profile: NodePerformanceProfile
|
||||
) -> None:
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph[rx_idx].node_profile = node_profile
|
||||
|
||||
def map_connections(
|
||||
self,
|
||||
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
|
||||
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list():
|
||||
source = self._graph[src_id]
|
||||
sink = self._graph[sink_id]
|
||||
if source not in base:
|
||||
base[source] = {}
|
||||
if sink not in base[source]:
|
||||
base[source][sink] = []
|
||||
base[source][sink].append(connection)
|
||||
return base
|
||||
def update_connection_profile(self, connection: Connection) -> None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.update_edge_by_index(rx_idx, connection)
|
||||
|
||||
def list_connections(
|
||||
self,
|
||||
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
|
||||
return (
|
||||
(
|
||||
self._graph[src_id],
|
||||
self._graph[sink_id],
|
||||
connection,
|
||||
)
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list()
|
||||
)
|
||||
def get_connection_profile(
|
||||
self, connection: Connection
|
||||
) -> ConnectionProfile | None:
|
||||
try:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def remove_node(self, node_id: NodeId) -> None:
|
||||
if node_id not in self._vertex_indices:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
rx_idx = self._vertex_indices[node_id]
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_id
|
||||
or connection.send_back_node_id == node_id
|
||||
):
|
||||
self.remove_connection(connection)
|
||||
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph.remove_node(rx_idx)
|
||||
|
||||
del self._vertex_indices[node_id]
|
||||
del self._node_id_to_rx_id_map[node_id]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
|
||||
def replace_all_out_tb_connections(
|
||||
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
|
||||
) -> None:
|
||||
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
|
||||
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
for sink, conn in new_connections:
|
||||
self.add_connection(source, sink, conn)
|
||||
|
||||
def remove_connection(
|
||||
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
|
||||
) -> None:
|
||||
if source not in self._vertex_indices or sink not in self._vertex_indices:
|
||||
def remove_connection(self, connection: Connection) -> None:
|
||||
if connection not in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
for conn_idx in self._graph.edge_indices_from_endpoints(
|
||||
self._vertex_indices[source], self._vertex_indices[sink]
|
||||
):
|
||||
if self._graph.get_edge_data_by_index(conn_idx) == edge:
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.remove_edge_from_index(rx_idx)
|
||||
del self._edge_id_to_rx_id_map[connection]
|
||||
|
||||
def get_cycles(self) -> list[list[NodeId]]:
|
||||
def get_cycles(self) -> list[list[NodeInfo]]:
|
||||
cycle_idxs = rx.simple_cycles(self._graph)
|
||||
cycles: list[list[NodeId]] = []
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [self._graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def get_cycles_tb(self) -> list[list[NodeId]]:
|
||||
def get_cycles_tb(self) -> list[list[NodeInfo]]:
|
||||
tb_edges = [
|
||||
(u, v, conn)
|
||||
for u, v, conn in self._graph.weighted_edge_list()
|
||||
if conn.is_thunderbolt()
|
||||
]
|
||||
|
||||
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
if isinstance(conn, SocketConnection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycles: list[list[NodeId]] = []
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [tb_graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
|
||||
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
|
||||
node_idxs = [node.node_id for node in nodes]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
topology = Topology()
|
||||
for rx_idx in rx_idxs:
|
||||
topology.add_node(self._graph[rx_idx])
|
||||
for source, sink, connection in self.list_connections():
|
||||
if source in node_ids and sink in node_ids:
|
||||
topology.add_connection(source, sink, connection)
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id in node_idxs
|
||||
and connection.send_back_node_id in node_idxs
|
||||
):
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
|
||||
node_idxs = [node for node in cycle]
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
|
||||
node_idxs = [node.node_id for node in cycle]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
|
||||
@@ -2,14 +2,14 @@ from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import SocketConnection
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
# TODO
|
||||
class NodeCreated(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
class NodeTimedOut(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
# TODO: bikeshed this naem
|
||||
class NodeGatheredInfo(BaseEvent):
|
||||
class NodePerformanceMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class NodeMemoryMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
memory: MemoryPerformanceProfile
|
||||
|
||||
|
||||
class NodeDownloadProgress(BaseEvent):
|
||||
@@ -97,15 +107,11 @@ class ChunkGenerated(BaseEvent):
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
source: NodeId
|
||||
sink: NodeId
|
||||
edge: SocketConnection
|
||||
edge: Connection
|
||||
|
||||
|
||||
class TopologyEdgeDeleted(BaseEvent):
|
||||
source: NodeId
|
||||
sink: NodeId
|
||||
edge: SocketConnection
|
||||
edge: Connection
|
||||
|
||||
|
||||
Event = (
|
||||
@@ -119,8 +125,10 @@ Event = (
|
||||
| InstanceDeleted
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
| NodeTimedOut
|
||||
| NodeGatheredInfo
|
||||
| NodePerformanceMeasured
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| TopologyEdgeCreated
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||
from pydantic import BaseModel, computed_field, field_validator
|
||||
|
||||
|
||||
class Multiaddr(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
address: str
|
||||
|
||||
PATTERNS: ClassVar[list[str]] = [
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Self
|
||||
|
||||
import psutil
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.thunderbolt import TBIdentifier
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class MemoryUsage(CamelCaseModel):
|
||||
class MemoryPerformanceProfile(CamelCaseModel):
|
||||
ram_total: Memory
|
||||
ram_available: Memory
|
||||
swap_total: Memory
|
||||
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
sys_power: float = 0.0
|
||||
pcpu_usage: float = 0.0
|
||||
ecpu_usage: float = 0.0
|
||||
ane_power: float = 0.0
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
@@ -54,16 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
model_id: str = "Unknown"
|
||||
chip_id: str = "Unknown"
|
||||
friendly_name: str = "Unknown"
|
||||
memory: MemoryUsage = MemoryUsage.from_bytes(
|
||||
ram_total=0, ram_available=0, swap_total=0, swap_available=0
|
||||
)
|
||||
network_interfaces: Sequence[NetworkInterfaceInfo] = []
|
||||
tb_interfaces: Sequence[TBIdentifier] = []
|
||||
system: SystemPerformanceProfile = SystemPerformanceProfile()
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
memory: MemoryPerformanceProfile
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
class ConnectionProfile(CamelCaseModel):
|
||||
pass
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
import anyio
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class TBConnection(CamelCaseModel):
|
||||
source_uuid: str
|
||||
sink_uuid: str
|
||||
|
||||
|
||||
class TBIdentifier(CamelCaseModel):
|
||||
rdma_interface: str
|
||||
domain_uuid: str
|
||||
|
||||
|
||||
## Intentionally minimal, only collecting data we care about - there's a lot more
|
||||
|
||||
|
||||
class TBReceptacleTag(BaseModel, extra="ignore"):
|
||||
receptacle_id_key: str | None = None
|
||||
|
||||
|
||||
class TBConnectivityItem(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None = None
|
||||
|
||||
|
||||
class TBConnectivityData(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None = None
|
||||
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
|
||||
receptacle_1_tag: TBReceptacleTag | None = None
|
||||
|
||||
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
|
||||
if (
|
||||
self.domain_uuid_key is None
|
||||
or self.receptacle_1_tag is None
|
||||
or self.receptacle_1_tag.receptacle_id_key is None
|
||||
):
|
||||
return
|
||||
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
|
||||
assert tag in ifaces # doesn't need to be an assertion but im confident
|
||||
# if tag not in ifaces: return None
|
||||
iface = f"rdma_{ifaces[tag]}"
|
||||
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
|
||||
|
||||
def conn(self) -> TBConnection | None:
|
||||
if self.domain_uuid_key is None or self.items is None:
|
||||
return
|
||||
|
||||
sink_key = next(
|
||||
(
|
||||
item.domain_uuid_key
|
||||
for item in self.items
|
||||
if item.domain_uuid_key is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
if sink_key is None:
|
||||
return None
|
||||
|
||||
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
|
||||
|
||||
|
||||
class TBConnectivity(BaseModel, extra="ignore"):
|
||||
SPThunderboltDataType: list[TBConnectivityData] = []
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> list[TBConnectivityData] | None:
|
||||
proc = await anyio.run_process(
|
||||
["system_profiler", "SPThunderboltDataType", "-json"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
# Saving you from PascalCase while avoiding too much pydantic
|
||||
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType
|
||||
@@ -1,32 +1,37 @@
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class RDMAConnection(FrozenModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
class NodeInfo(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile | None = None
|
||||
|
||||
|
||||
class Connection(CamelCaseModel):
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
logger.warning("duh")
|
||||
return True
|
||||
|
||||
|
||||
# TODO
|
||||
class LinkType(str, Enum):
|
||||
Thunderbolt = "Thunderbolt"
|
||||
Ethernet = "Ethernet"
|
||||
WiFi = "WiFi"
|
||||
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
jaccl_devices: list[list[str | None]]
|
||||
ibv_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
|
||||
43
src/exo/shared/types/worker/resource_monitor.py
Normal file
43
src/exo/shared/types/worker/resource_monitor.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from typing import Callable
|
||||
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
|
||||
class ResourceCollector(ABC):
|
||||
@abstractmethod
|
||||
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
|
||||
|
||||
|
||||
class SystemResourceCollector(ResourceCollector):
|
||||
async def collect(self) -> SystemPerformanceProfile: ...
|
||||
|
||||
|
||||
class MemoryResourceCollector(ResourceCollector):
|
||||
async def collect(self) -> MemoryPerformanceProfile: ...
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
data_collectors: list[ResourceCollector]
|
||||
effect_handlers: set[
|
||||
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
|
||||
]
|
||||
|
||||
async def _collect(
|
||||
self,
|
||||
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
|
||||
tasks: list[
|
||||
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
|
||||
] = [collector.collect() for collector in self.data_collectors]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def collect(self) -> None:
|
||||
profiles = await self._collect()
|
||||
for profile in profiles:
|
||||
for effect_handler in self.effect_handlers:
|
||||
effect_handler(profile)
|
||||
@@ -50,7 +50,9 @@ class RunnerReady(BaseRunnerStatus):
|
||||
|
||||
|
||||
class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
"""Runner is processing requests and can accept more (continuous batching)."""
|
||||
|
||||
active_requests: int = 0
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tomllib
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from subprocess import CalledProcessError
|
||||
from typing import Self, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group, open_process
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_CONFIG_FILE
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
)
|
||||
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
|
||||
from exo.utils.channels import Sender
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .macmon import MacmonMetrics
|
||||
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
|
||||
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
|
||||
class StaticNodeInformation(TaggedModel):
|
||||
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||
|
||||
model: str
|
||||
chip: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
model, chip = await get_model_and_chip()
|
||||
return cls(model=model, chip=chip)
|
||||
|
||||
|
||||
class NodeNetworkInterfaces(TaggedModel):
|
||||
ifaces: Sequence[NetworkInterfaceInfo]
|
||||
|
||||
|
||||
class MacTBIdentifiers(TaggedModel):
|
||||
idents: Sequence[TBIdentifier]
|
||||
|
||||
|
||||
class MacTBConnections(TaggedModel):
|
||||
conns: Sequence[TBConnection]
|
||||
|
||||
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
# TODO
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
cfg_file = anyio.Path(EXO_CONFIG_FILE)
|
||||
await cfg_file.touch(exist_ok=True)
|
||||
async with await cfg_file.open("rb") as f:
|
||||
try:
|
||||
contents = (await f.read()).decode("utf-8")
|
||||
data = tomllib.loads(contents)
|
||||
return cls.model_validate(data)
|
||||
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Invalid config file, skipping...")
|
||||
return None
|
||||
|
||||
|
||||
class MiscData(TaggedModel):
|
||||
"""Node information that may slowly change that doesn't fall into the other categories"""
|
||||
|
||||
friendly_name: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
return cls(friendly_name=await get_friendly_name())
|
||||
|
||||
|
||||
async def _gather_iface_map() -> dict[str, str] | None:
|
||||
proc = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
|
||||
ports: dict[str, str] = {}
|
||||
port = ""
|
||||
for line in proc.stdout.decode("utf-8").split("\n"):
|
||||
if line.startswith("Hardware Port:"):
|
||||
port = line.split(": ")[1]
|
||||
elif line.startswith("Device:"):
|
||||
ports[port] = line.split(": ")[1]
|
||||
port = ""
|
||||
if "" in ports:
|
||||
del ports[""]
|
||||
return ports
|
||||
|
||||
|
||||
GatheredInfo = (
|
||||
MacmonMetrics
|
||||
| MemoryUsage
|
||||
| NodeNetworkInterfaces
|
||||
| MacTBIdentifiers
|
||||
| MacTBConnections
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfoGatherer:
|
||||
info_sender: Sender[GatheredInfo]
|
||||
interface_watcher_interval: float | None = 10
|
||||
misc_poll_interval: float | None = 60
|
||||
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
if (macmon_path := shutil.which("macmon")) is not None:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
if IS_DARWIN:
|
||||
tg.start_soon(self._monitor_system_profiler)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
|
||||
nc = await NodeConfig.gather()
|
||||
if nc is not None:
|
||||
await self.info_sender.send(nc)
|
||||
sni = await StaticNodeInformation.gather()
|
||||
await self.info_sender.send(sni)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
prev = await MiscData.gather()
|
||||
await self.info_sender.send(prev)
|
||||
while True:
|
||||
curr = await MiscData.gather()
|
||||
if prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
|
||||
async def _monitor_system_profiler(self):
|
||||
if self.system_profiler_interval is None:
|
||||
return
|
||||
iface_map = await _gather_iface_map()
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
old_idents = []
|
||||
while True:
|
||||
data = await TBConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
if idents != old_idents:
|
||||
await self.info_sender.send(MacTBIdentifiers(idents=idents))
|
||||
old_idents = idents
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacTBConnections(conns=conns))
|
||||
|
||||
await anyio.sleep(self.system_profiler_interval)
|
||||
|
||||
async def _monitor_memory_usage(self):
|
||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||
override_memory: int | None = (
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||
if override_memory_env
|
||||
else None
|
||||
)
|
||||
if self.memory_poll_rate is None:
|
||||
return
|
||||
while True:
|
||||
await self.info_sender.send(
|
||||
MemoryUsage.from_psutil(override_memory=override_memory)
|
||||
)
|
||||
await anyio.sleep(self.memory_poll_rate)
|
||||
|
||||
async def _watch_system_info(self):
|
||||
if self.interface_watcher_interval is None:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
# macmon pipe --interval [interval in ms]
|
||||
try:
|
||||
async with await open_process(
|
||||
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
|
||||
) as p:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
if stderr_output is not None:
|
||||
stderr_msg = (
|
||||
stderr_output.decode()
|
||||
if isinstance(stderr_output, bytes)
|
||||
else str(stderr_output)
|
||||
)
|
||||
logger.warning(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class _TempMetrics(BaseModel, extra="ignore"):
|
||||
"""Temperature-related metrics returned by macmon."""
|
||||
|
||||
cpu_temp_avg: float
|
||||
gpu_temp_avg: float
|
||||
|
||||
|
||||
class _MemoryMetrics(BaseModel, extra="ignore"):
|
||||
"""Memory-related metrics returned by macmon."""
|
||||
|
||||
ram_total: int
|
||||
ram_usage: int
|
||||
swap_total: int
|
||||
swap_usage: int
|
||||
|
||||
|
||||
class RawMacmonMetrics(BaseModel, extra="ignore"):
|
||||
"""Complete set of metrics returned by macmon.
|
||||
|
||||
Unknown fields are ignored for forward-compatibility.
|
||||
"""
|
||||
|
||||
timestamp: str # ignored
|
||||
temp: _TempMetrics
|
||||
memory: _MemoryMetrics
|
||||
ecpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
pcpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
gpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
all_power: float
|
||||
ane_power: float
|
||||
cpu_power: float
|
||||
gpu_power: float
|
||||
gpu_ram_power: float
|
||||
ram_power: float
|
||||
sys_power: float
|
||||
|
||||
|
||||
class MacmonMetrics(TaggedModel):
|
||||
system_profile: SystemPerformanceProfile
|
||||
memory: MemoryUsage
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
|
||||
return cls(
|
||||
system_profile=SystemPerformanceProfile(
|
||||
gpu_usage=raw.gpu_usage[1],
|
||||
temp=raw.temp.gpu_temp_avg,
|
||||
sys_power=raw.sys_power,
|
||||
pcpu_usage=raw.pcpu_usage[1],
|
||||
ecpu_usage=raw.ecpu_usage[1],
|
||||
),
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=raw.memory.ram_total,
|
||||
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
|
||||
swap_total=raw.memory.swap_total,
|
||||
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_raw_json(cls, json: str) -> Self:
|
||||
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))
|
||||
@@ -1,24 +0,0 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.thunderbolt import (
|
||||
TBConnectivity,
|
||||
)
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
_gather_iface_map, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "darwin", reason="TB info can only be gathered on macos"
|
||||
)
|
||||
async def test_tb_parsing():
|
||||
data = await TBConnectivity.gather()
|
||||
ifaces = await _gather_iface_map()
|
||||
assert ifaces
|
||||
assert data
|
||||
for datum in data:
|
||||
datum.ident(ifaces)
|
||||
datum.conn()
|
||||
@@ -19,20 +19,11 @@ class CamelCaseModel(BaseModel):
|
||||
alias_generator=to_camel,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
|
||||
strict=True,
|
||||
)
|
||||
|
||||
|
||||
class FrozenModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
strict=True,
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
|
||||
class TaggedModel(CamelCaseModel):
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler: SerializerFunctionWrapHandler):
|
||||
|
||||
@@ -28,8 +28,9 @@ def bar(send: MpSender[str]):
|
||||
send.close()
|
||||
|
||||
|
||||
# not async, just want the fail_after
|
||||
@pytest.mark.anyio
|
||||
async def test_channel_ipc():
|
||||
async def test_channel_setup():
|
||||
with fail_after(0.5):
|
||||
s, r = mp_channel[str]()
|
||||
p1 = mp.Process(target=foo, args=(r,))
|
||||
282
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
282
src/exo/worker/engines/mlx/generator/batch_engine.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""Tracks an active request in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
uid: int # BatchGenerator's internal ID
|
||||
detokenizer: StreamingDetokenizer
|
||||
tokens_generated: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedGenerationResponse:
|
||||
"""Response from batch engine, tagged with command_id and task_id."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
response: GenerationResponse
|
||||
|
||||
|
||||
class BatchGenerationEngine:
|
||||
"""Manages continuous batching using mlx_lm's BatchGenerator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
completion_batch_size: int = 32,
|
||||
prefill_batch_size: int = 8,
|
||||
prefill_step_size: int = 2048,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.active_requests: dict[int, ActiveRequest] = {}
|
||||
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
|
||||
self._pending_completions: list[int] = [] # UIDs completed but not yet synced/removed
|
||||
|
||||
self.group = group
|
||||
self.rank = group.rank() if group else 0
|
||||
self.is_distributed = group is not None and group.size() > 1
|
||||
|
||||
sampler = make_sampler(temp=0.7, top_p=1.0)
|
||||
|
||||
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
|
||||
|
||||
self.batch_gen: BatchGenerator = BatchGenerator(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
stop_tokens=eos_tokens,
|
||||
sampler=sampler,
|
||||
completion_batch_size=completion_batch_size,
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
prefill_step_size=prefill_step_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
|
||||
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
|
||||
)
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion. Only rank 0 should call this.
|
||||
|
||||
In distributed mode, rank 0 receives tasks from the control plane and
|
||||
queues them here. The actual insertion happens in sync_and_insert_pending()
|
||||
which ensures all ranks insert the same requests together.
|
||||
"""
|
||||
assert self.rank == 0, "Only rank 0 should queue requests"
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
logger.info(f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})")
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Sync pending inserts across ranks and insert them. Returns UIDs.
|
||||
|
||||
This method ensures all ranks insert the same requests in the same order.
|
||||
In non-distributed mode, it simply inserts all pending requests.
|
||||
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
|
||||
|
||||
Batches all pending inserts into a single batch_gen.insert() call for
|
||||
efficient prefill batching.
|
||||
"""
|
||||
inserts_to_process: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]]
|
||||
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: just insert directly from pending
|
||||
inserts_to_process = list(self._pending_inserts)
|
||||
else:
|
||||
# Distributed: broadcast pending inserts from rank 0 to all ranks
|
||||
assert self.group is not None
|
||||
pending_data = self._pending_inserts if self.rank == 0 else None
|
||||
synced_data = share_object(pending_data, self.rank, self.group)
|
||||
|
||||
if synced_data is None:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
inserts_to_process = synced_data
|
||||
|
||||
if not inserts_to_process:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
# Prepare all requests for batched insertion
|
||||
all_tokens: list[list[int]] = []
|
||||
all_max_tokens: list[int] = []
|
||||
all_prompt_tokens: list[int] = []
|
||||
request_info: list[tuple[CommandId, TaskId]] = []
|
||||
|
||||
for cmd_id, task_id, params in inserts_to_process:
|
||||
prompt_str = apply_chat_template(self.tokenizer, params)
|
||||
tokens: list[int] = self.tokenizer.encode(prompt_str, add_special_tokens=False)
|
||||
max_tokens = params.max_tokens or self.max_tokens
|
||||
|
||||
all_tokens.append(tokens)
|
||||
all_max_tokens.append(max_tokens)
|
||||
all_prompt_tokens.append(len(tokens))
|
||||
request_info.append((cmd_id, task_id))
|
||||
|
||||
# Single batched insert for efficient prefill
|
||||
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
|
||||
|
||||
# Track all inserted requests
|
||||
for i, uid in enumerate(uids):
|
||||
cmd_id, task_id = request_info[i]
|
||||
self.active_requests[uid] = ActiveRequest(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
uid=uid,
|
||||
detokenizer=self.tokenizer.detokenizer,
|
||||
prompt_tokens=all_prompt_tokens[i],
|
||||
)
|
||||
logger.info(f"Inserted request {cmd_id} with uid={uid}, prompt_tokens={all_prompt_tokens[i]}, max_tokens={all_max_tokens[i]}")
|
||||
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
|
||||
responses = self.batch_gen.next()
|
||||
if not responses:
|
||||
return []
|
||||
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
|
||||
for r in responses:
|
||||
uid: int = r.uid
|
||||
req = self.active_requests.get(uid)
|
||||
if req is None:
|
||||
logger.warning(f"Received response for unknown uid={uid}")
|
||||
continue
|
||||
|
||||
req.tokens_generated += 1
|
||||
|
||||
# Decode the token
|
||||
token: int = r.token
|
||||
req.detokenizer.add_token(token)
|
||||
text: str = req.detokenizer.last_segment
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
raw_finish_reason: str | None = r.finish_reason
|
||||
if raw_finish_reason is not None:
|
||||
# Finalize to get remaining text
|
||||
req.detokenizer.finalize()
|
||||
text = req.detokenizer.last_segment
|
||||
|
||||
elapsed = time.perf_counter() - req.start_time
|
||||
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=0.0, # Not tracked per-request in batch mode
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=req.prompt_tokens,
|
||||
generation_tokens=req.tokens_generated,
|
||||
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
|
||||
)
|
||||
|
||||
if raw_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif raw_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
|
||||
finish_reason = "stop"
|
||||
|
||||
# Track completion but don't remove yet - wait for sync_completions()
|
||||
self._pending_completions.append(uid)
|
||||
logger.info(f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}")
|
||||
|
||||
results.append(BatchedGenerationResponse(
|
||||
command_id=req.command_id,
|
||||
task_id=req.task_id,
|
||||
response=GenerationResponse(text=text, token=token, finish_reason=finish_reason, stats=stats),
|
||||
))
|
||||
|
||||
# In non-distributed mode, clean up completions immediately
|
||||
if not self.is_distributed:
|
||||
self._remove_completed()
|
||||
|
||||
return results
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
|
||||
if not self._pending_completions:
|
||||
return
|
||||
|
||||
if self.is_distributed:
|
||||
assert self.group is not None
|
||||
# Broadcast completions from rank 0 to ensure all ranks are in sync
|
||||
synced_uids = share_object(
|
||||
self._pending_completions if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
if synced_uids:
|
||||
self._pending_completions = synced_uids
|
||||
|
||||
self._remove_completed()
|
||||
|
||||
def _remove_completed(self) -> None:
|
||||
"""Remove completed requests from tracking."""
|
||||
for uid in self._pending_completions:
|
||||
if uid in self.active_requests:
|
||||
del self.active_requests[uid]
|
||||
self._pending_completions.clear()
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self.active_requests)
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return len(self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def has_pending_completions(self) -> bool:
|
||||
return bool(self._pending_completions)
|
||||
42
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
42
src/exo/worker/engines/mlx/generator/distributed_sync.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
import pickle
|
||||
from typing import TypeVar, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
|
||||
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
|
||||
logger.debug(f"share_object: rank={rank}, obj_type={type(obj).__name__ if obj else 'None'}")
|
||||
if rank == 0:
|
||||
if obj is None:
|
||||
logger.debug("share_object: rank 0 broadcasting None (size=0)")
|
||||
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
|
||||
logger.debug("share_object: rank 0 broadcast None complete")
|
||||
return None
|
||||
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
|
||||
logger.debug(f"share_object: rank 0 broadcasting size={data.size}")
|
||||
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
|
||||
logger.debug("share_object: rank 0 broadcasting data")
|
||||
mx.eval(mx.distributed.all_sum(data, group=group))
|
||||
logger.debug("share_object: rank 0 broadcast complete")
|
||||
return obj
|
||||
else:
|
||||
logger.debug("share_object: non-rank-0 waiting for size")
|
||||
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
|
||||
logger.debug(f"share_object: non-rank-0 received size={size}")
|
||||
if size == 0:
|
||||
return None
|
||||
data = mx.zeros(size, dtype=mx.uint8)
|
||||
logger.debug("share_object: non-rank-0 waiting for data")
|
||||
data = mx.distributed.all_sum(data, group=group)
|
||||
mx.eval(data)
|
||||
logger.debug("share_object: non-rank-0 received data")
|
||||
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))
|
||||
102
src/exo/worker/engines/mlx/generator/time_budget.py
Normal file
102
src/exo/worker/engines/mlx/generator/time_budget.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Time budget iterator for controlling generation loop timing in distributed mode.
|
||||
|
||||
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
|
||||
rather than syncing every token. This reduces distributed sync overhead.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class TimeBudget(Iterator[None]):
|
||||
"""Controls generation loop timing, syncing across ranks periodically.
|
||||
|
||||
In distributed mode, periodically syncs timing across all ranks to
|
||||
dynamically adjust iteration count based on actual performance.
|
||||
|
||||
In non-distributed mode, simply runs for the time budget.
|
||||
|
||||
Usage:
|
||||
for _ in TimeBudget(budget=0.5):
|
||||
batch_engine.step()
|
||||
# ... process responses ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
budget: float = 0.5,
|
||||
iterations: int = 25,
|
||||
sync_frequency: int = 10,
|
||||
group: mx.distributed.Group | None = None,
|
||||
):
|
||||
"""Initialize TimeBudget.
|
||||
|
||||
Args:
|
||||
budget: Time budget in seconds before yielding control
|
||||
iterations: Initial number of iterations per budget period (distributed only)
|
||||
sync_frequency: How often to sync timing across ranks (distributed only)
|
||||
group: Distributed group, or None for non-distributed mode
|
||||
"""
|
||||
self._budget = budget
|
||||
self._iterations = iterations
|
||||
self._sync_frequency = sync_frequency
|
||||
self._group = group
|
||||
self._is_distributed = group is not None and group.size() > 1
|
||||
|
||||
# Runtime state
|
||||
self._start: float = 0.0
|
||||
self._current_iterations: int = 0
|
||||
self._loops: int = 0
|
||||
self._time_spent: float = 0.0
|
||||
|
||||
def __iter__(self) -> "TimeBudget":
|
||||
self._start = time.perf_counter()
|
||||
self._current_iterations = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> None:
|
||||
if not self._is_distributed:
|
||||
# Non-distributed: just check time budget
|
||||
if time.perf_counter() - self._start > self._budget:
|
||||
raise StopIteration()
|
||||
return None
|
||||
|
||||
# Distributed mode: iteration-based with periodic timing sync
|
||||
self._current_iterations += 1
|
||||
if self._current_iterations > self._iterations:
|
||||
self._loops += 1
|
||||
self._time_spent += time.perf_counter() - self._start
|
||||
|
||||
if self._loops % self._sync_frequency == 0:
|
||||
# Sync timing across all ranks
|
||||
assert self._group is not None
|
||||
with mx.stream(generation_stream):
|
||||
time_array = mx.array([self._time_spent], dtype=mx.float32)
|
||||
total_time = mx.distributed.all_sum(time_array, group=self._group)
|
||||
mx.eval(total_time)
|
||||
loop_time = float(total_time.item())
|
||||
|
||||
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
|
||||
|
||||
if avg_loop_time > 0:
|
||||
factor = self._budget / avg_loop_time
|
||||
self._iterations = max(round(self._iterations * factor), 1)
|
||||
logger.debug(f"TimeBudget adjusted iterations to {self._iterations}")
|
||||
|
||||
self._loops = 0
|
||||
self._time_spent = 0.0
|
||||
|
||||
raise StopIteration()
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Current iterations per budget period."""
|
||||
return self._iterations
|
||||
@@ -144,26 +144,20 @@ def mlx_distributed_init(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
assert all(
|
||||
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
|
||||
)
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(jaccl_devices_json)
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
@@ -16,7 +16,8 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
TopologyEdgeCreated,
|
||||
@@ -24,6 +25,7 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
@@ -32,7 +34,7 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
@@ -43,14 +45,14 @@ from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
|
||||
from exo.worker.utils.net_profile import check_reachable
|
||||
|
||||
|
||||
class Worker:
|
||||
@@ -84,7 +86,7 @@ class Worker:
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
self._nack_cancel_scope: CancelScope | None = None
|
||||
self._nack_attempts: int = 0
|
||||
@@ -96,13 +98,37 @@ class Worker:
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
# TODO: CLEANUP HEADER
|
||||
async def resource_monitor_callback(
|
||||
node_performance_profile: NodePerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodePerformanceMeasured(
|
||||
node_id=self.node_id,
|
||||
node_profile=node_performance_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
async def memory_monitor_callback(
|
||||
memory_profile: MemoryPerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeMemoryMeasured(
|
||||
node_id=self.node_id,
|
||||
memory=memory_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
)
|
||||
)
|
||||
|
||||
# END CLEANUP
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
@@ -116,17 +142,6 @@ class Worker:
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
async for info in info_stream:
|
||||
await self.event_sender.send(
|
||||
NodeGatheredInfo(
|
||||
node_id=self.node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=info,
|
||||
)
|
||||
)
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
@@ -146,6 +161,7 @@ class Worker:
|
||||
self._nack_cancel_scope is None
|
||||
or self._nack_cancel_scope.cancel_called
|
||||
):
|
||||
assert self._tg
|
||||
# Request the next index.
|
||||
self._tg.start_soon(
|
||||
self._nack_request, self.state.last_event_applied_idx + 1
|
||||
@@ -236,7 +252,8 @@ class Worker:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
if self._tg:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
@@ -253,24 +270,24 @@ class Worker:
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
return TopologyEdgeCreated(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
return TopologyEdgeDeleted(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
@@ -319,6 +336,7 @@ class Worker:
|
||||
event_sender=self.event_sender.clone(),
|
||||
)
|
||||
self.runners[task.bound_instance.bound_runner_id] = runner
|
||||
assert self._tg
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
@@ -381,6 +399,7 @@ class Worker:
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
assert self._tg
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
@@ -403,9 +422,7 @@ class Worker:
|
||||
while True:
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(
|
||||
self.state.topology, self.state.node_profiles, self.node_id
|
||||
)
|
||||
conns = await check_reachable(self.state.topology, self.node_id)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
@@ -413,31 +430,26 @@ class Worker:
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = SocketConnection(
|
||||
edge = Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=nid,
|
||||
# nonsense multiaddr
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
if "." in ip
|
||||
# nonsense multiaddr
|
||||
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
||||
)
|
||||
if edge not in edges:
|
||||
logger.debug(f"ping discovered {edge=}")
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeCreated(
|
||||
source=self.node_id, sink=nid, edge=edge
|
||||
)
|
||||
)
|
||||
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
|
||||
|
||||
for nid, conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn, SocketConnection):
|
||||
continue
|
||||
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
|
||||
nid, set()
|
||||
if (
|
||||
nid not in conns
|
||||
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
|
||||
):
|
||||
logger.debug(f"ping failed to discover {conn=}")
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
|
||||
)
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
|
||||
@@ -277,12 +277,14 @@ def _pending_tasks(
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
@@ -19,7 +19,7 @@ def entrypoint(
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.jaccl_devices) >= 2
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import gc
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from anyio import WouldBlock
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
@@ -11,6 +13,7 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -18,12 +21,11 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
@@ -39,7 +41,10 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.generator.batch_engine import BatchGenerationEngine
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.generator.generate import warmup_inference
|
||||
from exo.worker.engines.mlx.generator.time_budget import TimeBudget
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
@@ -69,142 +74,258 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
batch_engine: BatchGenerationEngine | None = None
|
||||
pending_shutdown: Shutdown | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
|
||||
def send_status(status: RunnerStatus) -> None:
|
||||
event_sender.send(RunnerStatusUpdated(runner_id=runner_id, runner_status=status))
|
||||
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
send_status(current_status)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
def handle_task(task: Task, is_deferred: bool = False) -> bool:
|
||||
nonlocal current_status, model, tokenizer, group, batch_engine, pending_shutdown
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
# For Shutdown, check if we need to defer BEFORE sending Running/Acknowledged
|
||||
if isinstance(task, Shutdown) and not is_deferred:
|
||||
if batch_engine is not None and (batch_engine.has_active_requests or batch_engine.has_pending_inserts):
|
||||
logger.info("deferring shutdown until active requests complete")
|
||||
pending_shutdown = task
|
||||
return True
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running))
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
send_status(current_status)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
send_status(current_status)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model is not None
|
||||
assert tokenizer is not None
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
send_status(current_status)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(model=model, tokenizer=tokenizer)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(f"runner initialized in {time.time() - setup_start_time} seconds")
|
||||
|
||||
batch_engine = BatchGenerationEngine(model=model, tokenizer=tokenizer, group=group)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
send_status(current_status)
|
||||
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, (RunnerReady, RunnerRunning)):
|
||||
assert batch_engine is not None
|
||||
|
||||
# In distributed mode, only rank 0 should queue requests
|
||||
# Other ranks should skip - they'll participate in sync_and_insert_pending()
|
||||
is_distributed_mode = group is not None and isinstance(group.size(), int) and group.size() > 1
|
||||
if is_distributed_mode and shard_metadata.device_rank != 0:
|
||||
logger.debug(f"Rank {shard_metadata.device_rank} skipping ChatCompletionTask (only rank 0 queues)")
|
||||
return True
|
||||
|
||||
if task_params.messages and task_params.messages[0].content is not None:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
# Queue the request - actual insertion happens in sync_and_insert_pending()
|
||||
batch_engine.queue_request(command_id=command_id, task_id=task.task_id, task_params=task_params)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
# Status will be updated after actual insertion in the main loop
|
||||
# For now, set to RunnerRunning to indicate we're processing
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count + batch_engine.pending_insert_count)
|
||||
send_status(current_status)
|
||||
|
||||
gc.collect()
|
||||
break
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
send_status(current_status)
|
||||
event_sender.send(TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete))
|
||||
current_status = RunnerShutdown()
|
||||
send_status(current_status)
|
||||
return False
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
with task_receiver as tasks:
|
||||
running = True
|
||||
is_rank_0 = shard_metadata.device_rank == 0
|
||||
is_distributed = group is not None and group.size() > 1
|
||||
|
||||
# Accumulated responses to send at budget boundary
|
||||
pending_responses: list[tuple[CommandId, TaskId, GenerationResponse]] = []
|
||||
|
||||
while running:
|
||||
if is_distributed:
|
||||
assert group is not None
|
||||
assert batch_engine is not None
|
||||
|
||||
# Distributed mode: sync at time budget boundaries, not per-token
|
||||
# Step 1: Only rank 0 checks for new tasks
|
||||
should_shutdown = False
|
||||
if is_rank_0:
|
||||
while True:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
task_result = handle_task(task)
|
||||
if not task_result:
|
||||
should_shutdown = True
|
||||
break
|
||||
except WouldBlock:
|
||||
break
|
||||
|
||||
# Check for deferred shutdown when no more active requests
|
||||
if pending_shutdown is not None and not batch_engine.has_active_requests and not batch_engine.has_pending_inserts:
|
||||
should_shutdown = True
|
||||
|
||||
# Step 2: Sync shutdown flag across all ranks
|
||||
synced_shutdown = share_object(should_shutdown if is_rank_0 else None, shard_metadata.device_rank, group)
|
||||
if synced_shutdown:
|
||||
running = False
|
||||
if is_rank_0 and pending_shutdown is not None:
|
||||
handle_task(pending_shutdown, is_deferred=True)
|
||||
continue
|
||||
|
||||
# Step 3: Sync and insert any pending requests
|
||||
# All ranks must participate in sync_and_insert_pending (collective op)
|
||||
# If rank 0 has no pending, it broadcasts empty list and all return early
|
||||
inserted = batch_engine.sync_and_insert_pending()
|
||||
if is_rank_0 and inserted:
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count)
|
||||
send_status(current_status)
|
||||
|
||||
# Step 4: Run generation for time budget (no per-token sync)
|
||||
if batch_engine.has_active_requests:
|
||||
time_budget = TimeBudget(budget=0.5, group=group)
|
||||
for _ in time_budget:
|
||||
if not batch_engine.has_active_requests:
|
||||
break
|
||||
for resp in batch_engine.step():
|
||||
# Accumulate responses to send at budget boundary
|
||||
pending_responses.append((resp.command_id, resp.task_id, resp.response))
|
||||
|
||||
# Step 5: Sync completions at budget boundary
|
||||
batch_engine.sync_completions()
|
||||
|
||||
# Step 6: Send accumulated responses (only rank 0)
|
||||
if is_rank_0:
|
||||
for command_id, task_id, response in pending_responses:
|
||||
event_sender.send(ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
))
|
||||
if response.finish_reason is not None:
|
||||
event_sender.send(TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Complete))
|
||||
|
||||
if batch_engine.has_active_requests:
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count)
|
||||
else:
|
||||
current_status = RunnerReady()
|
||||
send_status(current_status)
|
||||
|
||||
pending_responses.clear()
|
||||
|
||||
# Short sleep if nothing to do (will check for new tasks on next iteration)
|
||||
if not batch_engine.has_active_requests and not batch_engine.has_pending_inserts:
|
||||
time.sleep(0.001)
|
||||
|
||||
else:
|
||||
# Non-distributed mode: original logic with queue + insert
|
||||
while True:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
running = handle_task(task)
|
||||
if not running:
|
||||
break
|
||||
except WouldBlock:
|
||||
break
|
||||
|
||||
if not running:
|
||||
break
|
||||
|
||||
# Insert any queued requests (non-distributed just inserts directly)
|
||||
# Status was already sent in handle_task when queueing
|
||||
if batch_engine is not None and batch_engine.has_pending_inserts:
|
||||
batch_engine.sync_and_insert_pending()
|
||||
|
||||
if batch_engine is not None and batch_engine.has_active_requests:
|
||||
for resp in batch_engine.step():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(ChunkGenerated(
|
||||
command_id=resp.command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=resp.response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=resp.response.text,
|
||||
token_id=resp.response.token,
|
||||
finish_reason=resp.response.finish_reason,
|
||||
stats=resp.response.stats,
|
||||
),
|
||||
))
|
||||
if resp.response.finish_reason is not None:
|
||||
event_sender.send(TaskStatusUpdated(task_id=resp.task_id, task_status=TaskStatus.Complete))
|
||||
|
||||
if batch_engine.has_active_requests:
|
||||
current_status = RunnerRunning(active_requests=batch_engine.active_count)
|
||||
else:
|
||||
current_status = RunnerReady()
|
||||
send_status(current_status)
|
||||
|
||||
# Process deferred shutdown after all requests complete
|
||||
if pending_shutdown is not None and not batch_engine.has_active_requests and not batch_engine.has_pending_inserts:
|
||||
running = handle_task(pending_shutdown, is_deferred=True)
|
||||
else:
|
||||
task = tasks.receive()
|
||||
running = handle_task(task)
|
||||
|
||||
# Cleanup
|
||||
del model, tokenizer, group, batch_engine
|
||||
mx.clear_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
|
||||
@@ -105,7 +105,7 @@ class RunnerSupervisor:
|
||||
return
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
logger.warning("Runner process didn't shutdown successfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
@@ -128,9 +128,11 @@ class RunnerSupervisor:
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
logger.info(f"Skipping task {task.task_id} - already completed")
|
||||
return
|
||||
if task.task_id in self.pending:
|
||||
logger.info(f"Skipping task {task.task_id} - already pending")
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -149,13 +151,17 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
# Just set the event to unblock start_task, but keep in pending
|
||||
# to prevent duplicate forwarding until completion
|
||||
if event.task_id in self.pending:
|
||||
self.pending[event.task_id].set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
and event.task_status == TaskStatus.Complete
|
||||
if isinstance(event, TaskStatusUpdated) and event.task_status in (
|
||||
TaskStatus.Complete,
|
||||
TaskStatus.TimedOut,
|
||||
TaskStatus.Failed,
|
||||
):
|
||||
# If a task has just been completed, we should be working on it.
|
||||
# If a task has just finished, we should be working on it.
|
||||
assert isinstance(
|
||||
self.status,
|
||||
(
|
||||
@@ -166,6 +172,8 @@ class RunnerSupervisor:
|
||||
RunnerShuttingDown,
|
||||
),
|
||||
)
|
||||
# Now safe to remove from pending and add to completed
|
||||
self.pending.pop(event.task_id, None)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
|
||||
@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Tests for continuous batching behavior in the runner.
|
||||
|
||||
These tests verify that:
|
||||
1. Single requests work through the batch path
|
||||
2. Multiple concurrent requests batch together
|
||||
3. Tokens are routed to the correct requests
|
||||
4. Requests complete at different times appropriately
|
||||
"""
|
||||
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerRunning
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""
|
||||
Fake batch engine that generates a specified number of tokens per request.
|
||||
|
||||
This simulates realistic batch generation behavior where:
|
||||
- Requests are queued on insert
|
||||
- Each step() call generates one token for all active requests
|
||||
- Requests complete when they've generated all their tokens
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
|
||||
self._uid_counter = 0
|
||||
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._do_insert(command_id, task_id, task_params)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def _do_insert(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams | None,
|
||||
) -> int:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
# Track: (command_id, task_id, tokens_generated, max_tokens)
|
||||
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish_reason = "stop" if tokens_gen >= max_tokens else None
|
||||
text = f"token{tokens_gen}"
|
||||
|
||||
if finish_reason:
|
||||
uids_to_remove.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (
|
||||
command_id,
|
||||
task_id,
|
||||
tokens_gen,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for uid in uids_to_remove:
|
||||
del self._active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
|
||||
"""
|
||||
Run tasks through the runner, adding shutdown at the end.
|
||||
|
||||
Tasks are sent in order, with shutdown sent last.
|
||||
The batch engine processes between task handling.
|
||||
"""
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
shutdown_task = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
with task_sender, event_receiver:
|
||||
# Send all tasks including shutdown
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown_task)
|
||||
|
||||
# Disable cleanup methods to prevent issues
|
||||
event_sender.close = lambda: None
|
||||
event_sender.join = lambda: None
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> ChatCompletion:
|
||||
return ChatCompletion(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_single_request_generates_tokens(patch_batch_engine: None):
|
||||
"""
|
||||
Verify a single request generates the expected tokens through the batch path.
|
||||
|
||||
Note: With the current non-blocking design, shutdown is processed before
|
||||
batch steps run when all tasks are queued together. This test verifies
|
||||
the runner status reflects active requests.
|
||||
"""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events - this shows the request was inserted
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
|
||||
"""Verify RunnerRunning status includes active_requests count."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_chat_task_acknowledged(patch_batch_engine: None):
|
||||
"""Verify chat completion task is acknowledged with proper status updates."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find the chat task status events
|
||||
chat_running = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Running
|
||||
]
|
||||
|
||||
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
|
||||
|
||||
|
||||
def test_multiple_requests_tracked(patch_batch_engine: None):
|
||||
"""Verify multiple concurrent requests are tracked in active_requests."""
|
||||
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
# Should have at least 2 RunnerRunning events (one per request inserted)
|
||||
assert len(running_events) >= 2, f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
|
||||
|
||||
# First should have 1 active request, second should have 2
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
assert running_events[1].runner_status.active_requests == 2
|
||||
@@ -1,11 +1,16 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
@@ -22,6 +27,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
@@ -38,6 +44,9 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
@@ -107,18 +116,85 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
class FakeBatchEngine:
|
||||
"""
|
||||
Fake batch engine for testing.
|
||||
|
||||
Queues requests on insert, returns one token per step.
|
||||
The runner's non-blocking loop drains all tasks before running batch steps,
|
||||
so this engine queues requests and has_active_requests returns True only
|
||||
after at least one request has been inserted.
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId]] = {}
|
||||
self._pending_inserts: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]] = []
|
||||
self._uid_counter = 0
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, _task_params in self._pending_inserts:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
self._active_requests[uid] = (command_id, task_id)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
# Process all active requests - return one token and complete
|
||||
for uid, (command_id, task_id) in list(self._active_requests.items()):
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=0,
|
||||
text="hi",
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
)
|
||||
del self._active_requests[uid]
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
|
||||
# initialize_mlx returns a fake "group" (non-None for state machine)
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MagicMock()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock())))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngine)
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
@@ -148,7 +224,8 @@ def _run(tasks: Iterable[Task]):
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
def test_chat_completion_generates_and_completes(patch_out_mlx: pytest.MonkeyPatch):
|
||||
"""Verify chat completion generates tokens, completes, and runner returns to Ready."""
|
||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||
|
||||
expected_chunk = ChunkGenerated(
|
||||
@@ -191,7 +268,9 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerRunning(active_requests=1)
|
||||
),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
@@ -206,7 +285,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
|
||||
6
src/exo/worker/utils/__init__.py
Normal file
6
src/exo/worker/utils/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .profile import start_polling_memory_metrics, start_polling_node_metrics
|
||||
|
||||
__all__ = [
|
||||
"start_polling_node_metrics",
|
||||
"start_polling_memory_metrics",
|
||||
]
|
||||
103
src/exo/worker/utils/macmon.py
Normal file
103
src/exo/worker/utils/macmon.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import platform
|
||||
import shutil
|
||||
from subprocess import CalledProcessError
|
||||
from typing import cast
|
||||
|
||||
from anyio import run_process
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
|
||||
class MacMonError(Exception):
|
||||
"""Exception raised for errors in the MacMon functions."""
|
||||
|
||||
|
||||
def _get_binary_path() -> str:
|
||||
"""
|
||||
Get the path to the macmon binary.
|
||||
|
||||
Raises:
|
||||
MacMonError: If the binary doesn't exist or can't be made executable.
|
||||
"""
|
||||
# Check for macOS with ARM chip
|
||||
system = platform.system().lower()
|
||||
machine = platform.machine().lower()
|
||||
|
||||
if system != "darwin" or not (
|
||||
"arm" in machine or "m1" in machine or "m2" in machine
|
||||
):
|
||||
raise MacMonError("MacMon only supports macOS with Apple Silicon (ARM) chips")
|
||||
|
||||
path = shutil.which("macmon")
|
||||
|
||||
if path is None:
|
||||
raise MacMonError("MacMon not found in PATH")
|
||||
|
||||
return path
|
||||
|
||||
|
||||
class TempMetrics(BaseModel):
|
||||
"""Temperature-related metrics returned by macmon."""
|
||||
|
||||
cpu_temp_avg: float
|
||||
gpu_temp_avg: float
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
class Metrics(BaseModel):
|
||||
"""Complete set of metrics returned by macmon.
|
||||
|
||||
Unknown fields are ignored for forward-compatibility.
|
||||
"""
|
||||
|
||||
all_power: float
|
||||
ane_power: float
|
||||
cpu_power: float
|
||||
ecpu_usage: tuple[int, float]
|
||||
gpu_power: float
|
||||
gpu_ram_power: float
|
||||
gpu_usage: tuple[int, float]
|
||||
pcpu_usage: tuple[int, float]
|
||||
ram_power: float
|
||||
sys_power: float
|
||||
temp: TempMetrics
|
||||
timestamp: str
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics:
|
||||
"""
|
||||
Asynchronously run the binary and return the metrics as a Python dictionary.
|
||||
|
||||
Args:
|
||||
binary_path: Optional path to the binary. If not provided, will use the bundled binary.
|
||||
|
||||
Returns:
|
||||
A mapping containing system metrics.
|
||||
|
||||
Raises:
|
||||
MacMonError: If there's an error running the binary.
|
||||
"""
|
||||
path = _get_binary_path()
|
||||
|
||||
try:
|
||||
# TODO: Keep Macmon running in the background?
|
||||
result = await run_process([path, "pipe", "-s", "1"])
|
||||
|
||||
return Metrics.model_validate_json(result.stdout.decode().strip())
|
||||
|
||||
except ValidationError as e:
|
||||
raise MacMonError(f"Error parsing JSON output: {e}") from e
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
if stderr_output is not None:
|
||||
stderr_msg = (
|
||||
stderr_output.decode()
|
||||
if isinstance(stderr_output, bytes)
|
||||
else str(stderr_output)
|
||||
)
|
||||
raise MacMonError(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
) from e
|
||||
@@ -1,12 +1,10 @@
|
||||
import http.client
|
||||
from collections.abc import Mapping
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
@@ -60,20 +58,19 @@ async def check_reachability(
|
||||
|
||||
|
||||
async def check_reachable(
|
||||
topology: Topology,
|
||||
profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
self_node_id: NodeId,
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
async with create_task_group() as tg:
|
||||
for node_id in topology.list_nodes():
|
||||
if node_id not in profiles:
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
continue
|
||||
for iface in profiles[node_id].network_interfaces:
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node_id,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
)
|
||||
114
src/exo/worker/utils/profile.py
Normal file
114
src/exo/worker/utils/profile.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
from .macmon import (
|
||||
MacMonError,
|
||||
Metrics,
|
||||
)
|
||||
from .macmon import (
|
||||
get_metrics_async as macmon_get_metrics_async,
|
||||
)
|
||||
from .system_info import (
|
||||
get_friendly_name,
|
||||
get_model_and_chip,
|
||||
get_network_interfaces,
|
||||
)
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics | None:
|
||||
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
|
||||
|
||||
if platform.system().lower() == "darwin":
|
||||
return await macmon_get_metrics_async()
|
||||
|
||||
|
||||
def get_memory_profile() -> MemoryPerformanceProfile:
|
||||
"""Construct a MemoryPerformanceProfile using psutil"""
|
||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||
override_memory: int | None = (
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||
if override_memory_env
|
||||
else None
|
||||
)
|
||||
|
||||
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
|
||||
|
||||
|
||||
async def start_polling_memory_metrics(
|
||||
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 0.5,
|
||||
) -> None:
|
||||
"""Continuously poll and emit memory-only metrics at a faster cadence.
|
||||
|
||||
Parameters
|
||||
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
|
||||
- poll_interval_s: interval between polls
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
mem = get_memory_profile()
|
||||
await callback(mem)
|
||||
except MacMonError as e:
|
||||
logger.opt(exception=e).error("Memory Monitor encountered error")
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_node_metrics(
|
||||
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
|
||||
):
|
||||
poll_interval_s = 1.0
|
||||
while True:
|
||||
try:
|
||||
metrics = await get_metrics_async()
|
||||
if metrics is None:
|
||||
return
|
||||
|
||||
network_interfaces = get_network_interfaces()
|
||||
# these awaits could be joined but realistically they should be cached
|
||||
model_id, chip_id = await get_model_and_chip()
|
||||
friendly_name = await get_friendly_name()
|
||||
|
||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
||||
memory_profile = get_memory_profile()
|
||||
|
||||
await callback(
|
||||
NodePerformanceProfile(
|
||||
model_id=model_id,
|
||||
chip_id=chip_id,
|
||||
friendly_name=friendly_name,
|
||||
network_interfaces=network_interfaces,
|
||||
memory=memory_profile,
|
||||
system=SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
sys_power=metrics.sys_power,
|
||||
pcpu_usage=metrics.pcpu_usage[1],
|
||||
ecpu_usage=metrics.ecpu_usage[1],
|
||||
ane_power=metrics.ane_power,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
|
||||
)
|
||||
except MacMonError as e:
|
||||
logger.opt(exception=e).error("Resource Monitor encountered error")
|
||||
return
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
77
src/exo/worker/utils/tests/test_macmon.py
Normal file
77
src/exo/worker/utils/tests/test_macmon.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Tests for macmon error handling.
|
||||
|
||||
These tests verify that MacMon errors are handled gracefully without
|
||||
crashing the application or spamming logs.
|
||||
"""
|
||||
|
||||
import platform
|
||||
from subprocess import CalledProcessError
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.worker.utils.macmon import MacMonError, get_metrics_async
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system().lower() != "darwin" or "arm" not in platform.machine().lower(),
|
||||
reason="MacMon only supports macOS with Apple Silicon",
|
||||
)
|
||||
class TestMacMonErrorHandling:
|
||||
"""Test MacMon error handling."""
|
||||
|
||||
async def test_called_process_error_wrapped_as_macmon_error(self) -> None:
|
||||
"""CalledProcessError should be wrapped as MacMonError."""
|
||||
mock_error = CalledProcessError(
|
||||
returncode=1,
|
||||
cmd=["macmon", "pipe", "-s", "1"],
|
||||
stderr=b"some error message",
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
|
||||
),
|
||||
patch(
|
||||
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
|
||||
) as mock_run,
|
||||
):
|
||||
mock_run.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MacMonError) as exc_info:
|
||||
await get_metrics_async()
|
||||
|
||||
assert "MacMon failed with return code 1" in str(exc_info.value)
|
||||
assert "some error message" in str(exc_info.value)
|
||||
|
||||
async def test_called_process_error_with_no_stderr(self) -> None:
|
||||
"""CalledProcessError with no stderr should be handled gracefully."""
|
||||
mock_error = CalledProcessError(
|
||||
returncode=1,
|
||||
cmd=["macmon", "pipe", "-s", "1"],
|
||||
stderr=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
|
||||
),
|
||||
patch(
|
||||
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
|
||||
) as mock_run,
|
||||
):
|
||||
mock_run.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MacMonError) as exc_info:
|
||||
await get_metrics_async()
|
||||
|
||||
assert "MacMon failed with return code 1" in str(exc_info.value)
|
||||
assert "no stderr" in str(exc_info.value)
|
||||
|
||||
async def test_macmon_not_found_raises_macmon_error(self) -> None:
|
||||
"""When macmon is not found in PATH, MacMonError should be raised."""
|
||||
with patch("exo.worker.utils.macmon.shutil.which", return_value=None):
|
||||
with pytest.raises(MacMonError) as exc_info:
|
||||
await get_metrics_async()
|
||||
|
||||
assert "MacMon not found in PATH" in str(exc_info.value)
|
||||
@@ -34,8 +34,7 @@ from exo.shared.types.worker.instances import (
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.channels import MpReceiver, MpSender, mp_channel
|
||||
from exo.worker.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
@@ -66,7 +65,6 @@ async def main():
|
||||
app = FastAPI()
|
||||
app.post("/ring")(ring_backend)
|
||||
app.post("/jaccl")(jaccl_backend)
|
||||
app.post("/tb_detection")(tb_detection)
|
||||
shutdown = anyio.Event()
|
||||
await serve(
|
||||
app, # type: ignore
|
||||
@@ -78,15 +76,6 @@ async def main():
|
||||
shutdown.set()
|
||||
|
||||
|
||||
async def tb_detection():
|
||||
send, recv = channel[GatheredInfo]()
|
||||
ig = InfoGatherer(send)
|
||||
with anyio.move_on_after(1):
|
||||
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
|
||||
with recv:
|
||||
return recv.collect()
|
||||
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
@@ -220,16 +209,16 @@ async def jaccl_backend(test: Tests):
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid), hn)
|
||||
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
world_size = len(test.devs)
|
||||
|
||||
return MlxJacclInstance(
|
||||
instance_id=iid,
|
||||
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
|
||||
ibv_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
|
||||
# rank 0 is always coordinator
|
||||
jaccl_coordinators={
|
||||
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
|
||||
|
||||
Reference in New Issue
Block a user