Compare commits

..

9 Commits

Author SHA1 Message Date
Jake Hillion
563e94ab6c feat(downloads): add ETA and speed tracking for torrent downloads
Use download_speed from rqbit (already exposed via PyO3 bindings) to
calculate and report real-time ETA for torrent downloads. Previously
these values were hardcoded to zero.
2026-01-12 15:12:56 +00:00
Jake Hillion
41c27832d9 feat(downloads): integrate torrent downloads with file filtering
- Add get_torrent_file_list() to parse torrent metadata and return file list
- Expose get_torrent_file_list via PyO3 bindings
- Update TorrentShardDownloader to download all torrent variants
- Apply same file filtering as HTTP downloads using resolve_allow_patterns()
- Fix RepoDownloadProgress constructor usage

When --enable-torrents is set, all embedded torrent variants (e.g., "small",
"large") are now downloaded with proper file filtering based on shard layers.
2026-01-12 15:12:56 +00:00
Jake Hillion
9c04350e55 feat(downloads): embed multiple torrent variants per model
- Add 52 torrent files (small and large variants) for all supported models
- Update get_embedded_torrents() to return Vec<(variant, data)> for all variants
- Remove old single-torrent API in favor of multi-variant API
- Update Python bindings to return list of (variant, bytes) tuples
2026-01-12 15:12:56 +00:00
Jake Hillion
1ca732da23 feat(downloads): switch to rqbit fork and implement fake tracker
- Update librqbit dependency to JakeHillion/rqbit fork (v9.0.0-beta.1)
- Update session.rs for new librqbit API:
  - SessionPersistenceConfig instead of bool
  - AddTorrentResponse enum handling
  - Updated TorrentStats field access
- Implement bencode module with AnnounceParams and BencodeValue encoding
- Implement fake tracker that generates bencoded announce responses
  from Exo topology data (PeerInfo, TopologyData, handle_announce)
- Fix PyO3 bindings for new librqbit types (remove allow_threads)
2026-01-12 15:12:56 +00:00
Jake Hillion
ba708c2ccd torrents! 2026-01-12 15:12:56 +00:00
Jake Hillion
c9d95859e8 dashboard: show disk usage for completed models
The downloads dashboard showed "Completed" for finished model downloads
but provided no indication of how much disk space each model or the
total models on a node were using.

Added total_bytes field to DownloadCompleted type so the size is
preserved when a download completes. Updated the dashboard to display
the model size next to "Completed" status (e.g., "Completed (251.1GB)")
and a total disk usage line below the model count for each node (e.g.,
"502.2GB on disk").

Test plan:
- Ran unit tests for download apply and planning logic
- Type checked all modified files with basedpyright
2026-01-12 12:08:46 +00:00
Jake Hillion
3c5b7ea670 ci: add workflow_dispatch trigger to build-app
Build app is the most convenient way to get a DMG for testing, but
currently it's a bit limited. You have to push to test-app every time
which is far from ideal and requires a bit too much force pushing for my
liking.

Add the workflow_dispatch trigger. This adds a button in the actions UI
to trigger a workflow for a named branch, which means you can use your
normal dev branch instead of having to push to test-app. We'll leave
that behaviour there for now too, though it may change in future.

Filter on `"${{ github.event_name }}" == "workflow_dispatch"` and set
those to alpha as well. Will verify by pushing the first version from
`main` just in case. Unfortunately we do have to merge this before we
can test it.

Test plan:
- Looking really hard.
2026-01-12 12:14:21 +01:00
PG
b74a610537 Add a basic documentation to the api interface (#1122)
## Motivation

Adds basic api documentation

## Changes

- Add docs/api.md
- Modify README.md
2026-01-11 18:44:40 +00:00
Jake Hillion
18c4e49f91 nix: put treefmt in devshell
treefmt is a useful to be able to access directly for some formatters like
`jj fix`. Expose it in the devshell.

Test plan:
- Used with `jj fix` on a large branch. It worked.
2026-01-09 17:53:50 +01:00
116 changed files with 25898 additions and 1280 deletions

1605
Cargo.lock generated
View File

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
resolver = "3"
members = [
"rust/networking",
"rust/downloads",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
@@ -25,6 +26,7 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
downloads = { path = "rust/downloads" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }

View File

@@ -19,7 +19,6 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -6,7 +6,7 @@ enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
@@ -28,6 +28,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

View File

@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
// Uses API preview data when available, falls back to local estimation
const placementPreview = $derived(() => {
const nodeArray = nodeList();
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
const numNodes = nodeArray.length;
const iconSize = numNodes === 1 ? 50 : 36;

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: NodeInfo | undefined): string | null {
function checkNode(node: typeof data.nodes[string]): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
@@ -39,19 +39,17 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
return matchFromInterfaces.name;
}
if (node.ip_to_interface) {
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
if (mapped && mapped.trim().length > 0) {
return mapped;
}
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
}
return null;
}
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
@@ -257,24 +255,21 @@ function wrapLine(text: string, maxLen: number): string[] {
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
const pairMap = new Map<string, PairEntry>();
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
if (edge.source === a) entry.aToB = true;
else entry.bToA = true;
const ip = edge.sendBackIp || '?';
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
const ifaceInfo = getInterfaceLabel(edge.source, ip);
entry.connections.push({
from: edge.source,
@@ -343,8 +338,9 @@ function wrapLine(text: string, maxLen: number): string[] {
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
@@ -385,32 +381,32 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
const quadrants: Record<string, typeof debugEdgeLabels> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
if (quadrantEdges.length === 0) return;
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
quadrantEdges.forEach(edge => {
edges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;

View File

@@ -99,7 +99,7 @@ interface RawNodeProfile {
interface RawTopologyNode {
nodeId: string;
nodeProfile?: RawNodeProfile;
nodeProfile: RawNodeProfile;
}
interface RawTopologyConnection {
@@ -110,19 +110,9 @@ interface RawTopologyConnection {
| string;
}
// Connection can be an object or a tuple [source, target, metadata]
type RawConnectionItem =
| RawTopologyConnection
| [
string,
string,
{ sinkMultiaddr?: { ip_address?: string; address?: string } }?,
];
interface RawTopology {
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
connections?: RawConnectionItem[];
nodes: RawTopologyNode[];
connections?: RawTopologyConnection[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
@@ -223,18 +213,9 @@ function transformTopology(
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === "string" ? node : node.nodeId;
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode =
typeof node === "object" ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
@@ -283,7 +264,7 @@ function transformTopology(
}
}
nodes[nodeId] = {
nodes[node.nodeId] = {
system_info: {
model_id: profile?.modelId ?? "Unknown",
chip: profile?.chipId,
@@ -311,39 +292,14 @@ function transformTopology(
};
}
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
for (const conn of raw.connections || []) {
let localNodeId: string | undefined;
let sendBackNodeId: string | undefined;
let sendBackMultiaddr:
| { multiaddr?: string; address?: string; ip_address?: string }
| string
| undefined;
// Check if it's a tuple format [source, target, metadata]
if (Array.isArray(conn)) {
localNodeId = conn[0] as string;
sendBackNodeId = conn[1] as string;
const metadata = conn[2] as
| { sinkMultiaddr?: { ip_address?: string; address?: string } }
| undefined;
if (metadata?.sinkMultiaddr) {
sendBackMultiaddr = metadata.sinkMultiaddr;
}
} else {
// Object format with localNodeId/sendBackNodeId
localNodeId = conn.localNodeId;
sendBackNodeId = conn.sendBackNodeId;
sendBackMultiaddr = conn.sendBackMultiaddr;
}
if (!localNodeId || !sendBackNodeId) continue;
if (localNodeId === sendBackNodeId) continue;
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
if (conn.localNodeId === conn.sendBackNodeId) continue;
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
let sendBackIp: string | undefined;
if (sendBackMultiaddr) {
const multi = sendBackMultiaddr;
if (conn.sendBackMultiaddr) {
const multi = conn.sendBackMultiaddr;
if (typeof multi === "string") {
sendBackIp = extractIpFromMultiaddr(multi);
} else {
@@ -355,8 +311,8 @@ function transformTopology(
}
edges.push({
source: localNodeId,
target: sendBackNodeId,
source: conn.localNodeId,
target: conn.sendBackNodeId,
sendBackIp,
});
}

View File

@@ -895,7 +895,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
const [tag, shard] = getTagged(shardWrapped);
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
return { runnerId, tag, deviceRank };
});

View File

@@ -199,7 +199,13 @@
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
?? (downloadPayload as Record<string, unknown>).downloadProgress
?? {};
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
const totalBytes = getBytes(
(downloadPayload as Record<string, unknown>).total_bytes
?? (downloadPayload as Record<string, unknown>).totalBytes
?? (rawProgress as Record<string, unknown>).total_bytes
?? (rawProgress as Record<string, unknown>).totalBytes
);
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
@@ -332,8 +338,13 @@
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
</div>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
<div>
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
</div>
<div class="text-exo-light-gray normal-case tracking-normal">
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
</div>
</div>
</div>
@@ -385,7 +396,7 @@
</div>
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
{#if model.status !== 'completed'}
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
{/if}

40
rust/downloads/Cargo.toml Normal file
View File

@@ -0,0 +1,40 @@
[package]
name = "downloads"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "downloads"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
# macro dependencies
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-util = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
itertools = { workspace = true }
# tracing/logging
log = { workspace = true }
# BitTorrent library
librqbit = { git = "https://github.com/JakeHillion/rqbit", rev = "c4e2ecf81d03bd8acd96a0803d06a70b34d5da19" }
# Embed torrent files
include_dir = "0.7"
# Serialization
serde = { version = "1.0", features = ["derive"] }

View File

@@ -0,0 +1,162 @@
//! Bencode encoding for BitTorrent tracker responses
//!
//! Implements the subset of bencoding needed for tracker announce responses.
use std::collections::BTreeMap;
/// Parameters from a tracker announce request
#[derive(Debug, Clone)]
pub struct AnnounceParams {
/// 20-byte info hash of the torrent
pub info_hash: [u8; 20],
/// 20-byte peer ID of the client
pub peer_id: [u8; 20],
/// Port the client is listening on
pub port: u16,
/// Total bytes uploaded
pub uploaded: u64,
/// Total bytes downloaded
pub downloaded: u64,
/// Bytes remaining to download
pub left: u64,
/// Whether to return compact peer list (6 bytes per peer)
pub compact: bool,
/// Optional event (started, stopped, completed)
pub event: Option<AnnounceEvent>,
}
/// Announce event types
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AnnounceEvent {
Started,
Stopped,
Completed,
}
/// A bencoded value
#[derive(Debug, Clone)]
pub enum BencodeValue {
Integer(i64),
Bytes(Vec<u8>),
List(Vec<BencodeValue>),
Dict(BTreeMap<Vec<u8>, BencodeValue>),
}
impl BencodeValue {
/// Create a string value from a &str
#[inline]
pub fn string(s: &str) -> Self {
Self::Bytes(s.as_bytes().to_vec())
}
/// Create an integer value
#[inline]
pub fn integer(i: i64) -> Self {
Self::Integer(i)
}
/// Create an empty list
#[inline]
pub fn list() -> Self {
Self::List(Vec::new())
}
/// Create an empty dict
#[inline]
pub fn dict() -> Self {
Self::Dict(BTreeMap::new())
}
/// Add an item to a list (builder pattern)
#[inline]
pub fn push(mut self, value: BencodeValue) -> Self {
if let Self::List(ref mut list) = self {
list.push(value);
}
self
}
/// Insert a key-value pair into a dict (builder pattern)
#[inline]
pub fn insert(mut self, key: &str, value: BencodeValue) -> Self {
if let Self::Dict(ref mut dict) = self {
dict.insert(key.as_bytes().to_vec(), value);
}
self
}
/// Encode to bencoded bytes
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::new();
self.encode_into(&mut buf);
buf
}
/// Encode into an existing buffer
pub fn encode_into(&self, buf: &mut Vec<u8>) {
match self {
Self::Integer(i) => {
buf.push(b'i');
buf.extend_from_slice(i.to_string().as_bytes());
buf.push(b'e');
}
Self::Bytes(bytes) => {
buf.extend_from_slice(bytes.len().to_string().as_bytes());
buf.push(b':');
buf.extend_from_slice(bytes);
}
Self::List(list) => {
buf.push(b'l');
for item in list {
item.encode_into(buf);
}
buf.push(b'e');
}
Self::Dict(dict) => {
buf.push(b'd');
// BTreeMap keeps keys sorted
for (key, value) in dict {
buf.extend_from_slice(key.len().to_string().as_bytes());
buf.push(b':');
buf.extend_from_slice(key);
value.encode_into(buf);
}
buf.push(b'e');
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_integer() {
assert_eq!(BencodeValue::integer(42).encode(), b"i42e");
assert_eq!(BencodeValue::integer(-1).encode(), b"i-1e");
assert_eq!(BencodeValue::integer(0).encode(), b"i0e");
}
#[test]
fn test_encode_string() {
assert_eq!(BencodeValue::string("spam").encode(), b"4:spam");
assert_eq!(BencodeValue::string("").encode(), b"0:");
}
#[test]
fn test_encode_list() {
let list = BencodeValue::list()
.push(BencodeValue::string("spam"))
.push(BencodeValue::integer(42));
assert_eq!(list.encode(), b"l4:spami42ee");
}
#[test]
fn test_encode_dict() {
let dict = BencodeValue::dict()
.insert("bar", BencodeValue::string("spam"))
.insert("foo", BencodeValue::integer(42));
assert_eq!(dict.encode(), b"d3:bar4:spam3:fooi42ee");
}
}

View File

@@ -0,0 +1,108 @@
//! Embedded torrent file access
//!
//! Provides access to .torrent files embedded in the binary at compile time.
//! Each model/revision can have multiple torrent variants (e.g., "small", "large").
use include_dir::{Dir, include_dir};
/// Embedded torrent files directory
static TORRENTS: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/torrents");
/// Get all embedded torrent variants for a model_id and revision
///
/// # Arguments
/// * `model_id` - Model identifier (e.g., "mlx-community/Qwen3-30B-A3B-4bit")
/// * `revision` - Git commit hash
///
/// # Returns
/// Vec of (variant_name, torrent_data) tuples, e.g., [("small", data), ("large", data)]
/// Returns empty Vec if no torrents found for this model/revision.
#[inline]
pub fn get_embedded_torrents(model_id: &str, revision: &str) -> Vec<(String, Vec<u8>)> {
let dir_path = format!("{model_id}");
let Some(model_dir) = TORRENTS.get_dir(&dir_path) else {
return Vec::new();
};
let mut results = Vec::new();
let prefix = format!("{revision}.");
let suffix = ".torrent";
for file in model_dir.files() {
let Some(name) = file.path().file_name().and_then(|n| n.to_str()) else {
continue;
};
// Match files like "{revision}.small.torrent" or "{revision}.large.torrent"
if name.starts_with(&prefix) && name.ends_with(suffix) {
// Extract variant: "{revision}.{variant}.torrent" -> "{variant}"
let middle = &name[prefix.len()..name.len() - suffix.len()];
// Skip plain "{revision}.torrent" files (wrong format)
if middle.is_empty() {
continue;
}
results.push((middle.to_string(), file.contents().to_vec()));
}
}
// Sort by variant name for consistent ordering
results.sort_by(|a, b| a.0.cmp(&b.0));
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_embedded_torrents() {
// Test with the Qwen3 torrent we have
let result = get_embedded_torrents(
"mlx-community/Qwen3-30B-A3B-4bit",
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
);
assert!(!result.is_empty(), "Expected to find embedded torrents");
// Should have both small and large variants
let variants: Vec<&str> = result.iter().map(|(v, _)| v.as_str()).collect();
assert!(
variants.contains(&"small"),
"Expected 'small' variant, got: {variants:?}"
);
assert!(
variants.contains(&"large"),
"Expected 'large' variant, got: {variants:?}"
);
// Verify data is not empty
for (variant, data) in &result {
assert!(!data.is_empty(), "Torrent data for '{variant}' should not be empty");
}
}
#[test]
fn test_missing_torrent() {
let result = get_embedded_torrents("nonexistent/model", "abc123");
assert!(result.is_empty(), "Expected empty Vec for missing torrent");
}
#[test]
fn test_variant_ordering() {
let result = get_embedded_torrents(
"mlx-community/Qwen3-30B-A3B-4bit",
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
);
if result.len() >= 2 {
// Verify alphabetical ordering
let variants: Vec<&str> = result.iter().map(|(v, _)| v.as_str()).collect();
let mut sorted = variants.clone();
sorted.sort();
assert_eq!(variants, sorted, "Variants should be sorted alphabetically");
}
}
}

22
rust/downloads/src/lib.rs Normal file
View File

@@ -0,0 +1,22 @@
//! BitTorrent-based download system for model shards using rqbit
//!
//! This crate provides:
//! - Torrent session management via rqbit
//! - Embedded torrent file access
//! - Private tracker announce handling
//! - Selective file download based on shard layer ranges
#![allow(clippy::missing_inline_in_public_items)]
pub mod bencode;
pub mod embedded;
pub mod progress;
pub mod session;
pub mod torrent_files;
pub mod tracker;
pub use bencode::AnnounceParams;
pub use embedded::get_embedded_torrents;
pub use session::{DownloadProgress, TorrentSession};
pub use torrent_files::{get_torrent_file_list, TorrentFileInfo};
pub use tracker::{handle_announce, PeerInfo, TopologyData};

View File

@@ -0,0 +1,77 @@
//! Download progress tracking
//!
//! Types for tracking and reporting download progress to Python
use std::collections::HashMap;
/// Progress update for a torrent download
#[derive(Debug, Clone)]
pub struct DownloadProgress {
/// Total bytes to download
pub total_bytes: u64,
/// Bytes downloaded so far
pub downloaded_bytes: u64,
/// Number of pieces completed
pub pieces_completed: usize,
/// Total number of pieces
pub total_pieces: usize,
/// Number of peers connected
pub peers_connected: usize,
/// Download speed in bytes/second
pub speed_bytes_per_sec: f64,
/// Estimated time remaining in seconds
pub eta_seconds: Option<f64>,
/// Per-file progress
pub files: HashMap<String, FileProgress>,
}
#[derive(Debug, Clone)]
pub struct FileProgress {
/// Total file size
pub total_bytes: u64,
/// Bytes downloaded for this file
pub downloaded_bytes: u64,
/// Whether the file is complete
pub complete: bool,
}
impl DownloadProgress {
#[inline]
pub fn new(total_bytes: u64, total_pieces: usize) -> Self {
Self {
total_bytes,
downloaded_bytes: 0,
pieces_completed: 0,
total_pieces,
peers_connected: 0,
speed_bytes_per_sec: 0.0,
eta_seconds: None,
files: HashMap::new(),
}
}
#[inline]
pub fn progress_fraction(&self) -> f64 {
if self.total_bytes == 0 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
let fraction = self.downloaded_bytes as f64 / self.total_bytes as f64;
fraction
}
}
#[inline]
pub fn is_complete(&self) -> bool {
self.pieces_completed >= self.total_pieces
}
}

View File

@@ -0,0 +1,166 @@
//! Torrent session management using rqbit
//!
//! Provides a wrapper around rqbit's Session for managing torrent downloads
//! with persistent seeding and selective file downloads.
use anyhow::{Context, Result};
use librqbit::{AddTorrent, AddTorrentOptions, AddTorrentResponse, Api, ManagedTorrent, Session, SessionOptions, SessionPersistenceConfig};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Download progress information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DownloadProgress {
pub downloaded_bytes: u64,
pub total_bytes: u64,
pub download_speed: f64,
pub upload_speed: f64,
pub peers_connected: usize,
pub is_finished: bool,
}
/// Torrent session handle for managing multiple torrents
pub struct TorrentSession {
session: Arc<Session>,
api: Arc<Api>,
session_dir: PathBuf,
torrents: Arc<RwLock<HashMap<String, Arc<ManagedTorrent>>>>,
}
impl TorrentSession {
/// Create a new torrent session
///
/// # Arguments
/// * `session_dir` - Directory to store session state and downloaded files
pub async fn new(session_dir: PathBuf) -> Result<Self> {
std::fs::create_dir_all(&session_dir).context("Failed to create session directory")?;
let opts = SessionOptions {
disable_dht: false,
disable_dht_persistence: false,
dht_config: None,
persistence: Some(SessionPersistenceConfig::Json { folder: None }),
fastresume: true,
..Default::default()
};
let session = Session::new_with_opts(session_dir.clone(), opts)
.await
.context("Failed to create rqbit session")?;
let api = Api::new(Arc::clone(&session), None);
Ok(Self {
session,
api: Arc::new(api),
session_dir,
torrents: Arc::new(RwLock::new(HashMap::new())),
})
}
/// Add a torrent from raw bytes
///
/// # Arguments
/// * `torrent_data` - Raw .torrent file contents
/// * `save_path` - Where to save the downloaded files
/// * `file_indices` - Optional list of file indices to download (None = all files)
///
/// # Returns
/// Info hash as hex string
pub async fn add_torrent(
&self,
torrent_data: Vec<u8>,
save_path: PathBuf,
file_indices: Option<Vec<usize>>,
) -> Result<String> {
let opts = AddTorrentOptions {
overwrite: false,
only_files_regex: None,
only_files: file_indices,
output_folder: Some(save_path.to_string_lossy().to_string()),
..Default::default()
};
let add_torrent = AddTorrent::from_bytes(torrent_data);
let response = self
.session
.add_torrent(add_torrent, Some(opts))
.await
.context("Failed to add torrent")?;
let handle = match response {
AddTorrentResponse::Added(_, handle) => handle,
AddTorrentResponse::AlreadyManaged(_, handle) => handle,
AddTorrentResponse::ListOnly(_) => anyhow::bail!("Torrent was list-only, not added"),
};
let info_hash = handle.info_hash().as_string();
self.torrents
.write()
.await
.insert(info_hash.clone(), handle);
Ok(info_hash)
}
/// Get download progress for a torrent
pub async fn get_progress(&self, info_hash: &str) -> Result<DownloadProgress> {
let torrents = self.torrents.read().await;
let handle = torrents.get(info_hash).context("Torrent not found")?;
let stats = handle.stats();
Ok(DownloadProgress {
downloaded_bytes: stats.progress_bytes,
total_bytes: stats.total_bytes,
download_speed: stats.live.as_ref().map_or(0.0, |l| l.download_speed.mbps * 1024.0 * 1024.0),
upload_speed: stats.live.as_ref().map_or(0.0, |l| l.upload_speed.mbps * 1024.0 * 1024.0),
peers_connected: stats.live.as_ref().map_or(0, |l| l.snapshot.peer_stats.live as usize),
is_finished: stats.finished,
})
}
/// Wait until torrent download is completed
pub async fn wait_until_completed(&self, info_hash: &str) -> Result<()> {
let torrents = self.torrents.read().await;
let handle = torrents.get(info_hash).context("Torrent not found")?;
handle
.wait_until_completed()
.await
.context("Failed to wait for completion")?;
Ok(())
}
/// Enable seeding for a completed torrent
///
/// Note: rqbit seeds by default after completion, this is a no-op
/// but kept for API compatibility
pub async fn enable_seeding(&self, _info_hash: &str) -> Result<()> {
// rqbit automatically seeds after download completion
// This is kept for API compatibility
Ok(())
}
/// Remove a torrent from the session
pub async fn remove_torrent(&self, info_hash: &str) -> Result<()> {
let mut torrents = self.torrents.write().await;
if let Some(handle) = torrents.remove(info_hash) {
drop(handle);
}
Ok(())
}
/// Get list of all torrent info hashes in the session
pub async fn list_torrents(&self) -> Vec<String> {
self.torrents.read().await.keys().cloned().collect()
}
}

View File

@@ -0,0 +1,100 @@
//! Torrent file list parsing
//!
//! Provides functionality to extract file information from torrent metadata
//! without adding the torrent to a session.
use anyhow::{Context, Result};
use librqbit::torrent_from_bytes;
use serde::{Deserialize, Serialize};
/// Information about a file in a torrent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TorrentFileInfo {
/// File index (0-based)
pub index: usize,
/// File path relative to torrent root
pub path: String,
/// File size in bytes
pub size: u64,
}
/// Get the list of files in a torrent from its raw bytes
///
/// # Arguments
/// * `torrent_data` - Raw .torrent file contents
///
/// # Returns
/// List of file information (index, path, size)
pub fn get_torrent_file_list(torrent_data: &[u8]) -> Result<Vec<TorrentFileInfo>> {
let torrent_meta = torrent_from_bytes(torrent_data).context("Failed to parse torrent")?;
// Access the data inside WithRawBytes wrapper
let info = &torrent_meta.info.data;
let mut files = Vec::new();
// Handle both single-file and multi-file torrents
if let Some(ref file_list) = info.files {
// Multi-file torrent
for (index, file) in file_list.iter().enumerate() {
let path = file
.path
.iter()
.map(|buf| String::from_utf8_lossy(buf.0).to_string())
.collect::<Vec<_>>()
.join("/");
files.push(TorrentFileInfo {
index,
path,
size: file.length,
});
}
} else {
// Single-file torrent
let name = match &info.name {
Some(n) => String::from_utf8_lossy(n.0).to_string(),
None => String::new(),
};
files.push(TorrentFileInfo {
index: 0,
path: name,
size: info.length.unwrap_or(0),
});
}
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::get_embedded_torrents;
#[test]
fn test_get_torrent_file_list() {
// Use an embedded torrent for testing
let torrents = get_embedded_torrents(
"mlx-community/Qwen3-30B-A3B-4bit",
"d388dead1515f5e085ef7a0431dd8fadf0886c57",
);
assert!(!torrents.is_empty(), "Expected to find embedded torrents");
for (variant, data) in torrents {
let files = get_torrent_file_list(&data).expect("Failed to parse torrent");
assert!(!files.is_empty(), "Expected files in {variant} variant");
// Verify file info makes sense
for file in &files {
assert!(!file.path.is_empty(), "File path should not be empty");
assert!(file.size > 0, "File size should be positive");
}
println!("Variant '{variant}' has {} files", files.len());
for file in files.iter().take(5) {
println!(" [{}] {} ({} bytes)", file.index, file.path, file.size);
}
}
}
}

View File

@@ -0,0 +1,185 @@
//! Fake tracker implementation for Exo topology-based peer discovery
//!
//! Instead of contacting real BitTorrent trackers, this module generates
//! tracker announce responses using Exo's cluster topology data.
use std::net::Ipv4Addr;
use anyhow::Result;
use crate::bencode::{AnnounceParams, BencodeValue};
/// Information about a peer in the Exo topology
#[derive(Debug, Clone)]
pub struct PeerInfo {
/// Unique node identifier in the Exo cluster
pub node_id: String,
/// IPv4 address of the peer
pub ip: Ipv4Addr,
/// BitTorrent listening port
pub port: u16,
/// Whether this peer has the complete torrent
pub has_complete: bool,
/// Priority for peer selection (higher = prefer)
pub priority: i32,
}
/// Topology data containing available peers
#[derive(Debug, Clone)]
pub struct TopologyData {
/// List of peers in the topology
pub peers: Vec<PeerInfo>,
}
/// Default announce interval in seconds
const DEFAULT_INTERVAL: i64 = 1800;
/// Handle a tracker announce request using Exo topology data
///
/// Returns a bencoded tracker response containing peers from the topology.
///
/// # Arguments
/// * `params` - Announce request parameters
/// * `topology` - Current Exo cluster topology
///
/// # Returns
/// Bencoded announce response as bytes
pub fn handle_announce(params: &AnnounceParams, topology: &TopologyData) -> Result<Vec<u8>> {
// Sort peers by priority (descending) for better peer selection
let mut peers: Vec<_> = topology.peers.iter().collect();
peers.sort_by(|a, b| b.priority.cmp(&a.priority));
let response = if params.compact {
// Compact format: 6 bytes per peer (4 IP + 2 port)
let mut peer_data = Vec::with_capacity(peers.len() * 6);
for peer in &peers {
peer_data.extend_from_slice(&peer.ip.octets());
peer_data.extend_from_slice(&peer.port.to_be_bytes());
}
BencodeValue::dict()
.insert("interval", BencodeValue::integer(DEFAULT_INTERVAL))
.insert("peers", BencodeValue::Bytes(peer_data))
} else {
// Non-compact format: list of dicts
let mut peer_list = BencodeValue::list();
for peer in &peers {
let peer_dict = BencodeValue::dict()
.insert("ip", BencodeValue::string(&peer.ip.to_string()))
.insert("port", BencodeValue::integer(i64::from(peer.port)))
.insert("peer id", BencodeValue::Bytes(vec![0u8; 20])); // Placeholder peer ID
peer_list = peer_list.push(peer_dict);
}
BencodeValue::dict()
.insert("interval", BencodeValue::integer(DEFAULT_INTERVAL))
.insert("peers", peer_list)
};
Ok(response.encode())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_params(compact: bool) -> AnnounceParams {
AnnounceParams {
info_hash: [0u8; 20],
peer_id: [0u8; 20],
port: 6881,
uploaded: 0,
downloaded: 0,
left: 1000,
compact,
event: None,
}
}
fn make_test_topology() -> TopologyData {
TopologyData {
peers: vec![
PeerInfo {
node_id: "node1".to_string(),
ip: Ipv4Addr::new(192, 168, 1, 1),
port: 6881,
has_complete: true,
priority: 10,
},
PeerInfo {
node_id: "node2".to_string(),
ip: Ipv4Addr::new(192, 168, 1, 2),
port: 6882,
has_complete: false,
priority: 5,
},
],
}
}
#[test]
fn test_compact_response() {
let params = make_test_params(true);
let topology = make_test_topology();
let response = handle_announce(&params, &topology).unwrap();
// Should contain "interval" and "peers" keys
assert!(response.starts_with(b"d"));
assert!(response.ends_with(b"e"));
// Verify we have 12 bytes of peer data (2 peers * 6 bytes)
// The compact peers field should be "12:<12 bytes>"
let response_str = String::from_utf8_lossy(&response);
assert!(response_str.contains("8:interval"));
assert!(response_str.contains("5:peers"));
}
#[test]
fn test_non_compact_response() {
let params = make_test_params(false);
let topology = make_test_topology();
let response = handle_announce(&params, &topology).unwrap();
// Should contain peers as a list
let response_str = String::from_utf8_lossy(&response);
assert!(response_str.contains("8:interval"));
assert!(response_str.contains("5:peers"));
assert!(response_str.contains("2:ip"));
assert!(response_str.contains("4:port"));
}
#[test]
fn test_peer_priority_ordering() {
let params = make_test_params(true);
let topology = make_test_topology();
let response = handle_announce(&params, &topology).unwrap();
// In compact format, first peer should be node1 (priority 10)
// which is 192.168.1.1:6881
// Look for the peer data after "5:peers12:"
let peers_marker = b"5:peers12:";
let pos = response
.windows(peers_marker.len())
.position(|w| w == peers_marker)
.unwrap();
let peer_data = &response[pos + peers_marker.len()..pos + peers_marker.len() + 6];
// First peer should be 192.168.1.1 (node1 with higher priority)
assert_eq!(&peer_data[0..4], &[192, 168, 1, 1]);
}
#[test]
fn test_empty_topology() {
let params = make_test_params(true);
let topology = TopologyData { peers: vec![] };
let response = handle_announce(&params, &topology).unwrap();
// Should still be valid bencoded response with empty peers
assert!(response.starts_with(b"d"));
assert!(response.ends_with(b"e"));
}
}

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi884e4:pathl9:README.mdeed6:lengthi1249e4:pathl19:chat_template.jinjaeed6:lengthi1848e4:pathl11:config.jsoneed6:lengthi10652e4:pathl25:configuration_deepseek.pyeed6:lengthi52e4:pathl22:generation_config.jsoneed6:lengthi221164e4:pathl28:model.safetensors.index.jsoneed6:lengthi75769e4:pathl20:modeling_deepseek.pyeed6:lengthi760e4:pathl23:special_tokens_map.jsoneed6:lengthi11330e4:pathl20:tokenization_kimi.pyeed6:lengthi2738e4:pathl21:tokenizer_config.jsoneee4:name40:91fb4f9fd1de100104925196d62b8ee06fd2ad6012:piece lengthi262144e6:pieces40:<3A>C<EFBFBD>t:<3A><>I_<49>i*xg<78><04>s|,<2C>4S<34><53><EFBFBD>j<EFBFBD><6A><EFBFBD>S<EFBFBD><03>|d<>e8:url-list63:https://huggingface.co/mlx-community/Kimi-K2-Instruct-4bit/raw/e

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi864e4:pathl9:README.mdeed6:lengthi3442e4:pathl19:chat_template.jinjaeed6:lengthi3445e4:pathl11:config.jsoneed6:lengthi10652e4:pathl25:configuration_deepseek.pyeed6:lengthi53e4:pathl22:generation_config.jsoneed6:lengthi129766e4:pathl28:model.safetensors.index.jsoneed6:lengthi75769e4:pathl20:modeling_deepseek.pyeed6:lengthi760e4:pathl23:special_tokens_map.jsoneed6:lengthi12597e4:pathl20:tokenization_kimi.pyeed6:lengthi4047e4:pathl21:tokenizer_config.jsoneee4:name40:035a0cdd221ae0dca6b03120e20704a251a7bc9b12:piece lengthi262144e6:pieces20:<3A>^<5E>9`<60>C<18><>Y<EFBFBD>-L<><4C>*EC*e8:url-list58:https://huggingface.co/mlx-community/Kimi-K2-Thinking/raw/e

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16485e4:pathl9:README.mdeed6:lengthi1123e4:pathl11:config.jsoneed6:lengthi158327e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55425e4:pathl21:tokenizer_config.jsoneee4:name40:de2dfaf56839b7d0e834157d2401dee02726874d12:piece lengthi262144e6:pieces20:<3A>*_<1F><><EFBFBD><18>Tij<04><>+<2B>]<5D><>e8:url-list69:https://huggingface.co/mlx-community/Llama-3.3-70B-Instruct-4bit/raw/e

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16485e4:pathl9:README.mdeed6:lengthi1123e4:pathl11:config.jsoneed6:lengthi158327e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55425e4:pathl21:tokenizer_config.jsoneee4:name40:c5bfd839cd4cda0e5a39a97e00218d9c56e468af12:piece lengthi262144e6:pieces20:܌!<0E><><EFBFBD>TO<54><4F>4<><34><EFBFBD>P<EFBFBD>_Qe8:url-list69:https://huggingface.co/mlx-community/Llama-3.3-70B-Instruct-8bit/raw/e

View File

@@ -0,0 +1,2 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi1033e4:pathl9:README.mdeed6:lengthi707e4:pathl17:added_tokens.jsoneed6:lengthi6722e4:pathl19:chat_template.jinjaeed6:lengthi1222e4:pathl11:config.jsoneed6:lengthi180e4:pathl22:generation_config.jsoneed6:lengthi1671853e4:pathl10:merges.txteed6:lengthi154390e4:pathl28:model.safetensors.index.jsoneed6:lengthi28881e4:pathl24:qwen3_xml_tool_parser.pyeed6:lengthi613e4:pathl23:special_tokens_map.jsoneed6:lengthi5405e4:pathl21:tokenizer_config.jsoneed6:lengthi2776833e4:pathl10:vocab.jsoneee4:name40:ca8dbf41071f579fbe3260f20bbe1ab896f7903112:piece lengthi262144e6:pieces360:<3A>3<EFBFBD>\<5C>PDE<44><45><17><><EFBFBD><06><06><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>c+<2B>h{" <0B><>_
m<EFBFBD> 7<><37><EFBFBD><EFBFBD>.<2E>h<14><>fm<66><6D>,<2C>w<EFBFBD><77>nOМ<4F><11><>"<22><><EFBFBD><EFBFBD>&j<><6A>_<EFBFBD><5F>"F<><46><EFBFBD>u<18>gU<67><08><><EFBFBD>QW<51><57><EFBFBD><EFBFBD>@qiiq<69><71>T<EFBFBD><54><EFBFBD>P<>lSJƤ<4A> \<5C><><EFBFBD>R!<21>=<3D><>v<EFBFBD><76><EFBFBD>F<EFBFBD>q9<71><39><EFBFBD><EFBFBD><01><><EFBFBD><EFBFBD><av<61>B@<40><> <09>z

View File

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi75789955e4:pathl17:model.safetensorseee4:name40:f56bc6adfb74c794203dc8ca94e0bccfe2bcd6cc12:piece lengthi16777216e6:pieces100:QM0Ts@Ev<>XԄ=<3D>6_xhњU4=<3D><>7<EFBFBD>j<EFBFBD><6A><EFBFBD><18>F<EFBFBD>M<EFBFBD>q<EFBFBD><71><EFBFBD><EFBFBD>m>a<><61>H°*'<27>5<EFBFBD><35>/9B<39><42>^V<>4H9m<39><6D><EFBFBD><EFBFBD>0<EFBFBD>^z<><7A>+YS*<2A>M<EFBFBD><4D>G<EFBFBD>+<2B>.<02>h<EFBFBD>5e8:url-list62:https://huggingface.co/mlx-community/SmolLM-135M-4bit/resolve/e

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi845e4:pathl9:README.mdeed6:lengthi16738e4:pathl19:chat_template.jinjaeed6:lengthi50145e4:pathl11:config.jsoneed6:lengthi177e4:pathl22:generation_config.jsoneed6:lengthi100431e4:pathl28:model.safetensors.index.jsoneed6:lengthi440e4:pathl23:special_tokens_map.jsoneed6:lengthi4200e4:pathl21:tokenizer_config.jsoneee4:name40:81e5ac3ad0af6efb1298a8e8c7a10ed2990c137b12:piece lengthi262144e6:pieces20:ME<4D>TVE@ͯ<><4E>8<><38><EFBFBD>`e8:url-list63:https://huggingface.co/mlx-community/gpt-oss-120b-MXFP4-Q8/raw/e

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi838e4:pathl9:README.mdeed6:lengthi33998e4:pathl11:config.jsoneed6:lengthi177e4:pathl22:generation_config.jsoneed6:lengthi67046e4:pathl28:model.safetensors.index.jsoneed6:lengthi440e4:pathl23:special_tokens_map.jsoneed6:lengthi21694e4:pathl21:tokenizer_config.jsoneee4:name40:f356f2747216d7e98fee755df25987459fc1908912:piece lengthi262144e6:pieces20:<3A><><EFBFBD><EFBFBD>ͥ<><CDA5><EFBFBD>g#`<60><>f<EFBFBD>x<EFBFBD><78>e8:url-list62:https://huggingface.co/mlx-community/gpt-oss-20b-MXFP4-Q4/raw/e

View File

@@ -0,0 +1,6 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1519e4:pathl14:.gitattributeseed6:lengthi956e4:pathl9:README.mdeed6:lengthi207e4:pathl17:added_tokens.jsoneed6:lengthi848e4:pathl11:config.jsoneed6:lengthi441810e4:pathl10:merges.txteed6:lengthi25864e4:pathl28:model.safetensors.index.jsoneed6:lengthi801e4:pathl23:special_tokens_map.jsoneed6:lengthi3476578e4:pathl14:tokenizer.jsoneed6:lengthi9935e4:pathl21:tokenizer_config.jsoneed6:lengthi776995e4:pathl10:vocab.jsoneee4:name40:39b35eaa97282c34db81f61a983b4b83344e10f112:piece lengthi262144e6:pieces380:<3A>ih֨
[׬-<2D><><EFBFBD>}<7D><19><>U<EFBFBD>){[<5B>+<2B>7PU<><13>nR`<60><>g<EFBFBD> <0C>vH<76><78>q<EFBFBD><71>Lz<4C> <>Q<>Ĉ|Š<><C5A0><EFBFBD><EFBFBD>\<5C><>ehۢ<68>S<EFBFBD> <0B>#<23>g<EFBFBD>Y%@D<><D2A9><EFBFBD>}ޥXO<><4F><EFBFBD><EFBFBD><EFBFBD><06> <0C><><EFBFBD><EFBFBD><EFBFBD><1B>Y<EFBFBD>"<22><>|<7C>JH<4A> <0C>w<EFBFBD><05>MH<4D>*k<>@R<><52> 1i<31>|<7C>y<H<02><>H {<7B><14>
<EFBFBD><1B><P<><50>@<40><><16><>E<<3C><><EFBFBD>S<EFBFBD><53>|<7C><><EFBFBD>A
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>_<><5F>
;<3B>Rg<52><67><EFBFBD><>$<24><><EFBFBD>|@`X<58><7F>#<23><><EFBFBD>M<EFBFBD>$n-<2D><10><>i<EFBFBD>
9<>6ɝ@t<><74>j<EFBFBD><16>n<EFBFBD><6E><EFBFBD><EFBFBD>ɃH<C983><48>,<2C><>

View File

@@ -0,0 +1 @@
d8:announce42:udp://tracker.opentrackr.org:1337/announce10:created by13:mktorrent 1.14:infod5:filesld6:lengthi1570e4:pathl14:.gitattributeseed6:lengthi16447e4:pathl9:README.mdeed6:lengthi970e4:pathl11:config.jsoneed6:lengthi62518e4:pathl28:model.safetensors.index.jsoneed6:lengthi454e4:pathl23:special_tokens_map.jsoneed6:lengthi55421e4:pathl21:tokenizer_config.jsoneee4:name40:8103891b028a8933068e47751bc2acc10bb59aa212:piece lengthi262144e6:pieces20:<3A>l<EFBFBD>f<EFBFBD>7<>.<2E><><EFBFBD> <0B> a<><61>e8:url-list69:https://huggingface.co/mlx-community/llama-3.3-70b-instruct-fp16/raw/e

View File

@@ -23,6 +23,7 @@ workspace = true
[dependencies]
networking = { workspace = true }
downloads = { workspace = true }
# interop
pyo3 = { version = "0.27.1", features = [

View File

@@ -0,0 +1,334 @@
//! Downloads module - BitTorrent downloads PyO3 bindings
use crate::ext::*;
use downloads::bencode::AnnounceParams;
use downloads::tracker::{PeerInfo, TopologyData, handle_announce as rust_handle_announce};
use downloads::{DownloadProgress, TorrentSession};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};
use std::net::Ipv4Addr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
/// Handle a tracker announce request
///
/// Args:
/// params: Dictionary with announce parameters (info_hash, peer_id, port, etc.)
/// peers: List of peer dictionaries (node_id, ip, port, has_complete, priority)
///
/// Returns:
/// Bencoded announce response as bytes
#[pyfunction]
fn handle_tracker_announce(
py: Python<'_>,
params: &Bound<'_, PyDict>,
peers: &Bound<'_, pyo3::types::PyList>,
) -> PyResult<Py<PyBytes>> {
// Parse announce params
let info_hash = {
let info_hash_item = params
.get_item("info_hash")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing info_hash"))?;
let info_hash_bytes: &[u8] = info_hash_item.extract()?;
if info_hash_bytes.len() != 20 {
return Err(pyo3::exceptions::PyValueError::new_err(
"info_hash must be 20 bytes",
));
}
let mut info_hash = [0u8; 20];
info_hash.copy_from_slice(info_hash_bytes);
info_hash
};
let peer_id = {
let peer_id_item = params
.get_item("peer_id")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing peer_id"))?;
let peer_id_bytes: &[u8] = peer_id_item.extract()?;
if peer_id_bytes.len() != 20 {
return Err(pyo3::exceptions::PyValueError::new_err(
"peer_id must be 20 bytes",
));
}
let mut peer_id = [0u8; 20];
peer_id.copy_from_slice(peer_id_bytes);
peer_id
};
let port: u16 = params
.get_item("port")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing port"))?
.extract()?;
let uploaded: u64 = params
.get_item("uploaded")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing uploaded"))?
.extract()?;
let downloaded: u64 = params
.get_item("downloaded")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing downloaded"))?
.extract()?;
let left: u64 = params
.get_item("left")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing left"))?
.extract()?;
let compact: bool = params
.get_item("compact")?
.map(|v| v.extract().unwrap_or(true))
.unwrap_or(true);
let announce_params = AnnounceParams {
info_hash,
peer_id,
port,
uploaded,
downloaded,
left,
compact,
event: None, // TODO: parse event if needed
};
// Parse peer list
let peer_infos: Result<Vec<PeerInfo>, PyErr> = peers
.iter()
.map(|peer_item| {
let peer_dict: &Bound<'_, PyDict> = peer_item.downcast()?;
let node_id: String = peer_dict
.get_item("node_id")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing node_id"))?
.extract()?;
let ip_str: String = peer_dict
.get_item("ip")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing ip"))?
.extract()?;
let ip: Ipv4Addr = ip_str
.parse()
.map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid IP address"))?;
let port: u16 = peer_dict
.get_item("port")?
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Missing port"))?
.extract()?;
let has_complete: bool = peer_dict
.get_item("has_complete")?
.map(|v: Bound<'_, pyo3::PyAny>| v.extract().unwrap_or(false))
.unwrap_or(false);
let priority: i32 = peer_dict
.get_item("priority")?
.map(|v: Bound<'_, pyo3::PyAny>| v.extract().unwrap_or(0))
.unwrap_or(0);
Ok(PeerInfo {
node_id,
ip,
port,
has_complete,
priority,
})
})
.collect();
let peer_infos = peer_infos?;
let topology = TopologyData { peers: peer_infos };
// Call Rust tracker handler
let response_bytes = rust_handle_announce(&announce_params, &topology).pyerr()?;
// Return as Python bytes
Ok(PyBytes::new(py, &response_bytes).unbind())
}
/// Get all embedded torrent variants for a model
///
/// Args:
/// model_id: Model identifier (e.g., "mlx-community/Qwen3-30B-A3B-4bit")
/// revision: Git commit hash
///
/// Returns:
/// List of (variant_name, torrent_data) tuples, e.g., [("small", bytes), ("large", bytes)]
/// Returns empty list if no torrents found.
#[pyfunction]
fn get_embedded_torrents(
py: Python<'_>,
model_id: String,
revision: String,
) -> PyResult<Vec<(String, Py<PyBytes>)>> {
let torrents = downloads::get_embedded_torrents(&model_id, &revision);
Ok(torrents
.into_iter()
.map(|(variant, data)| (variant, PyBytes::new(py, &data).unbind()))
.collect())
}
/// Get file list from torrent data
///
/// Args:
/// torrent_data: Raw .torrent file contents
///
/// Returns:
/// List of (index, path, size_bytes) tuples for each file in the torrent
#[pyfunction]
fn get_torrent_file_list(torrent_data: Vec<u8>) -> PyResult<Vec<(usize, String, u64)>> {
let files = downloads::get_torrent_file_list(&torrent_data).pyerr()?;
Ok(files
.into_iter()
.map(|f| (f.index, f.path, f.size))
.collect())
}
/// Python wrapper for TorrentSession
#[pyclass]
struct TorrentSessionHandle {
session: Arc<Mutex<TorrentSession>>,
}
#[pymethods]
impl TorrentSessionHandle {
/// Create a new torrent session
///
/// Args:
/// session_dir: Directory to store session state and downloads
#[new]
fn new(session_dir: String) -> PyResult<Self> {
let session_path = PathBuf::from(session_dir);
let session = tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { TorrentSession::new(session_path).await })
.pyerr()?;
Ok(Self {
session: Arc::new(Mutex::new(session)),
})
}
/// Add a torrent from bytes
///
/// Args:
/// torrent_data: Raw .torrent file contents
/// save_path: Where to save downloaded files
/// file_indices: Optional list of file indices to download
///
/// Returns:
/// Info hash as hex string
fn add_torrent(
&self,
_py: Python<'_>,
torrent_data: Vec<u8>,
save_path: String,
file_indices: Option<Vec<usize>>,
) -> PyResult<String> {
let session = Arc::clone(&self.session);
let save_path = PathBuf::from(save_path);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async {
session
.lock()
.await
.add_torrent(torrent_data, save_path, file_indices)
.await
})
.pyerr()
}
/// Get download progress for a torrent
///
/// Args:
/// info_hash: Torrent info hash
///
/// Returns:
/// Dictionary with progress information
fn get_progress(&self, py: Python<'_>, info_hash: String) -> PyResult<Py<PyDict>> {
let session = Arc::clone(&self.session);
let progress: DownloadProgress = tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { session.lock().await.get_progress(&info_hash).await })
.pyerr()?;
let dict = PyDict::new(py);
dict.set_item("downloaded_bytes", progress.downloaded_bytes)?;
dict.set_item("total_bytes", progress.total_bytes)?;
dict.set_item("download_speed", progress.download_speed)?;
dict.set_item("upload_speed", progress.upload_speed)?;
dict.set_item("peers_connected", progress.peers_connected)?;
dict.set_item("is_finished", progress.is_finished)?;
Ok(dict.unbind())
}
/// Wait until torrent download is completed
///
/// Args:
/// info_hash: Torrent info hash
fn wait_until_completed(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
let session = Arc::clone(&self.session);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { session.lock().await.wait_until_completed(&info_hash).await })
.pyerr()
}
/// Enable seeding for a torrent
///
/// Args:
/// info_hash: Torrent info hash
fn enable_seeding(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
let session = Arc::clone(&self.session);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { session.lock().await.enable_seeding(&info_hash).await })
.pyerr()
}
/// Remove a torrent from the session
///
/// Args:
/// info_hash: Torrent info hash
fn remove_torrent(&self, _py: Python<'_>, info_hash: String) -> PyResult<()> {
let session = Arc::clone(&self.session);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { session.lock().await.remove_torrent(&info_hash).await })
.pyerr()
}
/// List all torrents in the session
///
/// Returns:
/// List of info hashes
fn list_torrents(&self, _py: Python<'_>) -> PyResult<Vec<String>> {
let session = Arc::clone(&self.session);
tokio::runtime::Runtime::new()
.pyerr()?
.block_on(async { Ok(session.lock().await.list_torrents().await) })
}
}
/// Downloads submodule
pub(crate) fn downloads_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(handle_tracker_announce, m)?)?;
m.add_function(wrap_pyfunction!(get_embedded_torrents, m)?)?;
m.add_function(wrap_pyfunction!(get_torrent_file_list, m)?)?;
m.add_class::<TorrentSessionHandle>()?;
Ok(())
}

View File

@@ -17,10 +17,12 @@
extern crate core;
mod allow_threading;
pub(crate) mod downloads;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
use crate::downloads::downloads_submodule;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
@@ -207,6 +209,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
downloads_submodule(m)?;
// top-level constructs
// TODO: ...

58
scripts/mktorrent.sh Executable file
View File

@@ -0,0 +1,58 @@
#!/usr/bin/env nix-shell
#!nix-shell -i bash -p mktorrent -p python3Packages.huggingface-hub -p git -p git-lfs
set -euo pipefail
set -x
MODEL="$1"
mkdir -p "$MODEL"
# Step 1: Clone/fetch the repo and get the hash of head
mkdir -p "$MODEL"
if test -d "$MODEL/git"; then
# Assert that the origin is correct
git -C "$MODEL/git" fetch
else
git clone "https://huggingface.co/$MODEL" "$MODEL/git"
fi
HASH=$(git -C "$MODEL/git" rev-parse origin/main)
LARGE_FILES=$(git -C "$MODEL/git" lfs ls-files --all --name-only)
SMALL_DIR="$MODEL/$HASH-small"
LARGE_DIR="$MODEL/$HASH-large"
mkdir -p "$SMALL_DIR" "$LARGE_DIR"
# Step 2: Prepare files. Two torrents: one for large files and one for metadata.
git -C "$MODEL/git" archive "$HASH" | tar -x -C "$SMALL_DIR"
echo "$LARGE_FILES" | xargs -I{} rm "$SMALL_DIR/{}"
echo "$LARGE_FILES" | xargs hf download "$MODEL" --revision "$HASH" --local-dir "$LARGE_DIR" --cache-dir "$(realpath .cache)" --include
if test -d "$LARGE_DIR/.cache"; then
echo ".cache created against our wishes, deleting it..."
rm -r "$LARGE_DIR/.cache"
fi
# Step 3: Create both torrents
mkdir -p "torrents/$MODEL/"
SMALL_TORRENT_PATH="torrents/$MODEL/${HASH}.small.torrent"
LARGE_TORRENT_PATH="torrents/$MODEL/${HASH}.large.torrent"
mktorrent "$SMALL_DIR/" --output="$SMALL_TORRENT_PATH" \
-n "$HASH" \
--web-seed="https://huggingface.co/$MODEL/raw/" \
--no-date \
--announce="udp://tracker.opentrackr.org:1337/announce"
# --private
mktorrent "$LARGE_DIR/" --output="$LARGE_TORRENT_PATH" \
-n "$HASH" \
--web-seed="https://huggingface.co/$MODEL/resolve/" \
--piece-length=24 \
--no-date \
--announce="udp://tracker.opentrackr.org:1337/announce"
# --private
echo "Successfully created torrent files in:"
echo "$SMALL_TORRENT_PATH"
echo "$LARGE_TORRENT_PATH"

View File

@@ -35,6 +35,7 @@ class Node:
api: API | None
node_id: NodeId
enable_torrents: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -66,7 +67,8 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
exo_shard_downloader(enable_torrents=args.enable_torrents),
initial_connection_messages=[],
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
@@ -74,7 +76,6 @@ class Node:
)
else:
worker = None
# We start every node with a master
master = Master(
node_id,
@@ -98,7 +99,7 @@ class Node:
election_result_sender=er_send,
)
return cls(router, worker, election, er_recv, master, api, node_id)
return cls(router, worker, election, er_recv, master, api, node_id, args.enable_torrents)
async def run(self):
async with self._tg as tg:
@@ -175,7 +176,7 @@ class Node:
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
exo_shard_downloader(enable_torrents=self.enable_torrents),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -215,6 +216,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
enable_torrents: bool = False
@classmethod
def parse(cls) -> Self:
@@ -256,6 +258,12 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--enable-torrents",
action="store_true",
dest="enable_torrents",
help="Enable BitTorrent-based downloads (experimental)",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -5,9 +5,9 @@ from typing import cast
import anyio
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import Response, StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
@@ -178,6 +178,7 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.get("/_internal/announce")(self.tracker_announce)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -236,7 +237,6 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -292,7 +292,6 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -601,8 +600,9 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
return total_available
@@ -623,6 +623,65 @@ class API:
]
)
async def tracker_announce(self, request: Request) -> Response:
"""BitTorrent tracker announce endpoint for private tracker."""
try:
from exo_pyo3_bindings import handle_tracker_announce # type: ignore
except ImportError as e:
raise HTTPException(
status_code=501,
detail="Torrent support not available (exo_pyo3_bindings not installed)",
) from e
# Parse announce parameters from query string
query_params = dict(request.query_params)
# Extract required parameters
try:
info_hash_hex = query_params.get("info_hash", "")
peer_id_hex = query_params.get("peer_id", "")
# URL decode and convert to bytes
info_hash = bytes.fromhex(info_hash_hex) if info_hash_hex else b""
peer_id = bytes.fromhex(peer_id_hex) if peer_id_hex else b""
if len(info_hash) != 20 or len(peer_id) != 20:
raise ValueError("info_hash and peer_id must be 20 bytes")
params = {
"info_hash": info_hash,
"peer_id": peer_id,
"port": int(query_params.get("port", "6881")),
"uploaded": int(query_params.get("uploaded", "0")),
"downloaded": int(query_params.get("downloaded", "0")),
"left": int(query_params.get("left", "0")),
"compact": query_params.get("compact", "1") == "1",
}
except (ValueError, KeyError) as e:
raise HTTPException(
status_code=400,
detail=f"Invalid announce parameters: {e}",
) from e
# Build peer list from topology
# TODO: Implement _build_peer_list_from_topology() to extract peers from self.state.topology
peers = [] # For now, return empty peer list
# Call Rust tracker handler
try:
response_bytes: bytes = handle_tracker_announce(params, peers) # type: ignore
return Response(
content=response_bytes,
media_type="text/plain",
headers={"Content-Type": "text/plain"},
)
except Exception as e:
logger.error(f"Tracker announce error: {e}")
raise HTTPException(
status_code=500,
detail=f"Tracker announce failed: {e}",
) from e
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -158,7 +158,6 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -201,7 +200,9 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:

View File

@@ -6,10 +6,9 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
@@ -20,11 +19,10 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -54,16 +52,19 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
cycles = topology.get_cycles() + [[node] for node in all_nodes]
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_profiles, command.model_meta.storage_size
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
if len(cycles_with_sufficient_memory) == 0:
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
@@ -93,15 +94,13 @@ def place_instance(
smallest_tb_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(
[node.node_id for node in cycle]
).is_thunderbolt_cycle([node.node_id for node in cycle])
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node.node_id) for node in cycle)
@@ -110,7 +109,11 @@ def place_instance(
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node.node_profile.memory.ram_available for node in cycle),
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
start=Memory(),
),
)
@@ -119,16 +122,14 @@ def place_instance(
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
[node.node_id for node in selected_cycle]
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -136,18 +137,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle[0].node_id,
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:

View File

@@ -1,4 +1,5 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import TypeGuard, cast
from loguru import logger
from pydantic import BaseModel
@@ -8,7 +9,7 @@ from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -23,32 +24,27 @@ class NodeWithProfile(BaseModel):
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[list[NodeId]],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[list[NodeWithProfile]]:
filtered_cycles: list[list[NodeWithProfile]] = []
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
for cycle in cycles:
if not all(node in node_profiles for node in cycle):
if not narrow_all_nodes(cycle):
continue
total_mem = sum(
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(
[
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
for node in cycle
]
)
filtered_cycles.append(cast(list[NodeInfo], cycle))
return filtered_cycles
def get_smallest_cycles(
cycles: list[list[NodeWithProfile]],
) -> list[list[NodeWithProfile]]:
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
@@ -139,9 +135,11 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
selected_cycle: list[NodeInfo],
sharding: Sharding,
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
@@ -178,16 +176,17 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for src, sink, connection in cycle_digraph.list_connections():
if not isinstance(connection, SocketConnection):
continue
if src == current_node and sink == next_node:
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break
@@ -195,7 +194,8 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
return hosts
def get_mlx_jaccl_devices_matrix(
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -204,7 +204,6 @@ def get_mlx_jaccl_devices_matrix(
to device j, or None if no connection exists or no interface name is found.
Diagonal elements are always None.
"""
selected_cycle = list(cycle_digraph.list_nodes())
num_nodes = len(selected_cycle)
matrix: list[list[str | None]] = [
[None for _ in range(num_nodes)] for _ in range(num_nodes)
@@ -215,38 +214,71 @@ def get_mlx_jaccl_devices_matrix(
if i == j:
continue
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
break
else:
logger.warning(
f"Failed to find interface name between {node_i} and {node_j}"
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
"Current ibv backend requires all-to-all rdma connections"
)
return matrix
def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j."""
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def _find_interface_name_for_ip(
ip_address: str,
node_info: NodeWithProfile,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -255,7 +287,7 @@ def _find_interface_name_for_ip(
def _find_ip_prioritised(
node: NodeWithProfile, other_node: NodeWithProfile, cycle_digraph: Topology
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -266,7 +298,7 @@ def _find_ip_prioritised(
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node.node_id, other_node.node_id, cycle_digraph))
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
@@ -292,7 +324,7 @@ def _find_ip_prioritised(
def get_mlx_ring_hosts_by_node(
selected_cycle: list[NodeWithProfile],
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
) -> dict[NodeId, list[Host]]:
@@ -329,7 +361,7 @@ def get_mlx_ring_hosts_by_node(
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node}"
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
@@ -343,30 +375,31 @@ def get_mlx_ring_hosts_by_node(
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
selected_cycle = list(cycle_digraph.list_nodes())
logger.info(f"Selecting coordinator: {coordinator}")
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip, _ in _find_connection_ip(n, coordinator, cycle_digraph):
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}
return {
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}

View File

@@ -1,36 +1,67 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
# TODO: this is a hack to get the port for the send_back_multiaddr
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)
return _create_connection

View File

@@ -19,13 +19,15 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodeGatheredInfo,
NodePerformanceMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -81,14 +83,21 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
NodePerformanceMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
),
@@ -154,7 +163,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)

View File

@@ -1,3 +1,5 @@
from typing import Callable
import pytest
from loguru import logger
@@ -5,20 +7,14 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_connection,
create_node_profile,
create_rdma_connection,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import SocketConnection
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -30,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -76,36 +77,34 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(2))
topology.add_connection(node_id_c, node_id_a, create_connection(3))
topology.add_connection(node_id_c, node_id_b, create_connection(4))
topology.add_connection(node_id_a, node_id_c, create_connection(5))
topology.add_connection(node_id_b, node_id_a, create_connection(6))
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
@@ -131,11 +130,12 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit() -> None:
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -146,7 +146,7 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -157,11 +157,12 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
topology.add_node(create_node(1001 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -172,7 +173,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -183,11 +184,12 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit() -> None:
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
@@ -200,7 +202,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, profiles)
place_instance(cic, topology, {})
def test_get_transition_events_no_change(instance: Instance):
@@ -245,103 +247,179 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_leaf_nodes(
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
model_meta.storage_size = Memory.from_bytes(1000)
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Daisy chain topology
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_a, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(1))
topology.add_connection(node_id_c, node_id_b, create_connection(1))
topology.add_connection(node_id_c, node_id_d, create_connection(1))
topology.add_connection(node_id_d, node_id_c, create_connection(1))
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
logger.info(list(topology.list_connections()))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
)
# act
placements = place_instance(cic, topology, {}, profiles)
# Act
placements = place_instance(cic, topology, {})
# assert
# Assert: D-E-F cycle should be selected as it has more total memory
assert len(placements) == 1
instance = list(placements.values())[0]
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(node_id_c, node_id_d)
)
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="192.168.1.100",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a, node_b, create_rdma_connection(3))
topology.add_connection(node_b, node_c, create_rdma_connection(4))
topology.add_connection(node_c, node_a, create_rdma_connection(5))
topology.add_connection(node_b, node_a, create_rdma_connection(3))
topology.add_connection(node_c, node_b, create_rdma_connection(4))
topology.add_connection(node_a, node_c, create_rdma_connection(5))
topology.add_connection(node_a, node_b, ethernet_conn)
topology.add_connection(node_b, node_c, ethernet_conn)
topology.add_connection(node_c, node_a, ethernet_conn)
topology.add_connection(node_a, node_c, ethernet_conn)
topology.add_connection(node_b, node_a, ethernet_conn)
topology.add_connection(node_c, node_b, ethernet_conn)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
cic = PlaceInstance(
sharding=Sharding.Tensor,
@@ -351,7 +429,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
min_nodes=1,
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -359,10 +437,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.jaccl_devices is not None
assert instance.ibv_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.jaccl_devices
matrix = instance.ibv_devices
assert len(matrix) == 3
for i in range(3):
@@ -371,15 +449,15 @@ def test_tensor_rdma_backend_connectivity_matrix(
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3

View File

@@ -1,48 +1,56 @@
from typing import Callable
import pytest
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_connection, create_node_profile
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(1)
connection2 = create_connection(2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
@@ -50,65 +58,64 @@ def test_filter_cycles_by_memory():
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory():
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(1)
connection2 = create_connection(2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
topology.get_cycles(), Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory():
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
@@ -120,38 +127,31 @@ def test_filter_multiple_cycles_by_memory():
}
def test_get_smallest_cycles():
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
# act
smallest_cycles = get_smallest_cycles(cycles)
smallest_cycles = get_smallest_cycles(topology.get_cycles())
# assert
assert len(smallest_cycles) == 1
@@ -168,6 +168,9 @@ def test_get_smallest_cycles():
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -176,25 +179,19 @@ def test_get_shard_assignments(
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_c_id, create_connection(2))
topology.add_connection(node_c_id, node_a_id, create_connection(3))
topology.add_connection(node_b_id, node_a_id, create_connection(4))
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
@@ -204,11 +201,7 @@ def test_get_shard_assignments(
hidden_size=1000,
supports_tensor=True,
)
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
cycles = topology.get_cycles()
selected_cycle = cycles[0]
# act
@@ -237,21 +230,28 @@ def test_get_shard_assignments(
)
def test_get_hosts_from_subgraph():
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -259,47 +259,108 @@ def test_get_hosts_from_subgraph():
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=("169.254.0.2"), port=1234),
Host(ip=("169.254.0.3"), port=1234),
Host(ip=("169.254.0.4"), port=1234),
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
conn_a_b = create_connection(1)
conn_b_a = create_connection(2)
conn_b_c = create_connection(3)
conn_c_b = create_connection(4)
conn_c_a = create_connection(5)
conn_a_c = create_connection(6)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
topology.add_connection(node_a_id, node_b_id, conn_a_b)
topology.add_connection(node_b_id, node_a_id, conn_b_a)
topology.add_connection(node_b_id, node_c_id, conn_b_c)
topology.add_connection(node_c_id, node_b_id, conn_c_b)
topology.add_connection(node_c_id, node_a_id, conn_c_a)
topology.add_connection(node_a_id, node_c_id, conn_a_c)
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_b)
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_jaccl_coordinators(
node_a_id, coordinator_port=5000, cycle_digraph=topology
cycle, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -328,11 +389,11 @@ def test_get_mlx_jaccl_coordinators():
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
"node_b should use the IP from conn_b_a"
)
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
"node_c should use the IP from conn_c_a"
)
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"

View File

@@ -1,14 +1,13 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
@pytest.fixture
@@ -17,15 +16,20 @@ def topology() -> Topology:
@pytest.fixture
def connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryUsage.from_bytes(
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -39,85 +43,162 @@ def node_profile() -> NodePerformanceProfile:
)
def test_add_node(topology: Topology):
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
# arrange
node_id = NodeId()
# act
topology.add_node(node_id)
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
# assert
assert topology.node_is_leaf(node_id)
data = topology.get_node_profile(node_id)
assert data == node_profile
def test_add_connection(topology: Topology, connection: SocketConnection):
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
data = list(conn for _, _, conn in topology.list_connections())
data = topology.get_connection_profile(connection)
# assert
assert data == [connection]
assert data == connection.connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
def test_remove_connection_still_connected(
topology: Topology, connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_connection(node_a, node_b, connection)
topology.remove_connection(connection)
# assert
assert list(topology.get_all_connections_between(node_a, node_b)) == []
assert topology.get_connection_profile(connection) is None
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_node(node_b)
topology.remove_node(connection.local_node_id)
# assert
assert list(topology.out_edges(node_a)) == []
assert topology.get_node_profile(connection.local_node_id) is None
def test_list_nodes(topology: Topology, connection: SocketConnection):
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeId) for node in nodes)
assert {node for node in nodes} == {node_a, node_b}
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}

View File

@@ -11,8 +11,10 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -25,23 +27,13 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import RDMAConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacTBConnections,
MacTBIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
@@ -55,12 +47,16 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -192,7 +188,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -200,12 +196,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -213,69 +205,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
# TODO: should be broken up into individual events instead of this monster
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
# TODO: makes me slightly sad
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
case MacTBIdentifiers():
profile.tb_interfaces = info.idents
case MacTBConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
(
conn_map[tb_conn.sink_uuid][0],
RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_profiles": new_profiles,
"last_seen": last_seen,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_connection(event.source, event.sink, event.edge)
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.sink, event.source, event.edge)
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -38,7 +38,6 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"

View File

@@ -24,8 +24,6 @@ class _InterceptHandler(logging.Handler):
except ValueError:
level = record.levelno
return
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())

View File

@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})
@@ -39,4 +43,7 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import SocketConnection
from exo.shared.types.topology import Connection
def test_state_serialization_roundtrip() -> None:
@@ -11,16 +11,17 @@ def test_state_serialization_roundtrip() -> None:
node_a = NodeId("node-a")
node_b = NodeId("node-b")
connection = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
connection = Connection(
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
state.topology.add_connection(node_a, node_b, connection)
state.topology.add_connection(connection)
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)
assert state.topology.to_snapshot().nodes == restored_state.topology.to_snapshot().nodes
assert set(state.topology.to_snapshot().connections) == set(restored_state.topology.to_snapshot().connections)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert restored_state.model_dump_json() == json_repr

View File

@@ -1,215 +1,203 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
class TopologySnapshot(BaseModel):
nodes: Sequence[NodeId]
connections: Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]
nodes: list[NodeInfo]
connections: list[Connection]
model_config = ConfigDict(frozen=True, extra="forbid")
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
@dataclass
class Topology:
# the _graph can be used as a int -> NodeId map.
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()), connections=self.list_connections()
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node_id in snapshot.nodes:
for node in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node_id)
topology.add_node(node)
for source, sink, conn in snapshot.connections:
topology.add_connection(source, sink, conn)
for connection in snapshot.connections:
topology.add_connection(connection)
return topology
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
return
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
]
def out_edges(
self, node_id: NodeId
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
if node_id not in self._vertex_indices:
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return (
(self._graph[nid], conn)
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
)
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
]
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._vertex_indices
return node_id in self._node_id_to_rx_id_map
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
source: NodeId,
sink: NodeId,
connection: SocketConnection | RDMAConnection,
connection: Connection,
) -> None:
if connection in self.get_all_connections_between(source, sink):
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
return
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
_ = self._graph.add_edge(src_id, sink_id, connection)
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
try:
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def list_connections(
self,
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
return (
(
self._graph[src_id],
self._graph[sink_id],
connection,
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._vertex_indices:
if node_id not in self._node_id_to_rx_id_map:
return
rx_idx = self._vertex_indices[node_id]
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph.remove_node(rx_idx)
del self._vertex_indices[node_id]
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
def replace_all_out_tb_connections(
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for sink, conn in new_connections:
self.add_connection(source, sink, conn)
def remove_connection(
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
) -> None:
if source not in self._vertex_indices or sink not in self._vertex_indices:
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
return
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[source], self._vertex_indices[sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == edge:
self._graph.remove_edge_from_index(conn_idx)
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
def get_cycles(self) -> list[list[NodeId]]:
def get_cycles(self) -> list[list[NodeInfo]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[NodeId]] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_cycles_tb(self) -> list[list[NodeId]]:
def get_cycles_tb(self) -> list[list[NodeInfo]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[list[NodeId]] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
topology = Topology()
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for source, sink, connection in self.list_connections():
if source in node_ids and sink in node_ids:
topology.add_connection(source, sink, connection)
for connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import SocketConnection
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
# TODO: bikeshed this naem
class NodeGatheredInfo(BaseEvent):
class NodePerformanceMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
class NodeDownloadProgress(BaseEvent):
@@ -97,15 +107,11 @@ class ChunkGenerated(BaseEvent):
class TopologyEdgeCreated(BaseEvent):
source: NodeId
sink: NodeId
edge: SocketConnection
edge: Connection
class TopologyEdgeDeleted(BaseEvent):
source: NodeId
sink: NodeId
edge: SocketConnection
edge: Connection
Event = (
@@ -119,8 +125,10 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodeGatheredInfo
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| TopologyEdgeCreated

View File

@@ -16,3 +16,4 @@ class ModelMetadata(CamelCaseModel):
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
revision: str | None = None # Git commit hash for torrent lookup

View File

@@ -1,11 +1,10 @@
import re
from typing import ClassVar
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from pydantic import BaseModel, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,14 +1,12 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import TBIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryUsage(CamelCaseModel):
class MemoryPerformanceProfile(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -54,16 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[TBIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
class ConnectionProfile(CamelCaseModel):
pass
throughput: float
latency: float
jitter: float

View File

@@ -1,75 +0,0 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class TBConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class TBIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
## Intentionally minimal, only collecting data we care about - there's a lot more
class TBReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str | None = None
class TBConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
class TBConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: TBReceptacleTag | None = None
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
if (
self.domain_uuid_key is None
or self.receptacle_1_tag is None
or self.receptacle_1_tag.receptacle_id_key is None
):
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
assert tag in ifaces # doesn't need to be an assertion but im confident
# if tag not in ifaces: return None
iface = f"rdma_{ifaces[tag]}"
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
def conn(self) -> TBConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
),
None,
)
if sink_key is None:
return None
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
class TBConnectivity(BaseModel, extra="ignore"):
SPThunderboltDataType: list[TBConnectivityData] = []
@classmethod
async def gather(cls) -> list[TBConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType

View File

@@ -1,32 +1,37 @@
from enum import Enum
from loguru import logger
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.utils.pydantic_ext import FrozenModel
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
def is_thunderbolt(self) -> bool:
logger.warning("duh")
return True
# TODO
class LinkType(str, Enum):
Thunderbolt = "Thunderbolt"
Ethernet = "Ethernet"
WiFi = "WiFi"
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
pass
total_bytes: Memory
class DownloadFailed(BaseDownloadProgress):

View File

@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
ibv_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]

View File

@@ -0,0 +1,43 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Callable
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
SystemPerformanceProfile,
)
class ResourceCollector(ABC):
@abstractmethod
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
class SystemResourceCollector(ResourceCollector):
async def collect(self) -> SystemPerformanceProfile: ...
class MemoryResourceCollector(ResourceCollector):
async def collect(self) -> MemoryPerformanceProfile: ...
class ResourceMonitor:
data_collectors: list[ResourceCollector]
effect_handlers: set[
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
]
async def _collect(
self,
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
tasks: list[
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
] = [collector.collect() for collector in self.data_collectors]
return await asyncio.gather(*tasks)
async def collect(self) -> None:
profiles = await self._collect()
for profile in profiles:
for effect_handler in self.effect_handlers:
effect_handler(profile)

View File

@@ -1,232 +0,0 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacTBIdentifiers(TaggedModel):
idents: Sequence[TBIdentifier]
class MacTBConnections(TaggedModel):
conns: Sequence[TBConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
# TODO
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacTBIdentifiers
| MacTBConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
if IS_DARWIN:
tg.start_soon(self._monitor_system_profiler)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await TBConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacTBIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacTBConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -1,70 +0,0 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

Some files were not shown because too many files have changed in this diff Show More