mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-18 10:58:35 -05:00
Compare commits
67 Commits
evan/mlnix
...
gather-lin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5579751843 | ||
|
|
d48a2700d9 | ||
|
|
3e9dd9dc92 | ||
|
|
3516736317 | ||
|
|
c6877623c2 | ||
|
|
c3bcd2036f | ||
|
|
401d94c6f0 | ||
|
|
678d318a12 | ||
|
|
380dd0be38 | ||
|
|
46b77583cd | ||
|
|
fdf983b334 | ||
|
|
e3cdc98c10 | ||
|
|
12367af37a | ||
|
|
da770293c5 | ||
|
|
5870bb0bf9 | ||
|
|
1f5bbada79 | ||
|
|
101adbeac7 | ||
|
|
bf2938c275 | ||
|
|
ea76134477 | ||
|
|
a4a33a4137 | ||
|
|
6d43d87a2a | ||
|
|
0b4af0c195 | ||
|
|
f85020a7df | ||
|
|
ea625964d1 | ||
|
|
6a3d5198a7 | ||
|
|
76b07040fc | ||
|
|
72477a3b76 | ||
|
|
24e1ca4bef | ||
|
|
096ad4fb6c | ||
|
|
c204d2897f | ||
|
|
3d8c7203f9 | ||
|
|
9c0f5074da | ||
|
|
576c9375d6 | ||
|
|
a19fd52fb6 | ||
|
|
e0dee9b48b | ||
|
|
c98d97c056 | ||
|
|
9e2bbe70b3 | ||
|
|
f4fcfdac16 | ||
|
|
c516ecdb75 | ||
|
|
6f76223e15 | ||
|
|
8178f5b173 | ||
|
|
ec8ce1a8ef | ||
|
|
5f7e429965 | ||
|
|
75cc3a0e07 | ||
|
|
6f5f8d5337 | ||
|
|
1bed99d17c | ||
|
|
04cb7ff30a | ||
|
|
b35aa9f6ff | ||
|
|
a07b402f72 | ||
|
|
83c5285a80 | ||
|
|
39ee2bf7bd | ||
|
|
991adfbd6f | ||
|
|
4b3de6b984 | ||
|
|
c8de3b90ea | ||
|
|
6e6567a802 | ||
|
|
a735dad667 | ||
|
|
aaf4e36bc3 | ||
|
|
3e623ccf0d | ||
|
|
c22dad8a7d | ||
|
|
4bc4d50685 | ||
|
|
e0aab46fd8 | ||
|
|
82ba42bae9 | ||
|
|
3671528fa4 | ||
|
|
e6434ec446 | ||
|
|
bdb43e1dbb | ||
|
|
e4a01e2b0e | ||
|
|
1200a7db64 |
17
.github/workflows/build-app.yml
vendored
17
.github/workflows/build-app.yml
vendored
@@ -113,11 +113,22 @@ jobs:
|
||||
uv python install
|
||||
uv sync --locked
|
||||
|
||||
- name: Install Nix
|
||||
uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Configure Cachix
|
||||
uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build dashboard
|
||||
run: |
|
||||
cd dashboard
|
||||
npm ci
|
||||
npm run build
|
||||
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
|
||||
mkdir -p dashboard/build
|
||||
cp -r "$DASHBOARD_OUT"/* dashboard/build/
|
||||
|
||||
- name: Install Sparkle CLI
|
||||
run: |
|
||||
|
||||
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -4340,25 +4340,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system_custodian"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tagptr"
|
||||
version = "0.2.0"
|
||||
|
||||
@@ -3,7 +3,6 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
@@ -25,7 +24,6 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
|
||||
1
TODO.md
1
TODO.md
@@ -19,6 +19,7 @@
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Validate RDMA connections with ibv_devinfo in the info gatherer
|
||||
|
||||
Potential refactors:
|
||||
|
||||
|
||||
@@ -56,6 +56,11 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Show warning if local network is not working and EXO is running.
|
||||
// The checker uses a longer timeout on first launch to allow time for
|
||||
// the permission prompt, so this correctly handles both:
|
||||
// 1. User denied permission on first launch
|
||||
// 2. Permission broke after restart (macOS TCC bug)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import os.log
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
/// actually work after a restart. This service uses NWConnection to mDNS multicast
|
||||
/// to verify actual connectivity.
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -35,30 +35,43 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
let result = await self.performCheck()
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
/// Checks connectivity using NWConnection to mDNS multicast.
|
||||
/// The connection attempt triggers the permission prompt if not yet shown.
|
||||
private func checkConnectivity(timeout: UInt64) async -> Status {
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
@@ -84,22 +97,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { state in
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
@@ -108,6 +106,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
// Otherwise keep waiting - might be showing permission prompt
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
@@ -127,7 +126,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
try? await Task.sleep(nanoseconds: timeout)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
|
||||
@@ -6,7 +6,7 @@ enum NetworkSetupHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
"/Library/Application Support/EXO/disable_bridge.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
@@ -28,35 +28,6 @@ enum NetworkSetupHelper {
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
|
||||
@@ -241,6 +241,9 @@ class PromptSizer:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Fix for transformers 5.x
|
||||
if hasattr(ids, "input_ids"):
|
||||
ids = ids.input_ids
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
|
||||
60
dashboard/dashboard.nix
Normal file
60
dashboard/dashboard.nix
Normal file
@@ -0,0 +1,60 @@
|
||||
{ lib
|
||||
, config
|
||||
, dream2nix
|
||||
, ...
|
||||
}:
|
||||
let
|
||||
# Read and parse the lock file
|
||||
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
|
||||
|
||||
# For packages with bundleDependencies, filter out deps that are bundled
|
||||
# (bundled deps are inside the tarball, not separate lockfile entries)
|
||||
fixedPackages = lib.mapAttrs
|
||||
(path: entry:
|
||||
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
|
||||
then entry // {
|
||||
dependencies = lib.filterAttrs
|
||||
(name: _: !(lib.elem name entry.bundleDependencies))
|
||||
(entry.dependencies or { });
|
||||
}
|
||||
else entry
|
||||
)
|
||||
(rawLockFile.packages or { });
|
||||
|
||||
fixedLockFile = rawLockFile // { packages = fixedPackages; };
|
||||
in
|
||||
{
|
||||
imports = [
|
||||
dream2nix.modules.dream2nix.nodejs-package-lock-v3
|
||||
dream2nix.modules.dream2nix.nodejs-granular-v3
|
||||
];
|
||||
|
||||
name = "exo-dashboard";
|
||||
version = "1.0.0";
|
||||
|
||||
mkDerivation = {
|
||||
src = config.deps.dashboardSrc;
|
||||
|
||||
buildPhase = ''
|
||||
runHook preBuild
|
||||
npm run build
|
||||
runHook postBuild
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
cp -r build $out/build
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
|
||||
deps = { nixpkgs, ... }: {
|
||||
inherit (nixpkgs) stdenv;
|
||||
dashboardSrc = null; # Injected by parts.nix
|
||||
};
|
||||
|
||||
nodejs-package-lock-v3 = {
|
||||
# Don't use packageLockFile - provide the fixed lock content directly
|
||||
packageLock = fixedLockFile;
|
||||
};
|
||||
}
|
||||
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,6 +863,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -902,6 +903,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,6 +1520,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1527,6 +1530,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1939,6 +1943,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2646,6 +2651,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2833,6 +2839,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2977,6 +2984,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2998,6 +3006,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
44
dashboard/parts.nix
Normal file
44
dashboard/parts.nix
Normal file
@@ -0,0 +1,44 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to only include dashboard directory
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
inDashboardDir =
|
||||
(lib.hasInfix "/dashboard/" path)
|
||||
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
|| (baseName == "dashboard" && type == "directory");
|
||||
in
|
||||
inDashboardDir;
|
||||
};
|
||||
|
||||
# Build the dashboard with dream2nix (includes node_modules in output)
|
||||
dashboardFull = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
# Inject the filtered source
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
|
||||
}
|
||||
];
|
||||
};
|
||||
in
|
||||
{
|
||||
# Extract just the static site from the full build
|
||||
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
|
||||
cp -r ${dashboardFull}/build $out
|
||||
'';
|
||||
};
|
||||
}
|
||||
@@ -60,12 +60,39 @@
|
||||
return models;
|
||||
});
|
||||
|
||||
// Auto-select the first available model if none is selected
|
||||
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
|
||||
let previousModelIds: Set<string> = new Set();
|
||||
|
||||
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
|
||||
$effect(() => {
|
||||
const models = availableModels();
|
||||
if (models.length > 0 && !currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
const currentModelIds = new Set(models.map(m => m.id));
|
||||
|
||||
if (models.length > 0) {
|
||||
// Find newly added models (in current but not in previous)
|
||||
const newModels = models.filter(m => !previousModelIds.has(m.id));
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some(m => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
|
||||
// Update previous model IDs for next comparison
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
|
||||
@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
// Uses API preview data when available, falls back to local estimation
|
||||
const placementPreview = $derived(() => {
|
||||
const nodeArray = nodeList();
|
||||
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
|
||||
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
|
||||
|
||||
const numNodes = nodeArray.length;
|
||||
const iconSize = numNodes === 1 ? 50 : 36;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from 'svelte';
|
||||
import * as d3 from 'd3';
|
||||
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
|
||||
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
|
||||
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: typeof data.nodes[string]): string | null {
|
||||
function checkNode(node: NodeInfo | undefined): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
@@ -39,17 +39,19 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
if (node.ip_to_interface) {
|
||||
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
@@ -255,21 +257,24 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
|
||||
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
|
||||
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
|
||||
const pairMap = new Map<string, PairEntry>();
|
||||
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
|
||||
|
||||
const a = edge.source < edge.target ? edge.source : edge.target;
|
||||
const b = edge.source < edge.target ? edge.target : edge.source;
|
||||
const key = `${a}|${b}`;
|
||||
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
|
||||
|
||||
|
||||
if (edge.source === a) entry.aToB = true;
|
||||
else entry.bToA = true;
|
||||
|
||||
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const ifaceInfo = getInterfaceLabel(edge.source, ip);
|
||||
entry.connections.push({
|
||||
from: edge.source,
|
||||
@@ -338,9 +343,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
@@ -381,32 +385,32 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
|
||||
if (quadrantEdges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
edges.forEach(edge => {
|
||||
|
||||
quadrantEdges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
|
||||
@@ -99,20 +99,36 @@ interface RawNodeProfile {
|
||||
|
||||
interface RawTopologyNode {
|
||||
nodeId: string;
|
||||
nodeProfile: RawNodeProfile;
|
||||
nodeProfile?: RawNodeProfile;
|
||||
}
|
||||
|
||||
interface RawTopologyConnection {
|
||||
localNodeId: string;
|
||||
sendBackNodeId: string;
|
||||
sendBackMultiaddr?:
|
||||
| { multiaddr?: string; address?: string; ip_address?: string }
|
||||
| string;
|
||||
// New connection edge types from Python SocketConnection/RDMAConnection
|
||||
interface RawSocketConnection {
|
||||
sinkMultiaddr?: {
|
||||
address?: string;
|
||||
// Multiaddr uses snake_case (no camelCase alias)
|
||||
ip_address?: string;
|
||||
ipAddress?: string; // fallback in case it changes
|
||||
address_type?: string;
|
||||
port?: number;
|
||||
};
|
||||
}
|
||||
|
||||
interface RawRDMAConnection {
|
||||
sourceRdmaIface?: string;
|
||||
sinkRdmaIface?: string;
|
||||
}
|
||||
|
||||
type RawConnectionEdge = RawSocketConnection | RawRDMAConnection;
|
||||
|
||||
// New nested mapping format: { source: { sink: [edge1, edge2, ...] } }
|
||||
type RawConnectionsMap = Record<string, Record<string, RawConnectionEdge[]>>;
|
||||
|
||||
interface RawTopology {
|
||||
nodes: RawTopologyNode[];
|
||||
connections?: RawTopologyConnection[];
|
||||
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
|
||||
nodes: (string | RawTopologyNode)[];
|
||||
// New nested mapping format
|
||||
connections?: RawConnectionsMap;
|
||||
}
|
||||
|
||||
type RawNodeProfiles = Record<string, RawNodeProfile>;
|
||||
@@ -213,9 +229,18 @@ function transformTopology(
|
||||
const nodes: Record<string, NodeInfo> = {};
|
||||
const edges: TopologyEdge[] = [];
|
||||
|
||||
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
|
||||
for (const node of raw.nodes || []) {
|
||||
const mergedProfile = profiles?.[node.nodeId];
|
||||
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
|
||||
// Determine the node ID - could be a string or an object with nodeId property
|
||||
const nodeId = typeof node === "string" ? node : node.nodeId;
|
||||
if (!nodeId) continue;
|
||||
|
||||
// Get the profile - from the separate profiles map or from the node object itself
|
||||
const profileFromMap = profiles?.[nodeId];
|
||||
const profileFromNode =
|
||||
typeof node === "object" ? node.nodeProfile : undefined;
|
||||
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
|
||||
|
||||
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
|
||||
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
|
||||
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
|
||||
@@ -264,7 +289,7 @@ function transformTopology(
|
||||
}
|
||||
}
|
||||
|
||||
nodes[node.nodeId] = {
|
||||
nodes[nodeId] = {
|
||||
system_info: {
|
||||
model_id: profile?.modelId ?? "Unknown",
|
||||
chip: profile?.chipId,
|
||||
@@ -292,29 +317,34 @@ function transformTopology(
|
||||
};
|
||||
}
|
||||
|
||||
for (const conn of raw.connections || []) {
|
||||
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
|
||||
if (conn.localNodeId === conn.sendBackNodeId) continue;
|
||||
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
|
||||
// Handle connections - nested mapping format { source: { sink: [edges] } }
|
||||
const connections = raw.connections;
|
||||
if (connections && typeof connections === "object") {
|
||||
for (const [source, sinks] of Object.entries(connections)) {
|
||||
if (!sinks || typeof sinks !== "object") continue;
|
||||
for (const [sink, edgeList] of Object.entries(sinks)) {
|
||||
if (!Array.isArray(edgeList)) continue;
|
||||
for (const edge of edgeList) {
|
||||
// Extract IP from SocketConnection (uses snake_case: ip_address)
|
||||
let sendBackIp: string | undefined;
|
||||
if (edge && typeof edge === "object" && "sinkMultiaddr" in edge) {
|
||||
const multiaddr = edge.sinkMultiaddr;
|
||||
if (multiaddr) {
|
||||
// Try both snake_case (actual) and camelCase (in case it changes)
|
||||
sendBackIp =
|
||||
multiaddr.ip_address ||
|
||||
multiaddr.ipAddress ||
|
||||
extractIpFromMultiaddr(multiaddr.address);
|
||||
}
|
||||
}
|
||||
// RDMAConnection (sourceRdmaIface/sinkRdmaIface) has no IP - edge just shows connection exists
|
||||
|
||||
let sendBackIp: string | undefined;
|
||||
if (conn.sendBackMultiaddr) {
|
||||
const multi = conn.sendBackMultiaddr;
|
||||
if (typeof multi === "string") {
|
||||
sendBackIp = extractIpFromMultiaddr(multi);
|
||||
} else {
|
||||
sendBackIp =
|
||||
multi.ip_address ||
|
||||
extractIpFromMultiaddr(multi.multiaddr) ||
|
||||
extractIpFromMultiaddr(multi.address);
|
||||
if (nodes[source] && nodes[sink] && source !== sink) {
|
||||
edges.push({ source, target: sink, sendBackIp });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
edges.push({
|
||||
source: conn.localNodeId,
|
||||
target: conn.sendBackNodeId,
|
||||
sendBackIp,
|
||||
});
|
||||
}
|
||||
|
||||
return { nodes, edges };
|
||||
|
||||
@@ -400,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Auto-select the launched model only if no model is currently selected
|
||||
if (!selectedChatModel()) {
|
||||
setSelectedChatModel(modelId);
|
||||
}
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -763,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
async function deleteInstance(instanceId: string) {
|
||||
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
|
||||
|
||||
// Get the model ID of the instance being deleted before we delete it
|
||||
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
|
||||
const wasSelected = selectedChatModel() === deletedInstanceModelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/instance/${instanceId}`, {
|
||||
method: 'DELETE',
|
||||
@@ -771,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to delete instance:', response.status);
|
||||
} else if (wasSelected) {
|
||||
// If we deleted the currently selected model, switch to another available model
|
||||
// Find another instance that isn't the one we just deleted
|
||||
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
|
||||
if (remainingInstances.length > 0) {
|
||||
// Select the last instance (most recently added, since objects preserve insertion order)
|
||||
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting instance:', error);
|
||||
@@ -895,7 +915,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
|
||||
const [tag, shard] = getTagged(shardWrapped);
|
||||
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
|
||||
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
|
||||
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
|
||||
return { runnerId, tag, deviceRank };
|
||||
});
|
||||
|
||||
|
||||
162
flake.lock
generated
162
flake.lock
generated
@@ -1,5 +1,42 @@
|
||||
{
|
||||
"nodes": {
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1767744144,
|
||||
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"dream2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -8,11 +45,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -21,6 +58,22 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-parts": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
@@ -43,11 +96,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -57,22 +110,85 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"purescript-overlay": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"slimlock": "slimlock"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1728546539,
|
||||
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"crane": "crane",
|
||||
"dream2nix": "dream2nix",
|
||||
"fenix": "fenix",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -82,6 +198,28 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"slimlock": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"purescript-overlay",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1688756706,
|
||||
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"treefmt-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -89,11 +227,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
63
flake.nix
63
flake.nix
@@ -9,6 +9,8 @@
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
crane.url = "github:ipetkov/crane";
|
||||
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
@@ -18,6 +20,14 @@
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
nixConfig = {
|
||||
@@ -36,12 +46,16 @@
|
||||
|
||||
imports = [
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
{ config, inputs', pkgs, lib, ... }:
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
@@ -54,13 +68,16 @@
|
||||
};
|
||||
rustfmt = {
|
||||
enable = true;
|
||||
package = fenixToolchain.rustfmt;
|
||||
package = config.rust.toolchain;
|
||||
};
|
||||
prettier = {
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format.enable = true;
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -71,6 +88,8 @@
|
||||
'';
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
packages =
|
||||
[
|
||||
# FORMATTING
|
||||
@@ -83,14 +102,8 @@
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
(fenixToolchain.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
rustup # Just here to make RustRover happy
|
||||
config.rust.toolchain
|
||||
maturin
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
@@ -102,30 +115,20 @@
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
|
||||
# IFCONFIG
|
||||
++ lib.optionals stdenv.isLinux [
|
||||
unixtools.ifconfig
|
||||
|
||||
# Build dependencies for Linux
|
||||
pkg-config
|
||||
openssl
|
||||
])
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
|
||||
# MACMON
|
||||
]
|
||||
++ lib.optionals stdenv.isDarwin [
|
||||
macmon
|
||||
]);
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
shellHook = ''
|
||||
# PYTHON
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
|
||||
${lib.optionalString pkgs.stdenv.isLinux ''
|
||||
# Build environment for Linux
|
||||
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
|
||||
${lib.optionalString stdenv.isLinux ''
|
||||
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
echo
|
||||
echo "🍎🍎 Run 'just <recipe>' to get started"
|
||||
just --list
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
2
justfile
2
justfile
@@ -1,3 +1,5 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
145
rust/parts.nix
Normal file
145
rust/parts.nix
Normal file
@@ -0,0 +1,145 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
"rust-analyzer"
|
||||
];
|
||||
|
||||
# Crane with fenix toolchain
|
||||
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
|
||||
|
||||
# Source filtering - only include rust/ directory and root Cargo files
|
||||
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
parentDir = builtins.dirOf path;
|
||||
inRustDir =
|
||||
(lib.hasInfix "/rust/" path)
|
||||
|| (lib.hasSuffix "/rust" parentDir)
|
||||
|| (baseName == "rust" && type == "directory");
|
||||
isRootCargoFile =
|
||||
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
|
||||
&& (builtins.dirOf path == toString inputs.self);
|
||||
in
|
||||
isRootCargoFile
|
||||
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
|
||||
};
|
||||
|
||||
# Common arguments for all Rust builds
|
||||
commonArgs = {
|
||||
inherit src;
|
||||
pname = "exo-rust";
|
||||
version = "0.0.1";
|
||||
strictDeps = true;
|
||||
|
||||
nativeBuildInputs = [
|
||||
pkgs.pkg-config
|
||||
pkgs.python313 # Required for pyo3-build-config
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
pkgs.openssl
|
||||
pkgs.python313 # Required for pyo3 tests
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
# Required for pyo3 tests to find libpython
|
||||
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
|
||||
};
|
||||
|
||||
# Build dependencies once for caching
|
||||
cargoArtifacts = craneLib.buildDepsOnly (
|
||||
commonArgs
|
||||
// {
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
in
|
||||
{
|
||||
# Export toolchain for use in treefmt and devShell
|
||||
options.rust = {
|
||||
toolchain = lib.mkOption {
|
||||
type = lib.types.package;
|
||||
default = rustToolchain;
|
||||
description = "The Rust toolchain to use";
|
||||
};
|
||||
};
|
||||
|
||||
config = {
|
||||
packages = {
|
||||
# Python bindings wheel via maturin
|
||||
exo_pyo3_bindings = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
pname = "exo_pyo3_bindings";
|
||||
|
||||
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
|
||||
pkgs.maturin
|
||||
];
|
||||
|
||||
buildPhaseCargoCommand = ''
|
||||
maturin build \
|
||||
--release \
|
||||
--manylinux off \
|
||||
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
|
||||
--features "pyo3/extension-module,pyo3/experimental-async" \
|
||||
--interpreter ${pkgs.python313}/bin/python \
|
||||
--out dist
|
||||
'';
|
||||
|
||||
# Don't use crane's default install behavior
|
||||
doNotPostBuildInstallCargoBinaries = true;
|
||||
|
||||
installPhaseCommand = ''
|
||||
mkdir -p $out
|
||||
cp dist/*.whl $out/
|
||||
'';
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Full workspace build (all crates)
|
||||
cargo-build = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
# Run tests with nextest
|
||||
cargo-nextest = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
|
||||
# Build documentation
|
||||
cargo-doc = craneLib.cargoDoc (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
};
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
@@ -1,69 +0,0 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
@@ -13,12 +13,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
@@ -67,8 +61,6 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
@@ -236,6 +228,7 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -291,6 +284,7 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -381,35 +375,8 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
@@ -417,16 +384,10 @@ class API:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
with recv as token_chunks:
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -442,11 +403,11 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
@@ -458,7 +419,7 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
@@ -466,7 +427,7 @@ class API:
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -495,7 +456,7 @@ class API:
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
@@ -503,7 +464,7 @@ class API:
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -544,8 +505,6 @@ class API:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -562,17 +521,16 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
@@ -589,19 +547,15 @@ class API:
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available
|
||||
for profile in self.state.node_profiles.values():
|
||||
total_available += profile.memory.ram_available
|
||||
|
||||
return total_available
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceDeleted,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
@@ -158,6 +159,7 @@ class Master:
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_profiles,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
@@ -200,9 +202,7 @@ class Master:
|
||||
async def _plan(self) -> None:
|
||||
while True:
|
||||
# kill broken instances
|
||||
connected_node_ids = set(
|
||||
[x.node_id for x in self.state.topology.list_nodes()]
|
||||
)
|
||||
connected_node_ids = set(self.state.topology.list_nodes())
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
@@ -237,6 +237,8 @@ class Master:
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
event.when = str(datetime.now(tz=timezone.utc))
|
||||
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
@@ -6,9 +6,10 @@ from typing import Sequence
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
Cycle,
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
@@ -19,10 +20,11 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -52,19 +54,14 @@ def place_instance(
|
||||
command: PlaceInstance,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
|
||||
)
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, command.model_meta.storage_size
|
||||
candidate_cycles, node_profiles, command.model_meta.storage_size
|
||||
)
|
||||
if not cycles_with_sufficient_memory:
|
||||
if len(cycles_with_sufficient_memory) == 0:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
if command.sharding == Sharding.Tensor:
|
||||
@@ -92,44 +89,38 @@ def place_instance(
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
||||
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
|
||||
]
|
||||
|
||||
if smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if any(topology.node_is_leaf(node.node_id) for node in cycle)
|
||||
if any(topology.node_is_leaf(node_id) for node_id in cycle)
|
||||
]
|
||||
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
),
|
||||
(node_profiles[node_id].memory.ram_available for node_id in cycle),
|
||||
start=Memory(),
|
||||
),
|
||||
)
|
||||
|
||||
shard_assignments = get_shard_assignments(
|
||||
command.model_meta, selected_cycle, command.sharding
|
||||
command.model_meta, selected_cycle, command.sharding, node_profiles
|
||||
)
|
||||
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
|
||||
|
||||
instance_id = InstanceId()
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.warning(
|
||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
||||
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
command.instance_meta = InstanceMeta.MlxRing
|
||||
@@ -137,19 +128,20 @@ def place_instance(
|
||||
# TODO: Single node instances
|
||||
match command.instance_meta:
|
||||
case InstanceMeta.MlxJaccl:
|
||||
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
|
||||
selected_cycle,
|
||||
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
||||
[node_id for node_id in selected_cycle],
|
||||
cycle_digraph,
|
||||
)
|
||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||
selected_cycle,
|
||||
coordinator=selected_cycle.node_ids[0],
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
cycle_digraph=cycle_digraph,
|
||||
node_profiles=node_profiles,
|
||||
)
|
||||
target_instances[instance_id] = MlxJacclInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
@@ -158,6 +150,7 @@ def place_instance(
|
||||
selected_cycle=selected_cycle,
|
||||
cycle_digraph=cycle_digraph,
|
||||
ephemeral_port=ephemeral_port,
|
||||
node_profiles=node_profiles,
|
||||
)
|
||||
target_instances[instance_id] = MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
from collections.abc import Generator
|
||||
from typing import TypeGuard, cast
|
||||
from collections.abc import Generator, Mapping
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
@@ -19,58 +17,55 @@ from exo.shared.types.worker.shards import (
|
||||
)
|
||||
|
||||
|
||||
class NodeWithProfile(BaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
|
||||
|
||||
def filter_cycles_by_memory(
|
||||
cycles: list[list[NodeInfo]], required_memory: Memory
|
||||
) -> list[list[NodeInfo]]:
|
||||
filtered_cycles: list[list[NodeInfo]] = []
|
||||
cycles: list[Cycle],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
required_memory: Memory,
|
||||
) -> list[Cycle]:
|
||||
filtered_cycles: list[Cycle] = []
|
||||
for cycle in cycles:
|
||||
if not narrow_all_nodes(cycle):
|
||||
if not all(node in node_profiles for node in cycle):
|
||||
continue
|
||||
|
||||
total_mem = sum(
|
||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
||||
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
|
||||
start=Memory(),
|
||||
)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
||||
filtered_cycles.append(cycle)
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
def get_smallest_cycles(
|
||||
cycles: list[Cycle],
|
||||
) -> list[Cycle]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
cycle: Cycle,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
):
|
||||
cycle_memory = sum(
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
|
||||
start=Memory(),
|
||||
)
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
world_size = len(cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
layers_assigned = 0
|
||||
for i, node in enumerate(selected_cycle):
|
||||
if i == len(selected_cycle) - 1:
|
||||
for i, node_id in enumerate(cycle):
|
||||
if i == len(cycle) - 1:
|
||||
node_layers = total_layers - layers_assigned
|
||||
else:
|
||||
node_layers = round(
|
||||
total_layers
|
||||
* (
|
||||
node.node_profile.memory.ram_available.in_bytes
|
||||
node_profiles[node_id].memory.ram_available.in_bytes
|
||||
/ cycle_memory.in_bytes
|
||||
)
|
||||
)
|
||||
@@ -88,7 +83,7 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
)
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
node_to_runner[node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
@@ -102,14 +97,14 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
|
||||
def get_shard_assignments_for_tensor_parallel(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
cycle: Cycle,
|
||||
):
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
world_size = len(cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
for i, node in enumerate(selected_cycle):
|
||||
for i, node_id in enumerate(cycle):
|
||||
shard = TensorShardMetadata(
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
@@ -122,7 +117,7 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
runner_id = RunnerId()
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
node_to_runner[node_id] = runner_id
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_meta.model_id,
|
||||
@@ -135,21 +130,21 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
def get_shard_assignments(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle: Cycle,
|
||||
sharding: Sharding,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> ShardAssignments:
|
||||
if not narrow_all_nodes(selected_cycle):
|
||||
raise ValueError("All nodes must have profiles to create shard assignments")
|
||||
match sharding:
|
||||
case Sharding.Pipeline:
|
||||
return get_shard_assignments_for_pipeline_parallel(
|
||||
model_meta=model_meta,
|
||||
selected_cycle=selected_cycle,
|
||||
cycle=cycle,
|
||||
node_profiles=node_profiles,
|
||||
)
|
||||
case Sharding.Tensor:
|
||||
return get_shard_assignments_for_tensor_parallel(
|
||||
model_meta=model_meta,
|
||||
selected_cycle=selected_cycle,
|
||||
cycle=cycle,
|
||||
)
|
||||
|
||||
|
||||
@@ -164,38 +159,40 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
)
|
||||
return []
|
||||
|
||||
cycle = cycles[0]
|
||||
|
||||
get_thunderbolt = False
|
||||
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
|
||||
if cycle_digraph.is_thunderbolt_cycle(cycle):
|
||||
get_thunderbolt = True
|
||||
|
||||
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
|
||||
|
||||
cycle = cycles[0]
|
||||
hosts: list[Host] = []
|
||||
for i in range(len(cycle)):
|
||||
current_node = cycle[i]
|
||||
next_node = cycle[(i + 1) % len(cycle)]
|
||||
current_node = cycle.node_ids[i]
|
||||
next_node = cycle.node_ids[(i + 1) % len(cycle)]
|
||||
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == current_node.node_id
|
||||
and connection.send_back_node_id == next_node.node_id
|
||||
):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
assert connection.send_back_multiaddr is not None
|
||||
host = Host(
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
for connection in cycle_digraph.get_all_connections_between(
|
||||
source=current_node, sink=next_node
|
||||
):
|
||||
if not isinstance(connection, SocketConnection):
|
||||
continue
|
||||
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
|
||||
host = Host(
|
||||
ip=connection.sink_multiaddr.ip_address,
|
||||
port=connection.sink_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_ibv_devices_matrix(
|
||||
selected_cycle: list[NodeInfo],
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
selected_cycle: list[NodeId],
|
||||
cycle_digraph: Topology,
|
||||
) -> list[list[str | None]]:
|
||||
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
|
||||
@@ -214,72 +211,37 @@ def get_mlx_ibv_devices_matrix(
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_rdma_interface_name_for_ip(
|
||||
connection_ip, node_i
|
||||
):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
)
|
||||
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(conn, RDMAConnection):
|
||||
matrix[i][j] = conn.source_rdma_iface
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
|
||||
f"Failed to find interface name between {node_i} and {node_j}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current ibv backend requires all-to-all rdma connections"
|
||||
"Current jaccl backend requires all-to-all RDMA connections"
|
||||
)
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _find_connection_ip(
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
node_i: NodeId,
|
||||
node_j: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_rdma_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
|
||||
continue
|
||||
logger.info(f" | {interface.name}: {interface.ip_address}")
|
||||
if interface.ip_address != ip_address:
|
||||
continue
|
||||
|
||||
logger.info("Found")
|
||||
return f"rdma_{interface.name}"
|
||||
|
||||
return None
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(connection, SocketConnection):
|
||||
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
ip_address: str, node_profile: NodePerformanceProfile
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
for interface in node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
@@ -287,7 +249,10 @@ def _find_interface_name_for_ip(
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
node_id: NodeId,
|
||||
other_node_id: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
@@ -298,9 +263,12 @@ def _find_ip_prioritised(
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
|
||||
iface_map = {
|
||||
_find_interface_name_for_ip(ip, node_profiles[other_node_id]): ip
|
||||
for ip, _ in ips
|
||||
}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
@@ -324,9 +292,10 @@ def _find_ip_prioritised(
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: list[NodeInfo],
|
||||
selected_cycle: Cycle,
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
"""Generate per-node host lists for MLX ring backend.
|
||||
|
||||
@@ -341,14 +310,13 @@ def get_mlx_ring_hosts_by_node(
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
for rank, node in enumerate(selected_cycle):
|
||||
node_id = node.node_id
|
||||
for rank, node_id in enumerate(selected_cycle):
|
||||
left_rank = (rank - 1) % world_size
|
||||
right_rank = (rank + 1) % world_size
|
||||
|
||||
hosts_for_node: list[Host] = []
|
||||
|
||||
for idx, other_node in enumerate(selected_cycle):
|
||||
for idx, other_node_id in enumerate(selected_cycle):
|
||||
if idx == rank:
|
||||
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
continue
|
||||
@@ -358,10 +326,12 @@ def get_mlx_ring_hosts_by_node(
|
||||
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
connection_ip = _find_ip_prioritised(
|
||||
node_id, other_node_id, cycle_digraph, node_profiles
|
||||
)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
@@ -375,31 +345,34 @@ def get_mlx_ring_hosts_by_node(
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator: NodeId,
|
||||
coordinator_port: int,
|
||||
cycle_digraph: Topology,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[NodeId, str]:
|
||||
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
|
||||
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
|
||||
|
||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
logger.info(f"Selecting coordinator: {coordinator}")
|
||||
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
def get_ip_for_node(n: NodeId) -> str:
|
||||
if n == coordinator:
|
||||
return "0.0.0.0"
|
||||
|
||||
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
|
||||
if ip:
|
||||
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_profiles)
|
||||
if ip is not None:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||
)
|
||||
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
||||
|
||||
return {
|
||||
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
|
||||
n: f"{get_ip_for_node(n)}:{coordinator_port}"
|
||||
for n in cycle_digraph.list_nodes()
|
||||
}
|
||||
|
||||
@@ -1,67 +1,39 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
|
||||
return _create_node
|
||||
def create_node_profile(memory: int) -> NodePerformanceProfile:
|
||||
return NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
|
||||
for i in range(10)
|
||||
],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
|
||||
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
@pytest.fixture
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
ip_counter = 1
|
||||
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||
)
|
||||
|
||||
def _create_connection(
|
||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
||||
) -> Connection:
|
||||
nonlocal port_counter
|
||||
nonlocal ip_counter
|
||||
# assign unique ips
|
||||
ip_counter += 1
|
||||
if send_back_port is None:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
return _create_connection
|
||||
def create_rdma_connection(iface: int) -> RDMAConnection:
|
||||
return RDMAConnection(
|
||||
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
|
||||
)
|
||||
|
||||
@@ -19,15 +19,13 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodePerformanceMeasured,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
MemoryUsage,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
@@ -83,21 +81,14 @@ async def test_master():
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodePerformanceMeasured(
|
||||
NodeGatheredInfo(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
info=MemoryUsage(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -163,7 +154,7 @@ async def test_master():
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[0].event, NodeGatheredInfo)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_node_profile,
|
||||
create_rdma_connection,
|
||||
create_socket_connection,
|
||||
)
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -26,11 +29,6 @@ from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def instance() -> Instance:
|
||||
return MlxRingInstance(
|
||||
@@ -77,34 +75,57 @@ def test_get_instance_placements_create_instance(
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
model_meta.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
cic = place_instance_command(model_meta)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# fully connected (directed) between the 3 nodes
|
||||
conn_a_b = Connection(
|
||||
source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)
|
||||
)
|
||||
conn_b_c = Connection(
|
||||
source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)
|
||||
)
|
||||
conn_c_a = Connection(
|
||||
source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)
|
||||
)
|
||||
conn_c_b = Connection(
|
||||
source=node_id_c, sink=node_id_b, edge=create_socket_connection(4)
|
||||
)
|
||||
conn_a_c = Connection(
|
||||
source=node_id_a, sink=node_id_c, edge=create_socket_connection(5)
|
||||
)
|
||||
conn_b_a = Connection(
|
||||
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
|
||||
)
|
||||
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(available_memory[0]),
|
||||
node_id_b: create_node_profile(available_memory[1]),
|
||||
node_id_c: create_node_profile(available_memory[2]),
|
||||
}
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
topology.add_connection(conn_b_a)
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {})
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
@@ -130,12 +151,11 @@ def test_get_instance_placements_create_instance(
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -146,7 +166,7 @@ def test_get_instance_placements_one_node_exact_fit(
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -157,12 +177,11 @@ def test_get_instance_placements_one_node_exact_fit(
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1001 * 1024, node_id))
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1001 * 1024)}
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -173,7 +192,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -184,12 +203,11 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
cic = place_instance_command(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -202,7 +220,7 @@ def test_get_instance_placements_one_node_not_fit(
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
||||
place_instance(cic, topology, {})
|
||||
place_instance(cic, topology, {}, profiles)
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(instance: Instance):
|
||||
@@ -247,179 +265,130 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
def test_placement_selects_leaf_nodes(
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
# arrange
|
||||
topology = Topology()
|
||||
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size = Memory.from_bytes(1000)
|
||||
|
||||
# Create node ids
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
node_id_d = NodeId()
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
topology.add_node(create_node(800, node_id_c))
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(500),
|
||||
node_id_b: create_node_profile(600),
|
||||
node_id_c: create_node_profile(600),
|
||||
node_id_d: create_node_profile(500),
|
||||
}
|
||||
|
||||
# D-E-F cycle total memory = 1800 (> A-B-C total)
|
||||
topology.add_node(create_node(600, node_id_d))
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_node(node_id_d)
|
||||
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
# Daisy chain topology (directed)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_c, sink=node_id_b, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_c, sink=node_id_d, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
|
||||
)
|
||||
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
cic = place_instance_command(model_meta=model_meta)
|
||||
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
instance = list(placements.values())[0]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
|
||||
(
|
||||
node_id_c,
|
||||
node_id_d,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
topology = Topology()
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
node_c = NodeId()
|
||||
|
||||
node_a = create_node(500, node_id_a)
|
||||
node_b = create_node(500, node_id_b)
|
||||
node_c = create_node(500, node_id_c)
|
||||
profiles = {
|
||||
node_a: create_node_profile(500),
|
||||
node_b: create_node_profile(500),
|
||||
node_c: create_node_profile(500),
|
||||
}
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(
|
||||
name="en0",
|
||||
ip_address="192.168.1.100",
|
||||
ip_address="10.0.0.1",
|
||||
)
|
||||
ethernet_conn = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
|
||||
)
|
||||
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
conn_a_b = create_connection(node_id_a, node_id_b)
|
||||
conn_b_c = create_connection(node_id_b, node_id_c)
|
||||
conn_c_a = create_connection(node_id_c, node_id_a)
|
||||
|
||||
conn_b_a = create_connection(node_id_b, node_id_a)
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
profiles[node_a].network_interfaces = [ethernet_interface]
|
||||
profiles[node_b].network_interfaces = [ethernet_interface]
|
||||
profiles[node_c].network_interfaces = [ethernet_interface]
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
# RDMA connections (directed)
|
||||
topology.add_connection(
|
||||
Connection(source=node_a, sink=node_b, edge=create_rdma_connection(3))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_b, sink=node_a, edge=create_rdma_connection(3))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_b, sink=node_c, edge=create_rdma_connection(4))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_c, sink=node_b, edge=create_rdma_connection(4))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_a, sink=node_c, edge=create_rdma_connection(5))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_c, sink=node_a, edge=create_rdma_connection(5))
|
||||
)
|
||||
|
||||
# Ethernet connections (directed)
|
||||
topology.add_connection(Connection(source=node_a, sink=node_b, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_b, sink=node_c, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_c, sink=node_a, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_a, sink=node_c, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_b, sink=node_a, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_c, sink=node_b, edge=ethernet_conn))
|
||||
|
||||
cic = PlaceInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
@@ -429,35 +398,34 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
placements = place_instance(cic, topology, {})
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.jaccl_devices is not None
|
||||
assert instance.jaccl_coordinators is not None
|
||||
|
||||
matrix = instance.ibv_devices
|
||||
matrix = instance.jaccl_devices
|
||||
assert len(matrix) == 3
|
||||
|
||||
for i in range(3):
|
||||
assert matrix[i][i] is None
|
||||
|
||||
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
|
||||
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
|
||||
|
||||
idx_a = node_to_idx[node_id_a]
|
||||
idx_b = node_to_idx[node_id_b]
|
||||
idx_c = node_to_idx[node_id_c]
|
||||
idx_a = node_to_idx[node_a]
|
||||
idx_b = node_to_idx[node_b]
|
||||
idx_c = node_to_idx[node_c]
|
||||
|
||||
logger.info(matrix)
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.jaccl_coordinators) == 3
|
||||
@@ -469,7 +437,5 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
if node_id == assigned_nodes[0]:
|
||||
assert coordinator.startswith("0.0.0.0:")
|
||||
else:
|
||||
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
|
||||
ip_part = coordinator.split(":")[0]
|
||||
# Just verify it's a valid IP format
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable
|
||||
from copy import copy
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,154 +9,178 @@ from exo.master.placement_utils import (
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import create_node_profile, create_socket_connection
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
topology = Topology()
|
||||
return topology
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
def test_filter_cycles_by_memory():
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
connection1 = Connection(
|
||||
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
|
||||
)
|
||||
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
|
||||
topology = Topology()
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
cycles = [c for c in topology.get_cycles() if len(c) != 1]
|
||||
assert len(cycles) == 1
|
||||
assert len(cycles[0]) == 2
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_bytes(1)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
assert len(filtered_cycles[0]) == 2
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
assert set(n for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
def test_filter_cycles_by_insufficient_memory():
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
connection1 = Connection(
|
||||
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
|
||||
)
|
||||
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
|
||||
topology = Topology()
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
topology.get_cycles(), Memory.from_kb(2001)
|
||||
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 0
|
||||
|
||||
|
||||
def test_filter_multiple_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
def test_filter_multiple_cycles_by_memory():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
|
||||
)
|
||||
connection4 = Connection(
|
||||
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
|
||||
)
|
||||
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
topology = Topology()
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
topology.add_connection(connection4)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_kb(1500)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
assert len(filtered_cycles[0]) == 3
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {
|
||||
assert set(n for n in filtered_cycles[0]) == {
|
||||
node_a_id,
|
||||
node_b_id,
|
||||
node_c_id,
|
||||
}
|
||||
|
||||
|
||||
def test_get_smallest_cycles(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
def test_get_smallest_cycles():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
topology = Topology()
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
|
||||
)
|
||||
connection4 = Connection(
|
||||
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
|
||||
)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
topology.add_connection(connection4)
|
||||
|
||||
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(topology.get_cycles())
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
assert len(smallest_cycles[0]) == 2
|
||||
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -168,9 +192,6 @@ def test_get_smallest_cycles(
|
||||
],
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
@@ -180,18 +201,37 @@ def test_get_shard_assignments(
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
node_a = create_node(available_memory[0] * 1024, node_a_id)
|
||||
node_b = create_node(available_memory[1] * 1024, node_b_id)
|
||||
node_c = create_node(available_memory[2] * 1024, node_c_id)
|
||||
# create connections (A -> B -> C -> A forms a 3-cycle, plus B -> A also exists)
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
connection4 = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(4)
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology = Topology()
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
topology.add_connection(connection4)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
node_a = create_node_profile(available_memory[0] * 1024)
|
||||
node_b = create_node_profile(available_memory[1] * 1024)
|
||||
node_c = create_node_profile(available_memory[2] * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
@@ -201,23 +241,22 @@ def test_get_shard_assignments(
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# pick the 3-node cycle deterministically (cycle ordering can vary)
|
||||
selected_cycle = next(cycle for cycle in cycles if len(cycle) == 3)
|
||||
|
||||
# act
|
||||
shard_assignments = get_shard_assignments(
|
||||
model_meta, selected_cycle, Sharding.Pipeline
|
||||
model_meta, selected_cycle, Sharding.Pipeline, node_profiles=node_profiles
|
||||
)
|
||||
|
||||
# assert
|
||||
runner_id_a = shard_assignments.node_to_runner[node_a_id]
|
||||
runner_id_b = shard_assignments.node_to_runner[node_b_id]
|
||||
runner_id_c = shard_assignments.node_to_runner[node_c_id]
|
||||
assert (
|
||||
shard_assignments.runner_to_shard[runner_id_c].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_c].start_layer
|
||||
== expected_layers[2]
|
||||
)
|
||||
|
||||
assert (
|
||||
shard_assignments.runner_to_shard[runner_id_a].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_a].start_layer
|
||||
@@ -228,30 +267,37 @@ def test_get_shard_assignments(
|
||||
- shard_assignments.runner_to_shard[runner_id_b].start_layer
|
||||
== expected_layers[1]
|
||||
)
|
||||
assert (
|
||||
shard_assignments.runner_to_shard[runner_id_c].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_c].start_layer
|
||||
== expected_layers[2]
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
def test_get_hosts_from_subgraph():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node(500, node_a_id)
|
||||
node_b = create_node(500, node_b_id)
|
||||
node_c = create_node(1000, node_c_id)
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
@@ -259,95 +305,78 @@ def test_get_hosts_from_subgraph(
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip=("169.254.0.2"), port=5001),
|
||||
Host(ip=("169.254.0.3"), port=5002),
|
||||
Host(ip=("169.254.0.4"), port=5003),
|
||||
Host(ip="169.254.0.1", port=1234),
|
||||
Host(ip="169.254.0.2", port=1234),
|
||||
Host(ip="169.254.0.3", port=1234),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_jaccl_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
# fully connected (directed) between the 3 nodes
|
||||
conn_a_b = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
conn_b_a = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
|
||||
)
|
||||
conn_b_c = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(3)
|
||||
)
|
||||
conn_c_b = Connection(
|
||||
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
|
||||
)
|
||||
conn_c_a = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(5)
|
||||
)
|
||||
conn_a_c = Connection(
|
||||
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
|
||||
)
|
||||
|
||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||
|
||||
# Update node profiles with network interfaces before adding to topology
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
npp = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=0,
|
||||
ram_available=0,
|
||||
swap_total=0,
|
||||
swap_available=0,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
npp_a = copy(npp)
|
||||
npp_a.network_interfaces = [
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
|
||||
]
|
||||
npp_b = copy(npp)
|
||||
npp_b.network_interfaces = [
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
|
||||
]
|
||||
npp_c = copy(npp)
|
||||
npp_c.network_interfaces = [
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
|
||||
]
|
||||
node_profiles = {
|
||||
node_a_id: npp_a,
|
||||
node_b_id: npp_b,
|
||||
node_c_id: npp_c,
|
||||
}
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology = Topology()
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_a)
|
||||
@@ -356,11 +385,12 @@ def test_get_mlx_jaccl_coordinators(
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cycle = [node_a, node_b, node_c]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_jaccl_coordinators(
|
||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||
node_a_id,
|
||||
coordinator_port=5000,
|
||||
cycle_digraph=topology,
|
||||
node_profiles=node_profiles,
|
||||
)
|
||||
|
||||
# assert
|
||||
@@ -381,19 +411,20 @@ def test_get_mlx_jaccl_coordinators(
|
||||
f"Coordinator for {node_id} should use port 5000"
|
||||
)
|
||||
|
||||
# Rank 0 (node_a) treats this as the listen socket so should listen on all
|
||||
# IPs
|
||||
# Rank 0 (node_a) treats this as the listen socket so should listen on all IPs
|
||||
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
|
||||
"Rank 0 node should use localhost as coordinator"
|
||||
"Rank 0 node should use 0.0.0.0 as coordinator listen address"
|
||||
)
|
||||
|
||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||
assert coordinators[node_b_id] == (
|
||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||
assert isinstance(conn_b_a.edge, SocketConnection)
|
||||
assert (
|
||||
coordinators[node_b_id] == f"{conn_b_a.edge.sink_multiaddr.ip_address}:5000"
|
||||
), "node_b should use the IP from conn_b_a"
|
||||
|
||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
assert isinstance(conn_c_a.edge, SocketConnection)
|
||||
assert (
|
||||
coordinators[node_c_id] == f"{conn_c_a.edge.sink_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
MemoryUsage,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,20 +17,15 @@ def topology() -> Topology:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection() -> Connection:
|
||||
return Connection(
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
def socket_connection() -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryPerformanceProfile.from_bytes(
|
||||
memory_profile = MemoryUsage.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
)
|
||||
system_profile = SystemPerformanceProfile()
|
||||
@@ -43,162 +39,91 @@ def node_profile() -> NodePerformanceProfile:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_profile() -> ConnectionProfile:
|
||||
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
def test_add_node(topology: Topology):
|
||||
# arrange
|
||||
node_id = NodeId()
|
||||
|
||||
# act
|
||||
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
|
||||
topology.add_node(node_id)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(node_id)
|
||||
assert data == node_profile
|
||||
assert topology.node_is_leaf(node_id)
|
||||
|
||||
|
||||
def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
def test_add_connection(topology: Topology, socket_connection: SocketConnection):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
connection = Connection(source=node_a, sink=node_b, edge=socket_connection)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
data = topology.get_connection_profile(connection)
|
||||
data = list(topology.list_connections())
|
||||
|
||||
# assert
|
||||
assert data == connection.connection_profile
|
||||
assert data == [connection]
|
||||
|
||||
|
||||
def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_node_profile(
|
||||
connection.local_node_id, node_profile=new_node_profile
|
||||
)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(
|
||||
throughput=2000, latency=2000, jitter=2000
|
||||
)
|
||||
connection = Connection(
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
data = topology.get_connection_profile(connection)
|
||||
assert data == new_connection_profile
|
||||
assert topology.node_is_leaf(node_a)
|
||||
assert topology.node_is_leaf(node_b)
|
||||
|
||||
|
||||
def test_remove_connection_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
topology: Topology, socket_connection: SocketConnection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(conn)
|
||||
|
||||
# act
|
||||
topology.remove_connection(connection)
|
||||
topology.remove_connection(conn)
|
||||
|
||||
# assert
|
||||
assert topology.get_connection_profile(connection) is None
|
||||
assert list(topology.get_all_connections_between(node_a, node_b)) == []
|
||||
|
||||
|
||||
def test_remove_node_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
topology: Topology, socket_connection: SocketConnection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(conn)
|
||||
assert list(topology.out_edges(node_a)) == [conn]
|
||||
|
||||
# act
|
||||
topology.remove_node(connection.local_node_id)
|
||||
topology.remove_node(node_b)
|
||||
|
||||
# assert
|
||||
assert topology.get_node_profile(connection.local_node_id) is None
|
||||
assert list(topology.out_edges(node_a)) == []
|
||||
|
||||
|
||||
def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
def test_list_nodes(topology: Topology, socket_connection: SocketConnection):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(conn)
|
||||
assert list(topology.out_edges(node_a)) == [conn]
|
||||
|
||||
# act
|
||||
nodes = list(topology.list_nodes())
|
||||
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
}
|
||||
assert all(isinstance(node, NodeId) for node in nodes)
|
||||
assert set(node for node in nodes) == set([node_a, node_b])
|
||||
|
||||
@@ -11,10 +11,8 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
@@ -27,13 +25,23 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.topology import Connection, RDMAConnection
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
MacmonMetrics,
|
||||
MacThunderboltConnections,
|
||||
MacThunderboltIdentifiers,
|
||||
MemoryUsage,
|
||||
MiscData,
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
)
|
||||
|
||||
|
||||
def event_apply(event: Event, state: State) -> State:
|
||||
@@ -47,16 +55,12 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case NodeCreated():
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodePerformanceMeasured():
|
||||
return apply_node_performance_measured(event, state)
|
||||
case NodeDownloadProgress():
|
||||
return apply_node_download_progress(event, state)
|
||||
case NodeMemoryMeasured():
|
||||
return apply_node_memory_measured(event, state)
|
||||
case NodeGatheredInfo():
|
||||
return apply_node_gathered_info(event, state)
|
||||
case RunnerDeleted():
|
||||
return apply_runner_deleted(event, state)
|
||||
case RunnerStatusUpdated():
|
||||
@@ -188,7 +192,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
|
||||
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology = copy.deepcopy(state.topology)
|
||||
state.topology.remove_node(event.node_id)
|
||||
node_profiles = {
|
||||
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
||||
@@ -196,8 +200,12 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
}
|
||||
downloads = {
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"node_profiles": node_profiles,
|
||||
"last_seen": last_seen,
|
||||
@@ -205,103 +213,68 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
)
|
||||
|
||||
|
||||
def apply_node_performance_measured(
|
||||
event: NodePerformanceMeasured, state: State
|
||||
) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: event.node_profile,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
||||
topology = copy.copy(state.topology)
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_node(event.node_id)
|
||||
info = event.info
|
||||
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
|
||||
match info:
|
||||
case MacmonMetrics():
|
||||
profile.system = info.system_profile
|
||||
profile.memory = info.memory
|
||||
case MemoryUsage():
|
||||
profile.memory = info
|
||||
case NodeConfig():
|
||||
pass
|
||||
case MiscData():
|
||||
profile.friendly_name = info.friendly_name
|
||||
case StaticNodeInformation():
|
||||
profile.model_id = info.model
|
||||
profile.chip_id = info.chip
|
||||
case NodeNetworkInterfaces():
|
||||
profile.network_interfaces = info.ifaces
|
||||
case MacThunderboltIdentifiers():
|
||||
profile.tb_interfaces = info.idents
|
||||
case MacThunderboltConnections():
|
||||
conn_map = {
|
||||
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
|
||||
for nid in state.node_profiles
|
||||
for tb_ident in state.node_profiles[nid].tb_interfaces
|
||||
}
|
||||
as_rdma_conns = [
|
||||
Connection(
|
||||
source=event.node_id,
|
||||
sink=conn_map[tb_conn.sink_uuid][0],
|
||||
edge=RDMAConnection(
|
||||
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
|
||||
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
|
||||
),
|
||||
)
|
||||
for tb_conn in info.conns
|
||||
if tb_conn.source_uuid in conn_map
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
|
||||
|
||||
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
|
||||
new_profiles = {**state.node_profiles, event.node_id: profile}
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": new_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
"topology": topology,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
|
||||
existing = state.node_profiles.get(event.node_id)
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
if existing is None:
|
||||
created = NodePerformanceProfile(
|
||||
model_id="unknown",
|
||||
chip_id="unknown",
|
||||
friendly_name="Unknown",
|
||||
memory=event.memory,
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(
|
||||
# TODO: flops_fp16=0.0,
|
||||
gpu_usage=0.0,
|
||||
temp=0.0,
|
||||
sys_power=0.0,
|
||||
pcpu_usage=0.0,
|
||||
ecpu_usage=0.0,
|
||||
ane_power=0.0,
|
||||
),
|
||||
)
|
||||
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: created,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
# TODO: NodeCreated
|
||||
topology.update_node_profile(event.node_id, created)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": created_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
updated = existing.model_copy(update={"memory": event.memory})
|
||||
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: updated,
|
||||
}
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, updated)
|
||||
return state.model_copy(
|
||||
update={"node_profiles": updated_profiles, "topology": topology}
|
||||
)
|
||||
|
||||
|
||||
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_connection(event.edge)
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_connection(event.conn)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
topology.remove_connection(event.edge)
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.remove_connection(event.conn)
|
||||
# TODO: Clean up removing the reverse connection
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@@ -38,6 +38,7 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
||||
|
||||
# Identity (config)
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
|
||||
@@ -11,9 +11,6 @@ class InterceptLogger(HypercornLogger):
|
||||
def __init__(self, config: Config):
|
||||
super().__init__(config)
|
||||
assert self.error_logger
|
||||
# TODO: Decide if we want to provide access logs
|
||||
# assert self.access_logger
|
||||
# self.access_logger.handlers = [_InterceptHandler()]
|
||||
self.error_logger.handlers = [_InterceptHandler()]
|
||||
|
||||
|
||||
@@ -29,6 +26,11 @@ class _InterceptHandler(logging.Handler):
|
||||
|
||||
def logger_setup(log_file: Path | None, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
|
||||
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
logger.remove()
|
||||
|
||||
# replace all stdlib loggers with _InterceptHandlers that log to loguru
|
||||
|
||||
@@ -14,32 +14,6 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
# "deepseek-v3-0324:4bit": ModelCard(
|
||||
# short_id="deepseek-v3-0324:4bit",
|
||||
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
# name="DeepSeek V3 0324 (4-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3-0324": ModelCard(
|
||||
# short_id="deepseek-v3-0324",
|
||||
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
# name="DeepSeek V3 0324 (8-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
short_id="deepseek-v3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
@@ -70,63 +44,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
# short_id="deepseek-v3.2",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# name="DeepSeek V3.2 (8-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# pretty_name="DeepSeek V3.2 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3.2-4bit": ModelCard(
|
||||
# short_id="deepseek-v3.2-4bit",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# name="DeepSeek V3.2 (4-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# pretty_name="DeepSeek V3.2 (4-bit)",
|
||||
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# deepseek r1
|
||||
# "deepseek-r1-0528-4bit": ModelCard(
|
||||
# short_id="deepseek-r1-0528-4bit",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
|
||||
# name="DeepSeek-R1-0528 (4-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
|
||||
# pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-r1-0528": ModelCard(
|
||||
# short_id="deepseek-r1-0528",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
|
||||
# name="DeepSeek-R1-0528 (8-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
|
||||
# pretty_name="DeepSeek R1 671B (8-bit)",
|
||||
# storage_size=Memory.from_bytes(754998771712),
|
||||
# n_layers=61,
|
||||
# . hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
@@ -508,23 +425,24 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-20b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# Needs to be quantized g32 or g16.
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
@@ -554,19 +472,81 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# storage_size=Memory.from_kb(133_000_000),
|
||||
# n_layers=88,
|
||||
# hidden_size=12288,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
name="GLM 4.7 4bit",
|
||||
description="GLM 4.7 4bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
pretty_name="GLM 4.7 4bit",
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
short_id="glm-4.7-6bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
name="GLM 4.7 6bit",
|
||||
description="GLM 4.7 6bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
pretty_name="GLM 4.7 6bit",
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
short_id="glm-4.7-8bit-gs32",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
name="GLM 4.7 8bit (gs32)",
|
||||
description="GLM 4.7 8bit (gs32)",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
pretty_name="GLM 4.7 8bit (gs32)",
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
name="MiniMax M2.1 8bit",
|
||||
description="MiniMax M2.1 8bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
pretty_name="MiniMax M2.1 8bit",
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
short_id="minimax-m2.1-3bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
name="MiniMax M2.1 3bit",
|
||||
description="MiniMax M2.1 3bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
pretty_name="MiniMax M2.1 3bit",
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -43,7 +43,4 @@ def test_apply_two_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event2), state
|
||||
)
|
||||
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
@@ -12,9 +12,11 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = Connection(
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
source=node_a,
|
||||
sink=node_b,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
),
|
||||
)
|
||||
|
||||
state = State()
|
||||
@@ -23,5 +25,11 @@ def test_state_serialization_roundtrip() -> None:
|
||||
json_repr = state.model_dump_json()
|
||||
restored_state = State.model_validate_json(json_repr)
|
||||
|
||||
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
|
||||
assert (
|
||||
state.topology.to_snapshot().nodes
|
||||
== restored_state.topology.to_snapshot().nodes
|
||||
)
|
||||
assert set(state.topology.to_snapshot().connections) == set(
|
||||
restored_state.topology.to_snapshot().connections
|
||||
)
|
||||
assert restored_state.model_dump_json() == json_repr
|
||||
|
||||
@@ -1,203 +1,227 @@
|
||||
import contextlib
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
|
||||
import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.topology import (
|
||||
Connection,
|
||||
Cycle,
|
||||
RDMAConnection,
|
||||
SocketConnection,
|
||||
)
|
||||
|
||||
|
||||
class TopologySnapshot(BaseModel):
|
||||
nodes: list[NodeInfo]
|
||||
connections: list[Connection]
|
||||
nodes: Sequence[NodeId]
|
||||
connections: Mapping[
|
||||
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
|
||||
]
|
||||
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
model_config = ConfigDict(frozen=True, extra="forbid")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Topology:
|
||||
def __init__(self) -> None:
|
||||
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
|
||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
|
||||
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
|
||||
init=False, default_factory=rx.PyDiGraph
|
||||
)
|
||||
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
|
||||
|
||||
def to_snapshot(self) -> TopologySnapshot:
|
||||
return TopologySnapshot(
|
||||
nodes=list(self.list_nodes()),
|
||||
connections=list(self.list_connections()),
|
||||
nodes=list(self.list_nodes()), connections=self.map_connections()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
|
||||
topology = cls()
|
||||
|
||||
for node in snapshot.nodes:
|
||||
for node_id in snapshot.nodes:
|
||||
with contextlib.suppress(ValueError):
|
||||
topology.add_node(node)
|
||||
topology.add_node(node_id)
|
||||
|
||||
for connection in snapshot.connections:
|
||||
topology.add_connection(connection)
|
||||
for source in snapshot.connections:
|
||||
for sink in snapshot.connections[source]:
|
||||
for edge in snapshot.connections[source][sink]:
|
||||
topology.add_connection(
|
||||
Connection(source=source, sink=sink, edge=edge)
|
||||
)
|
||||
|
||||
return topology
|
||||
|
||||
def add_node(self, node: NodeInfo) -> None:
|
||||
if node.node_id in self._node_id_to_rx_id_map:
|
||||
def add_node(self, node_id: NodeId) -> None:
|
||||
if node_id in self._vertex_indices:
|
||||
return
|
||||
rx_id = self._graph.add_node(node)
|
||||
self._node_id_to_rx_id_map[node.node_id] = rx_id
|
||||
self._rx_id_to_node_id_map[rx_id] = node.node_id
|
||||
rx_id = self._graph.add_node(node_id)
|
||||
self._vertex_indices[node_id] = rx_id
|
||||
|
||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||
return (
|
||||
node_id in self._node_id_to_rx_id_map
|
||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
||||
node_id in self._vertex_indices
|
||||
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
|
||||
)
|
||||
|
||||
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
||||
return [
|
||||
self._rx_id_to_node_id_map[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
|
||||
self._graph[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
|
||||
]
|
||||
|
||||
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
def out_edges(self, node_id: NodeId) -> Iterable[Connection]:
|
||||
if node_id not in self._vertex_indices:
|
||||
return []
|
||||
return [
|
||||
(self._rx_id_to_node_id_map[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(
|
||||
self._node_id_to_rx_id_map[node_id]
|
||||
return (
|
||||
Connection(source=self._graph[source], sink=self._graph[sink], edge=edge)
|
||||
for source, sink, edge in self._graph.out_edges(
|
||||
self._vertex_indices[node_id]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
return node_id in self._vertex_indices
|
||||
|
||||
def contains_connection(self, connection: Connection) -> bool:
|
||||
return connection in self._edge_id_to_rx_id_map
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.local_node_id))
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
def add_connection(self, conn: Connection) -> None:
|
||||
source, sink, edge = conn.source, conn.sink, conn.edge
|
||||
del conn
|
||||
if edge in self.get_all_connections_between(source, sink):
|
||||
return
|
||||
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
if source not in self._vertex_indices:
|
||||
self.add_node(source)
|
||||
if sink not in self._vertex_indices:
|
||||
self.add_node(sink)
|
||||
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
|
||||
def list_nodes(self) -> Iterable[NodeInfo]:
|
||||
return (self._graph[i] for i in self._graph.node_indices())
|
||||
_ = self._graph.add_edge(src_id, sink_id, edge)
|
||||
|
||||
def list_connections(self) -> Iterable[Connection]:
|
||||
return (connection for _, _, connection in self._graph.weighted_edge_list())
|
||||
def get_all_connections_between(
|
||||
self, source: NodeId, sink: NodeId
|
||||
) -> Iterable[SocketConnection | RDMAConnection]:
|
||||
if source not in self._vertex_indices:
|
||||
return []
|
||||
if sink not in self._vertex_indices:
|
||||
return []
|
||||
|
||||
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
try:
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
return self._graph.get_node_data(rx_idx).node_profile
|
||||
except KeyError:
|
||||
return None
|
||||
return self._graph.get_all_edge_data(src_id, sink_id)
|
||||
except rx.NoEdgeBetweenNodes:
|
||||
return []
|
||||
|
||||
def update_node_profile(
|
||||
self, node_id: NodeId, node_profile: NodePerformanceProfile
|
||||
) -> None:
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph[rx_idx].node_profile = node_profile
|
||||
def list_nodes(self) -> Iterable[NodeId]:
|
||||
return self._graph.nodes()
|
||||
|
||||
def update_connection_profile(self, connection: Connection) -> None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.update_edge_by_index(rx_idx, connection)
|
||||
def map_connections(
|
||||
self,
|
||||
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
|
||||
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list():
|
||||
source = self._graph[src_id]
|
||||
sink = self._graph[sink_id]
|
||||
if source not in base:
|
||||
base[source] = {}
|
||||
if sink not in base[source]:
|
||||
base[source][sink] = []
|
||||
base[source][sink].append(connection)
|
||||
return base
|
||||
|
||||
def get_connection_profile(
|
||||
self, connection: Connection
|
||||
) -> ConnectionProfile | None:
|
||||
try:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
|
||||
except KeyError:
|
||||
return None
|
||||
def list_connections(
|
||||
self,
|
||||
) -> Iterable[Connection]:
|
||||
return (
|
||||
(
|
||||
Connection(
|
||||
source=self._graph[src_id],
|
||||
sink=self._graph[sink_id],
|
||||
edge=connection,
|
||||
)
|
||||
)
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list()
|
||||
)
|
||||
|
||||
def remove_node(self, node_id: NodeId) -> None:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
if node_id not in self._vertex_indices:
|
||||
return
|
||||
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_id
|
||||
or connection.send_back_node_id == node_id
|
||||
):
|
||||
self.remove_connection(connection)
|
||||
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
rx_idx = self._vertex_indices[node_id]
|
||||
self._graph.remove_node(rx_idx)
|
||||
|
||||
del self._node_id_to_rx_id_map[node_id]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
del self._vertex_indices[node_id]
|
||||
|
||||
def remove_connection(self, connection: Connection) -> None:
|
||||
if connection not in self._edge_id_to_rx_id_map:
|
||||
def replace_all_out_rdma_connections(
|
||||
self, source: NodeId, new_connections: Sequence[Connection]
|
||||
) -> None:
|
||||
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
|
||||
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
for conn in new_connections:
|
||||
self.add_connection(conn)
|
||||
|
||||
def remove_connection(self, conn: Connection) -> None:
|
||||
if (
|
||||
conn.source not in self._vertex_indices
|
||||
or conn.sink not in self._vertex_indices
|
||||
):
|
||||
return
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.remove_edge_from_index(rx_idx)
|
||||
del self._edge_id_to_rx_id_map[connection]
|
||||
for conn_idx in self._graph.edge_indices_from_endpoints(
|
||||
self._vertex_indices[conn.source], self._vertex_indices[conn.sink]
|
||||
):
|
||||
if self._graph.get_edge_data_by_index(conn_idx) == conn.edge:
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
|
||||
def get_cycles(self) -> list[Cycle]:
|
||||
"""Get simple cycles in the graph, including singleton cycles"""
|
||||
|
||||
def get_cycles(self) -> list[list[NodeInfo]]:
|
||||
cycle_idxs = rx.simple_cycles(self._graph)
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
cycles: list[Cycle] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [self._graph[idx] for idx in cycle_idx]
|
||||
cycle = Cycle(node_ids=[self._graph[idx] for idx in cycle_idx])
|
||||
cycles.append(cycle)
|
||||
|
||||
for node_id in self.list_nodes():
|
||||
cycles.append(Cycle(node_ids=[node_id]))
|
||||
return cycles
|
||||
|
||||
def get_cycles_tb(self) -> list[list[NodeInfo]]:
|
||||
def get_cycles_tb(self) -> list[Cycle]:
|
||||
tb_edges = [
|
||||
(u, v, conn)
|
||||
for u, v, conn in self._graph.weighted_edge_list()
|
||||
if conn.is_thunderbolt()
|
||||
]
|
||||
|
||||
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
if isinstance(conn, SocketConnection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
cycles: list[Cycle] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [tb_graph[idx] for idx in cycle_idx]
|
||||
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
|
||||
node_idxs = [node.node_id for node in nodes]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
|
||||
topology = Topology()
|
||||
for rx_idx in rx_idxs:
|
||||
topology.add_node(self._graph[rx_idx])
|
||||
for node_id in node_ids:
|
||||
topology.add_node(node_id)
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id in node_idxs
|
||||
and connection.send_back_node_id in node_idxs
|
||||
):
|
||||
if connection.source in node_ids and connection.sink in node_ids:
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
|
||||
node_idxs = [node.node_id for node in cycle]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
|
||||
node_idxs = [node for node in cycle]
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
|
||||
@@ -2,14 +2,14 @@ from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -76,25 +76,15 @@ class RunnerDeleted(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
# TODO
|
||||
class NodeCreated(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
class NodeTimedOut(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
class NodePerformanceMeasured(BaseEvent):
|
||||
# TODO: bikeshed this name
|
||||
class NodeGatheredInfo(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class NodeMemoryMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
memory: MemoryPerformanceProfile
|
||||
info: GatheredInfo
|
||||
|
||||
|
||||
class NodeDownloadProgress(BaseEvent):
|
||||
@@ -107,11 +97,11 @@ class ChunkGenerated(BaseEvent):
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
edge: Connection
|
||||
conn: Connection
|
||||
|
||||
|
||||
class TopologyEdgeDeleted(BaseEvent):
|
||||
edge: Connection
|
||||
conn: Connection
|
||||
|
||||
|
||||
Event = (
|
||||
@@ -125,10 +115,8 @@ Event = (
|
||||
| InstanceDeleted
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
| NodeTimedOut
|
||||
| NodePerformanceMeasured
|
||||
| NodeMemoryMeasured
|
||||
| NodeGatheredInfo
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| TopologyEdgeCreated
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, computed_field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||
|
||||
|
||||
class Multiaddr(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
address: str
|
||||
|
||||
PATTERNS: ClassVar[list[str]] = [
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Self
|
||||
|
||||
import psutil
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.thunderbolt import ThunderboltIdentifier
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class MemoryPerformanceProfile(CamelCaseModel):
|
||||
class MemoryUsage(CamelCaseModel):
|
||||
ram_total: Memory
|
||||
ram_available: Memory
|
||||
swap_total: Memory
|
||||
@@ -44,7 +46,6 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
sys_power: float = 0.0
|
||||
pcpu_usage: float = 0.0
|
||||
ecpu_usage: float = 0.0
|
||||
ane_power: float = 0.0
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
@@ -53,15 +54,12 @@ class NetworkInterfaceInfo(CamelCaseModel):
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
memory: MemoryPerformanceProfile
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
class ConnectionProfile(CamelCaseModel):
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
model_id: str = "Unknown"
|
||||
chip_id: str = "Unknown"
|
||||
friendly_name: str = "Unknown"
|
||||
memory: MemoryUsage = MemoryUsage.from_bytes(
|
||||
ram_total=0, ram_available=0, swap_total=0, swap_available=0
|
||||
)
|
||||
network_interfaces: Sequence[NetworkInterfaceInfo] = []
|
||||
tb_interfaces: Sequence[ThunderboltIdentifier] = []
|
||||
system: SystemPerformanceProfile = SystemPerformanceProfile()
|
||||
|
||||
81
src/exo/shared/types/thunderbolt.py
Normal file
81
src/exo/shared/types/thunderbolt.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import anyio
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ThunderboltConnection(CamelCaseModel):
|
||||
source_uuid: str
|
||||
sink_uuid: str
|
||||
|
||||
|
||||
class ThunderboltIdentifier(CamelCaseModel):
|
||||
rdma_interface: str
|
||||
domain_uuid: str
|
||||
|
||||
|
||||
## Intentionally minimal, only collecting data we care about - there's a lot more
|
||||
|
||||
|
||||
class _ReceptacleTag(BaseModel, extra="ignore"):
|
||||
receptacle_id_key: str | None = None
|
||||
|
||||
|
||||
class _ConnectivityItem(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None = None
|
||||
|
||||
|
||||
class ThunderboltConnectivityData(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None = None
|
||||
items: list[_ConnectivityItem] | None = Field(None, alias="_items")
|
||||
receptacle_1_tag: _ReceptacleTag | None = None
|
||||
|
||||
def ident(self, ifaces: dict[str, str]) -> ThunderboltIdentifier | None:
|
||||
if (
|
||||
self.domain_uuid_key is None
|
||||
or self.receptacle_1_tag is None
|
||||
or self.receptacle_1_tag.receptacle_id_key is None
|
||||
):
|
||||
return
|
||||
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
|
||||
assert tag in ifaces # doesn't need to be an assertion but im confident
|
||||
# if tag not in ifaces: return None
|
||||
iface = f"rdma_{ifaces[tag]}"
|
||||
return ThunderboltIdentifier(
|
||||
rdma_interface=iface, domain_uuid=self.domain_uuid_key
|
||||
)
|
||||
|
||||
def conn(self) -> ThunderboltConnection | None:
|
||||
if self.domain_uuid_key is None or self.items is None:
|
||||
return
|
||||
|
||||
sink_key = next(
|
||||
(
|
||||
item.domain_uuid_key
|
||||
for item in self.items
|
||||
if item.domain_uuid_key is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
if sink_key is None:
|
||||
return None
|
||||
|
||||
return ThunderboltConnection(
|
||||
source_uuid=self.domain_uuid_key, sink_uuid=sink_key
|
||||
)
|
||||
|
||||
|
||||
class ThunderboltConnectivity(BaseModel, extra="ignore"):
|
||||
SPThunderboltDataType: list[ThunderboltConnectivityData] = []
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> list[ThunderboltConnectivityData] | None:
|
||||
proc = await anyio.run_process(
|
||||
["system_profiler", "SPThunderboltDataType", "-json"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
# Saving you from PascalCase while avoiding too much pydantic
|
||||
return ThunderboltConnectivity.model_validate_json(
|
||||
proc.stdout
|
||||
).SPThunderboltDataType
|
||||
@@ -1,37 +1,41 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
class NodeInfo(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile | None = None
|
||||
@dataclass(frozen=True)
|
||||
class Cycle:
|
||||
node_ids: list[NodeId]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.node_ids.__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[NodeId]:
|
||||
return self.node_ids.__iter__()
|
||||
|
||||
|
||||
class Connection(CamelCaseModel):
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
)
|
||||
class RDMAConnection(FrozenModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
||||
return True
|
||||
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
|
||||
class Connection(FrozenModel):
|
||||
source: NodeId
|
||||
sink: NodeId
|
||||
edge: RDMAConnection | SocketConnection
|
||||
|
||||
@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
ibv_devices: list[list[str | None]]
|
||||
jaccl_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from typing import Callable
|
||||
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
|
||||
class ResourceCollector(ABC):
|
||||
@abstractmethod
|
||||
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
|
||||
|
||||
|
||||
class SystemResourceCollector(ResourceCollector):
|
||||
async def collect(self) -> SystemPerformanceProfile: ...
|
||||
|
||||
|
||||
class MemoryResourceCollector(ResourceCollector):
|
||||
async def collect(self) -> MemoryPerformanceProfile: ...
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
data_collectors: list[ResourceCollector]
|
||||
effect_handlers: set[
|
||||
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
|
||||
]
|
||||
|
||||
async def _collect(
|
||||
self,
|
||||
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
|
||||
tasks: list[
|
||||
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
|
||||
] = [collector.collect() for collector in self.data_collectors]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def collect(self) -> None:
|
||||
profiles = await self._collect()
|
||||
for profile in profiles:
|
||||
for effect_handler in self.effect_handlers:
|
||||
effect_handler(profile)
|
||||
235
src/exo/utils/info_gatherer/info_gatherer.py
Normal file
235
src/exo/utils/info_gatherer/info_gatherer.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tomllib
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from subprocess import CalledProcessError
|
||||
from typing import Self, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group, open_process
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_CONFIG_FILE
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
)
|
||||
from exo.shared.types.thunderbolt import (
|
||||
ThunderboltConnection,
|
||||
ThunderboltConnectivity,
|
||||
ThunderboltIdentifier,
|
||||
)
|
||||
from exo.utils.channels import Sender
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .macmon import MacmonMetrics
|
||||
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
|
||||
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
|
||||
class StaticNodeInformation(TaggedModel):
|
||||
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||
|
||||
model: str
|
||||
chip: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
model, chip = await get_model_and_chip()
|
||||
return cls(model=model, chip=chip)
|
||||
|
||||
|
||||
class NodeNetworkInterfaces(TaggedModel):
|
||||
ifaces: Sequence[NetworkInterfaceInfo]
|
||||
|
||||
|
||||
class MacThunderboltIdentifiers(TaggedModel):
|
||||
idents: Sequence[ThunderboltIdentifier]
|
||||
|
||||
|
||||
class MacThunderboltConnections(TaggedModel):
|
||||
conns: Sequence[ThunderboltConnection]
|
||||
|
||||
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
cfg_file = anyio.Path(EXO_CONFIG_FILE)
|
||||
await cfg_file.touch(exist_ok=True)
|
||||
async with await cfg_file.open("rb") as f:
|
||||
try:
|
||||
contents = (await f.read()).decode("utf-8")
|
||||
data = tomllib.loads(contents)
|
||||
return cls.model_validate(data)
|
||||
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Invalid config file, skipping...")
|
||||
return None
|
||||
|
||||
|
||||
class MiscData(TaggedModel):
|
||||
"""Node information that may slowly change that doesn't fall into the other categories"""
|
||||
|
||||
friendly_name: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
return cls(friendly_name=await get_friendly_name())
|
||||
|
||||
|
||||
async def _gather_iface_map() -> dict[str, str] | None:
|
||||
proc = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
|
||||
ports: dict[str, str] = {}
|
||||
port = ""
|
||||
for line in proc.stdout.decode("utf-8").split("\n"):
|
||||
if line.startswith("Hardware Port:"):
|
||||
port = line.split(": ")[1]
|
||||
elif line.startswith("Device:"):
|
||||
ports[port] = line.split(": ")[1]
|
||||
port = ""
|
||||
if "" in ports:
|
||||
del ports[""]
|
||||
return ports
|
||||
|
||||
|
||||
GatheredInfo = (
|
||||
MacmonMetrics
|
||||
| MemoryUsage
|
||||
| NodeNetworkInterfaces
|
||||
| MacThunderboltIdentifiers
|
||||
| MacThunderboltConnections
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfoGatherer:
|
||||
info_sender: Sender[GatheredInfo]
|
||||
interface_watcher_interval: float | None = 10
|
||||
misc_poll_interval: float | None = 60
|
||||
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
if IS_DARWIN:
|
||||
if (macmon_path := shutil.which("macmon")) is not None:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
|
||||
nc = await NodeConfig.gather()
|
||||
if nc is not None:
|
||||
await self.info_sender.send(nc)
|
||||
sni = await StaticNodeInformation.gather()
|
||||
await self.info_sender.send(sni)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
prev = await MiscData.gather()
|
||||
await self.info_sender.send(prev)
|
||||
while True:
|
||||
curr = await MiscData.gather()
|
||||
if prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
|
||||
async def _monitor_system_profiler_thunderbolt_data(self):
|
||||
if self.system_profiler_interval is None:
|
||||
return
|
||||
iface_map = await _gather_iface_map()
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
old_idents = []
|
||||
while True:
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
if idents != old_idents:
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
old_idents = idents
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacThunderboltConnections(conns=conns))
|
||||
|
||||
await anyio.sleep(self.system_profiler_interval)
|
||||
|
||||
async def _monitor_memory_usage(self):
|
||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||
override_memory: int | None = (
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||
if override_memory_env
|
||||
else None
|
||||
)
|
||||
if self.memory_poll_rate is None:
|
||||
return
|
||||
while True:
|
||||
await self.info_sender.send(
|
||||
MemoryUsage.from_psutil(override_memory=override_memory)
|
||||
)
|
||||
await anyio.sleep(self.memory_poll_rate)
|
||||
|
||||
async def _watch_system_info(self):
|
||||
if self.interface_watcher_interval is None:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
# macmon pipe --interval [interval in ms]
|
||||
try:
|
||||
async with await open_process(
|
||||
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
|
||||
) as p:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
if stderr_output is not None:
|
||||
stderr_msg = (
|
||||
stderr_output.decode()
|
||||
if isinstance(stderr_output, bytes)
|
||||
else str(stderr_output)
|
||||
)
|
||||
logger.warning(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
70
src/exo/utils/info_gatherer/macmon.py
Normal file
70
src/exo/utils/info_gatherer/macmon.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class _TempMetrics(BaseModel, extra="ignore"):
|
||||
"""Temperature-related metrics returned by macmon."""
|
||||
|
||||
cpu_temp_avg: float
|
||||
gpu_temp_avg: float
|
||||
|
||||
|
||||
class _MemoryMetrics(BaseModel, extra="ignore"):
|
||||
"""Memory-related metrics returned by macmon."""
|
||||
|
||||
ram_total: int
|
||||
ram_usage: int
|
||||
swap_total: int
|
||||
swap_usage: int
|
||||
|
||||
|
||||
class RawMacmonMetrics(BaseModel, extra="ignore"):
|
||||
"""Complete set of metrics returned by macmon.
|
||||
|
||||
Unknown fields are ignored for forward-compatibility.
|
||||
"""
|
||||
|
||||
timestamp: str # ignored
|
||||
temp: _TempMetrics
|
||||
memory: _MemoryMetrics
|
||||
ecpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
pcpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
gpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
all_power: float
|
||||
ane_power: float
|
||||
cpu_power: float
|
||||
gpu_power: float
|
||||
gpu_ram_power: float
|
||||
ram_power: float
|
||||
sys_power: float
|
||||
|
||||
|
||||
class MacmonMetrics(TaggedModel):
|
||||
system_profile: SystemPerformanceProfile
|
||||
memory: MemoryUsage
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
|
||||
return cls(
|
||||
system_profile=SystemPerformanceProfile(
|
||||
gpu_usage=raw.gpu_usage[1],
|
||||
temp=raw.temp.gpu_temp_avg,
|
||||
sys_power=raw.sys_power,
|
||||
pcpu_usage=raw.pcpu_usage[1],
|
||||
ecpu_usage=raw.ecpu_usage[1],
|
||||
),
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=raw.memory.ram_total,
|
||||
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
|
||||
swap_total=raw.memory.swap_total,
|
||||
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_raw_json(cls, json: str) -> Self:
|
||||
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))
|
||||
113
src/exo/utils/info_gatherer/net_profile.py
Normal file
113
src/exo/utils/info_gatherer/net_profile.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio import create_task_group
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
|
||||
REACHABILITY_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
if ":" in target_ip:
|
||||
# TODO: use real IpAddress types
|
||||
target_ip = f"[{target_ip}]"
|
||||
url = f"http://{target_ip}:52415/node_id"
|
||||
|
||||
remote_node_id = None
|
||||
last_error = None
|
||||
|
||||
for _ in range(REACHABILITY_ATTEMPTS):
|
||||
try:
|
||||
r = await client.get(url)
|
||||
if r.status_code != 200:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
|
||||
body = r.text.strip().strip('"')
|
||||
if not body:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
|
||||
remote_node_id = NodeId(body)
|
||||
break
|
||||
|
||||
# expected failure cases
|
||||
except (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
):
|
||||
await anyio.sleep(1)
|
||||
|
||||
# other failures should be logged on last attempt
|
||||
except httpx.HTTPError as e:
|
||||
last_error = e
|
||||
await anyio.sleep(1)
|
||||
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
f"ip={target_ip}, expected_node_id={expected_node_id}, "
|
||||
f"remote_node_id={remote_node_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if remote_node_id not in out:
|
||||
out[remote_node_id] = set()
|
||||
out[remote_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(
|
||||
topology: Topology,
|
||||
self_node_id: NodeId,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
|
||||
# these are intentionally httpx's defaults so we can tune them later
|
||||
timeout = httpx.Timeout(timeout=5.0)
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=5,
|
||||
)
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
for node_id in topology.list_nodes():
|
||||
if node_id not in node_profiles:
|
||||
continue
|
||||
if node_id == self_node_id:
|
||||
continue
|
||||
for iface in node_profiles[node_id].network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node_id,
|
||||
reachable,
|
||||
client,
|
||||
)
|
||||
|
||||
return reachable
|
||||
24
src/exo/utils/info_gatherer/tests/test_tb_parsing.py
Normal file
24
src/exo/utils/info_gatherer/tests/test_tb_parsing.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.thunderbolt import (
|
||||
ThunderboltConnectivity,
|
||||
)
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
_gather_iface_map, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "darwin", reason="Thunderbolt info can only be gathered on macos"
|
||||
)
|
||||
async def test_tb_parsing():
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
ifaces = await _gather_iface_map()
|
||||
assert ifaces
|
||||
assert data
|
||||
for datum in data:
|
||||
datum.ident(ifaces)
|
||||
datum.conn()
|
||||
@@ -19,11 +19,20 @@ class CamelCaseModel(BaseModel):
|
||||
alias_generator=to_camel,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
|
||||
strict=True,
|
||||
)
|
||||
|
||||
|
||||
class FrozenModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
strict=True,
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
|
||||
class TaggedModel(CamelCaseModel):
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler: SerializerFunctionWrapHandler):
|
||||
|
||||
@@ -28,9 +28,8 @@ def bar(send: MpSender[str]):
|
||||
send.close()
|
||||
|
||||
|
||||
# not async, just want the fail_after
|
||||
@pytest.mark.anyio
|
||||
async def test_channel_setup():
|
||||
async def test_channel_ipc():
|
||||
with fail_after(0.5):
|
||||
s, r = mp_channel[str]()
|
||||
p1 = mp.Process(target=foo, args=(r,))
|
||||
@@ -10,18 +10,24 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.cache import (
|
||||
_BaseCache, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
)
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
@@ -91,8 +97,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
@@ -100,7 +104,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# This change happened upstream - check out mlx github somewhere??
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
@@ -132,24 +135,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
return layers
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(model, DeepseekV3Model) and hasattr(
|
||||
inner_model_instance, "num_layers"
|
||||
):
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
def pipeline_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
@@ -165,8 +150,7 @@ def pipeline_auto_parallel(
|
||||
"""
|
||||
inner_model_instance: nn.Module = _inner_model(model)
|
||||
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
|
||||
layers = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
@@ -180,6 +164,17 @@ def pipeline_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
assert isinstance(layers, list), (
|
||||
@@ -204,18 +199,44 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
segments: int = 1
|
||||
|
||||
def _all_to_sharded(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - all to sharded")
|
||||
return weight.ndim - 1, segments
|
||||
return max(weight.ndim - 2, 0), segments
|
||||
|
||||
all_to_sharded_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding="all-to-sharded",
|
||||
group=group,
|
||||
)
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding="sharded-to-all",
|
||||
sharding=_all_to_sharded, # type: ignore
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(model, LlamaModel):
|
||||
n = group.size()
|
||||
|
||||
def _sharded_to_all(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - sharded to all")
|
||||
weight /= n
|
||||
return None
|
||||
return -1, segments
|
||||
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding=_sharded_to_all, # type: ignore
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return model
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -223,7 +244,8 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, DeepseekV3Model):
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -231,7 +253,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
elif isinstance(model, MiniMaxModel):
|
||||
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -239,6 +269,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -284,6 +323,32 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
||||
) and hasattr(inner_model_instance, "num_layers"):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
logger.info(
|
||||
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.num_hidden_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
@@ -304,7 +369,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, DeepseekV3MLP):
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
@@ -338,6 +403,35 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
@@ -352,11 +446,13 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
if isinstance(
|
||||
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
|
||||
):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # type: ignore
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -380,3 +476,50 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
layer.self_attn.num_key_value_groups = (
|
||||
layer.self_attn.num_attention_heads
|
||||
// layer.self_attn.num_key_value_heads
|
||||
)
|
||||
|
||||
layer.self_attn.sinks = layer.self_attn.sinks[
|
||||
layer.self_attn.num_attention_heads
|
||||
* self.group.rank() : layer.self_attn.num_attention_heads
|
||||
* (self.group.rank() + 1)
|
||||
]
|
||||
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -20,6 +20,7 @@ except ImportError:
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -144,20 +145,26 @@ def mlx_distributed_init(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
assert all(
|
||||
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
|
||||
)
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
_ = f.write(jaccl_devices_json)
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
@@ -365,6 +372,8 @@ def apply_chat_template(
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -396,6 +405,11 @@ def make_kv_cache(
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
|
||||
@@ -16,8 +16,7 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
TopologyEdgeCreated,
|
||||
@@ -25,7 +24,6 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
@@ -34,7 +32,7 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
@@ -45,14 +43,14 @@ from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
|
||||
from exo.worker.utils.net_profile import check_reachable
|
||||
|
||||
|
||||
class Worker:
|
||||
@@ -86,7 +84,7 @@ class Worker:
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
|
||||
self._nack_cancel_scope: CancelScope | None = None
|
||||
self._nack_attempts: int = 0
|
||||
@@ -98,37 +96,13 @@ class Worker:
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
# TODO: CLEANUP HEADER
|
||||
async def resource_monitor_callback(
|
||||
node_performance_profile: NodePerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodePerformanceMeasured(
|
||||
node_id=self.node_id,
|
||||
node_profile=node_performance_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
|
||||
async def memory_monitor_callback(
|
||||
memory_profile: MemoryPerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeMemoryMeasured(
|
||||
node_id=self.node_id,
|
||||
memory=memory_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
)
|
||||
)
|
||||
|
||||
# END CLEANUP
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
@@ -142,6 +116,17 @@ class Worker:
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
async for info in info_stream:
|
||||
await self.event_sender.send(
|
||||
NodeGatheredInfo(
|
||||
node_id=self.node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=info,
|
||||
)
|
||||
)
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
@@ -161,7 +146,6 @@ class Worker:
|
||||
self._nack_cancel_scope is None
|
||||
or self._nack_cancel_scope.cancel_called
|
||||
):
|
||||
assert self._tg
|
||||
# Request the next index.
|
||||
self._tg.start_soon(
|
||||
self._nack_request, self.state.last_event_applied_idx + 1
|
||||
@@ -252,8 +236,7 @@ class Worker:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
if self._tg:
|
||||
self._tg.cancel_scope.cancel()
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
@@ -270,24 +253,28 @@ class Worker:
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
return TopologyEdgeCreated(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
return TopologyEdgeDeleted(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
@@ -336,7 +323,6 @@ class Worker:
|
||||
event_sender=self.event_sender.clone(),
|
||||
)
|
||||
self.runners[task.bound_instance.bound_runner_id] = runner
|
||||
assert self._tg
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
@@ -399,7 +385,6 @@ class Worker:
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
assert self._tg
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
@@ -420,9 +405,14 @@ class Worker:
|
||||
|
||||
async def _poll_connection_updates(self):
|
||||
while True:
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(self.state.topology, self.node_id)
|
||||
edges = set(
|
||||
conn.edge for conn in self.state.topology.out_edges(self.node_id)
|
||||
)
|
||||
conns = await check_reachable(
|
||||
self.state.topology,
|
||||
self.node_id,
|
||||
self.state.node_profiles,
|
||||
)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
@@ -430,26 +420,33 @@ class Worker:
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=nid,
|
||||
edge = SocketConnection(
|
||||
# nonsense multiaddr
|
||||
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
if "." in ip
|
||||
# nonsense multiaddr
|
||||
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
||||
)
|
||||
if edge not in edges:
|
||||
logger.debug(f"ping discovered {edge=}")
|
||||
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id, sink=nid, edge=edge
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for nid, conn in self.state.topology.out_edges(self.node_id):
|
||||
for conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn.edge, SocketConnection):
|
||||
continue
|
||||
if (
|
||||
nid not in conns
|
||||
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
|
||||
conn.sink not in conns
|
||||
or conn.edge.sink_multiaddr.ip_address
|
||||
not in conns.get(conn.source, set())
|
||||
):
|
||||
logger.debug(f"ping failed to discover {conn=}")
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def entrypoint(
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
and len(bound_instance.instance.jaccl_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
@@ -153,11 +162,19 @@ def main(
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
):
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
@@ -207,6 +224,43 @@ def main(
|
||||
break
|
||||
|
||||
|
||||
@cache
|
||||
def get_gpt_oss_encoding():
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return encoding
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
for response in responses:
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
break
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from .profile import start_polling_memory_metrics, start_polling_node_metrics
|
||||
|
||||
__all__ = [
|
||||
"start_polling_node_metrics",
|
||||
"start_polling_memory_metrics",
|
||||
]
|
||||
@@ -1,103 +0,0 @@
|
||||
import platform
|
||||
import shutil
|
||||
from subprocess import CalledProcessError
|
||||
from typing import cast
|
||||
|
||||
from anyio import run_process
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
|
||||
class MacMonError(Exception):
|
||||
"""Exception raised for errors in the MacMon functions."""
|
||||
|
||||
|
||||
def _get_binary_path() -> str:
|
||||
"""
|
||||
Get the path to the macmon binary.
|
||||
|
||||
Raises:
|
||||
MacMonError: If the binary doesn't exist or can't be made executable.
|
||||
"""
|
||||
# Check for macOS with ARM chip
|
||||
system = platform.system().lower()
|
||||
machine = platform.machine().lower()
|
||||
|
||||
if system != "darwin" or not (
|
||||
"arm" in machine or "m1" in machine or "m2" in machine
|
||||
):
|
||||
raise MacMonError("MacMon only supports macOS with Apple Silicon (ARM) chips")
|
||||
|
||||
path = shutil.which("macmon")
|
||||
|
||||
if path is None:
|
||||
raise MacMonError("MacMon not found in PATH")
|
||||
|
||||
return path
|
||||
|
||||
|
||||
class TempMetrics(BaseModel):
|
||||
"""Temperature-related metrics returned by macmon."""
|
||||
|
||||
cpu_temp_avg: float
|
||||
gpu_temp_avg: float
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
class Metrics(BaseModel):
|
||||
"""Complete set of metrics returned by macmon.
|
||||
|
||||
Unknown fields are ignored for forward-compatibility.
|
||||
"""
|
||||
|
||||
all_power: float
|
||||
ane_power: float
|
||||
cpu_power: float
|
||||
ecpu_usage: tuple[int, float]
|
||||
gpu_power: float
|
||||
gpu_ram_power: float
|
||||
gpu_usage: tuple[int, float]
|
||||
pcpu_usage: tuple[int, float]
|
||||
ram_power: float
|
||||
sys_power: float
|
||||
temp: TempMetrics
|
||||
timestamp: str
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics:
|
||||
"""
|
||||
Asynchronously run the binary and return the metrics as a Python dictionary.
|
||||
|
||||
Args:
|
||||
binary_path: Optional path to the binary. If not provided, will use the bundled binary.
|
||||
|
||||
Returns:
|
||||
A mapping containing system metrics.
|
||||
|
||||
Raises:
|
||||
MacMonError: If there's an error running the binary.
|
||||
"""
|
||||
path = _get_binary_path()
|
||||
|
||||
try:
|
||||
# TODO: Keep Macmon running in the background?
|
||||
result = await run_process([path, "pipe", "-s", "1"])
|
||||
|
||||
return Metrics.model_validate_json(result.stdout.decode().strip())
|
||||
|
||||
except ValidationError as e:
|
||||
raise MacMonError(f"Error parsing JSON output: {e}") from e
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
if stderr_output is not None:
|
||||
stderr_msg = (
|
||||
stderr_output.decode()
|
||||
if isinstance(stderr_output, bytes)
|
||||
else str(stderr_output)
|
||||
)
|
||||
raise MacMonError(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
) from e
|
||||
@@ -1,78 +0,0 @@
|
||||
import http.client
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.HTTPException:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
f"ip={target_ip}, expected_node_id={expected_node_id}, "
|
||||
f"remote_node_id={remote_node_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if remote_node_id not in out:
|
||||
out[remote_node_id] = set()
|
||||
out[remote_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
)
|
||||
|
||||
return reachable
|
||||
@@ -1,114 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
from .macmon import (
|
||||
MacMonError,
|
||||
Metrics,
|
||||
)
|
||||
from .macmon import (
|
||||
get_metrics_async as macmon_get_metrics_async,
|
||||
)
|
||||
from .system_info import (
|
||||
get_friendly_name,
|
||||
get_model_and_chip,
|
||||
get_network_interfaces,
|
||||
)
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics | None:
|
||||
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
|
||||
|
||||
if platform.system().lower() == "darwin":
|
||||
return await macmon_get_metrics_async()
|
||||
|
||||
|
||||
def get_memory_profile() -> MemoryPerformanceProfile:
|
||||
"""Construct a MemoryPerformanceProfile using psutil"""
|
||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||
override_memory: int | None = (
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||
if override_memory_env
|
||||
else None
|
||||
)
|
||||
|
||||
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
|
||||
|
||||
|
||||
async def start_polling_memory_metrics(
|
||||
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
|
||||
*,
|
||||
poll_interval_s: float = 0.5,
|
||||
) -> None:
|
||||
"""Continuously poll and emit memory-only metrics at a faster cadence.
|
||||
|
||||
Parameters
|
||||
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
|
||||
- poll_interval_s: interval between polls
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
mem = get_memory_profile()
|
||||
await callback(mem)
|
||||
except MacMonError as e:
|
||||
logger.opt(exception=e).error("Memory Monitor encountered error")
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
|
||||
|
||||
async def start_polling_node_metrics(
|
||||
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
|
||||
):
|
||||
poll_interval_s = 1.0
|
||||
while True:
|
||||
try:
|
||||
metrics = await get_metrics_async()
|
||||
if metrics is None:
|
||||
return
|
||||
|
||||
network_interfaces = get_network_interfaces()
|
||||
# these awaits could be joined but realistically they should be cached
|
||||
model_id, chip_id = await get_model_and_chip()
|
||||
friendly_name = await get_friendly_name()
|
||||
|
||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
||||
memory_profile = get_memory_profile()
|
||||
|
||||
await callback(
|
||||
NodePerformanceProfile(
|
||||
model_id=model_id,
|
||||
chip_id=chip_id,
|
||||
friendly_name=friendly_name,
|
||||
network_interfaces=network_interfaces,
|
||||
memory=memory_profile,
|
||||
system=SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
sys_power=metrics.sys_power,
|
||||
pcpu_usage=metrics.pcpu_usage[1],
|
||||
ecpu_usage=metrics.ecpu_usage[1],
|
||||
ane_power=metrics.ane_power,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
|
||||
)
|
||||
except MacMonError as e:
|
||||
logger.opt(exception=e).error("Resource Monitor encountered error")
|
||||
return
|
||||
finally:
|
||||
await anyio.sleep(poll_interval_s)
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Tests for macmon error handling.
|
||||
|
||||
These tests verify that MacMon errors are handled gracefully without
|
||||
crashing the application or spamming logs.
|
||||
"""
|
||||
|
||||
import platform
|
||||
from subprocess import CalledProcessError
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.worker.utils.macmon import MacMonError, get_metrics_async
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system().lower() != "darwin" or "arm" not in platform.machine().lower(),
|
||||
reason="MacMon only supports macOS with Apple Silicon",
|
||||
)
|
||||
class TestMacMonErrorHandling:
|
||||
"""Test MacMon error handling."""
|
||||
|
||||
async def test_called_process_error_wrapped_as_macmon_error(self) -> None:
|
||||
"""CalledProcessError should be wrapped as MacMonError."""
|
||||
mock_error = CalledProcessError(
|
||||
returncode=1,
|
||||
cmd=["macmon", "pipe", "-s", "1"],
|
||||
stderr=b"some error message",
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
|
||||
),
|
||||
patch(
|
||||
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
|
||||
) as mock_run,
|
||||
):
|
||||
mock_run.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MacMonError) as exc_info:
|
||||
await get_metrics_async()
|
||||
|
||||
assert "MacMon failed with return code 1" in str(exc_info.value)
|
||||
assert "some error message" in str(exc_info.value)
|
||||
|
||||
async def test_called_process_error_with_no_stderr(self) -> None:
|
||||
"""CalledProcessError with no stderr should be handled gracefully."""
|
||||
mock_error = CalledProcessError(
|
||||
returncode=1,
|
||||
cmd=["macmon", "pipe", "-s", "1"],
|
||||
stderr=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
|
||||
),
|
||||
patch(
|
||||
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
|
||||
) as mock_run,
|
||||
):
|
||||
mock_run.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MacMonError) as exc_info:
|
||||
await get_metrics_async()
|
||||
|
||||
assert "MacMon failed with return code 1" in str(exc_info.value)
|
||||
assert "no stderr" in str(exc_info.value)
|
||||
|
||||
async def test_macmon_not_found_raises_macmon_error(self) -> None:
|
||||
"""When macmon is not found in PATH, MacMonError should be raised."""
|
||||
with patch("exo.worker.utils.macmon.shutil.which", return_value=None):
|
||||
with pytest.raises(MacMonError) as exc_info:
|
||||
await get_metrics_async()
|
||||
|
||||
assert "MacMon not found in PATH" in str(exc_info.value)
|
||||
@@ -34,7 +34,8 @@ from exo.shared.types.worker.instances import (
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, mp_channel
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
@@ -65,6 +66,7 @@ async def main():
|
||||
app = FastAPI()
|
||||
app.post("/ring")(ring_backend)
|
||||
app.post("/jaccl")(jaccl_backend)
|
||||
app.post("/tb_detection")(tb_detection)
|
||||
shutdown = anyio.Event()
|
||||
await serve(
|
||||
app, # type: ignore
|
||||
@@ -76,6 +78,15 @@ async def main():
|
||||
shutdown.set()
|
||||
|
||||
|
||||
async def tb_detection():
|
||||
send, recv = channel[GatheredInfo]()
|
||||
ig = InfoGatherer(send)
|
||||
with anyio.move_on_after(1):
|
||||
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
|
||||
with recv:
|
||||
return recv.collect()
|
||||
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
@@ -89,6 +100,12 @@ async def assert_downloads():
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
@@ -203,16 +220,16 @@ async def jaccl_backend(test: Tests):
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
|
||||
return await execute_test(test, jaccl_instance(test, iid), hn)
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
world_size = len(test.devs)
|
||||
|
||||
return MlxJacclInstance(
|
||||
instance_id=iid,
|
||||
ibv_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
|
||||
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
|
||||
# rank 0 is always coordinator
|
||||
jaccl_coordinators={
|
||||
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
|
||||
|
||||
Reference in New Issue
Block a user