Compare commits

..

9 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
8d6e52bdb5 Try wrong shardings 2026-01-16 19:39:44 +00:00
Ryuichi Leo Takashige
a412ec6d04 Comment out custom moe layer 2026-01-16 19:33:22 +00:00
Ryuichi Leo Takashige
8d00c6ad44 Comment out sharding group only 2026-01-16 19:28:31 +00:00
Ryuichi Leo Takashige
9bf1bb3025 remove moe sharding 2026-01-16 19:24:06 +00:00
Ryuichi Leo Takashige
c613df4d8d formatting 2026-01-16 18:48:57 +00:00
Ryuichi Leo Takashige
f5d1532245 revert api change 2026-01-16 18:46:54 +00:00
Ryuichi Leo Takashige
659fbdf7ea Handle model timeouts
- Add eval with a timeout.
- Add fast synch flag

Timeout mx barrier too
2026-01-16 18:40:48 +00:00
Ryuichi Leo Takashige
313b24fe04 Handle broken resource error gracefully
# Conflicts:
#	src/exo/master/api.py
2026-01-16 18:32:07 +00:00
Ryuichi Leo Takashige
b5bd0ffc98 Return error responses for Chat Completions
- Error chunks
- Use error handling in exo_bench.py
2026-01-16 18:30:55 +00:00
78 changed files with 2385 additions and 3249 deletions

View File

@@ -1,16 +1,5 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
@@ -22,10 +11,8 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_VERSION: 2.8.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
@@ -100,52 +87,6 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -363,28 +304,6 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -417,26 +336,3 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi

View File

@@ -40,31 +40,6 @@ uv run ruff check
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition

View File

@@ -27,22 +27,13 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Dashboard
exo includes a built-in dashboard for managing your cluster and chatting with models.
<p align="center">
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
</p>
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
## Benchmarks
<details>
<summary>Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-1-qwen3-235b.jpeg" alt="Benchmark - Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -50,7 +41,7 @@ exo includes a built-in dashboard for managing your cluster and chatting with mo
<summary>DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-2-deepseek-3.1-671b.jpeg" alt="Benchmark - DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -58,7 +49,7 @@ exo includes a built-in dashboard for managing your cluster and chatting with mo
<summary>Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-3-kimi-k2-thinking.jpeg" alt="Benchmark - Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -163,24 +154,6 @@ This starts the exo dashboard and API at http://localhost:52415/
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
**Configuration Options:**
- `--no-worker`: Run exo without the worker component. Useful for coordinator-only nodes that handle networking and orchestration but don't execute inference tasks. This is helpful for machines without sufficient GPU resources but with good network connectivity.
```bash
uv run exo --no-worker
```
**File Locations (Linux):**
exo follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) on Linux:
- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)
- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)
- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)
You can override these locations by setting the corresponding XDG environment variables.
### macOS App
exo ships a macOS app that runs in the background on your Mac.
@@ -193,19 +166,6 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
**Custom Namespace for Cluster Isolation:**
The macOS app includes a custom namespace feature that allows you to isolate your exo cluster from others on the same network. This is configured through the `EXO_LIBP2P_NAMESPACE` setting:
- **Use cases**:
- Running multiple separate exo clusters on the same network
- Isolating development/testing clusters from production clusters
- Preventing accidental cluster joining
- **Configuration**: Access this setting in the app's Advanced settings (or set the `EXO_LIBP2P_NAMESPACE` environment variable when running from source)
The namespace is logged on startup for debugging purposes.
#### Uninstalling the macOS App
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
@@ -352,52 +312,6 @@ For further details, see:
---
## Benchmarking
The `exo-bench` tool measures model prefill and token generation speed across different placement configurations. This helps you optimize model performance and validate improvements.
**Prerequisites:**
- Nodes should be running with `uv run exo` before benchmarking
- The tool uses the `/bench/chat/completions` endpoint
**Basic usage:**
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--pp 128,256,512 \
--tg 128,256
```
**Key parameters:**
- `--model`: Model to benchmark (short ID or HuggingFace ID)
- `--pp`: Prompt size hints (comma-separated integers)
- `--tg`: Generation lengths (comma-separated integers)
- `--max-nodes`: Limit placements to N nodes (default: 4)
- `--instance-meta`: Filter by `ring`, `jaccl`, or `both` (default: both)
- `--sharding`: Filter by `pipeline`, `tensor`, or `both` (default: both)
- `--repeat`: Number of repetitions per configuration (default: 1)
- `--warmup`: Warmup runs per placement (default: 0)
- `--json-out`: Output file for results (default: bench/results.json)
**Example with filters:**
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--pp 128,512 \
--tg 128 \
--max-nodes 2 \
--sharding tensor \
--repeat 3 \
--json-out my-results.json
```
The tool outputs performance metrics including prompt tokens per second (prompt_tps), generation tokens per second (generation_tps), and peak memory usage for each configuration.
---
## Hardware Accelerator Support
On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working on extending hardware accelerator support. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community.
@@ -406,4 +320,4 @@ On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working
## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.

View File

@@ -19,7 +19,6 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.9.0-beta.1;
minimumVersion = 2.8.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
}
}
],

View File

@@ -6,7 +6,7 @@ enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
@@ -28,6 +28,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
@@ -112,13 +141,6 @@ enum NetworkSetupHelper {
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(

View File

@@ -496,9 +496,9 @@ def main() -> int:
and "tensor" in sharding.lower()
):
model_card = MODEL_CARDS[short_id]
if model_card.storage_size > Memory.from_gb(10):
if model_card.metadata.storage_size > Memory.from_gb(10):
logger.info(
f"Skipping tensor ring as this is too slow for model of size {model_card.storage_size} on {n_nodes=}"
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
)
continue
for tg in tg_list:

View File

@@ -53,285 +53,62 @@
marked.use({ renderer });
/**
* Unescape HTML entities that marked may have escaped
*/
function unescapeHtmlEntities(text: string): string {
return text
.replace(/&lt;/g, '<')
.replace(/&gt;/g, '>')
.replace(/&amp;/g, '&')
.replace(/&quot;/g, '"')
.replace(/&#39;/g, "'");
}
// Storage for math expressions extracted before markdown processing
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
let mathCounter = 0;
// Storage for HTML snippets that need protection from markdown
const htmlSnippets: Map<string, string> = new Map();
let htmlCounter = 0;
// Use alphanumeric placeholders that won't be interpreted as HTML tags
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
const HTML_PLACEHOLDER_PREFIX = 'HTMLPLACEHOLDER';
/**
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
*/
function preprocessLaTeX(text: string): string {
// Reset storage
mathExpressions.clear();
mathCounter = 0;
htmlSnippets.clear();
htmlCounter = 0;
// Protect code blocks first
// Protect code blocks
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
return `<<CODE_${codeBlocks.length - 1}>>`;
});
// Remove LaTeX document commands
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\begin\{document\}/g, '');
processed = processed.replace(/\\end\{document\}/g, '');
processed = processed.replace(/\\maketitle/g, '');
processed = processed.replace(/\\title\{[^}]*\}/g, '');
processed = processed.replace(/\\author\{[^}]*\}/g, '');
processed = processed.replace(/\\date\{[^}]*\}/g, '');
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
processed = processed.replace(/\\require\{[^}]*\}/g, '');
// Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">📐</span><span class="latex-diagram-text">Diagram</span></div>');
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\begin\{figure\}[\s\S]*?\\end\{figure\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">🖼️</span><span class="latex-diagram-text">Figure</span></div>');
htmlCounter++;
return placeholder;
});
// Strip center environment (layout only, no content change)
processed = processed.replace(/\\begin\{center\}/g, '');
processed = processed.replace(/\\end\{center\}/g, '');
// Strip other layout environments
processed = processed.replace(/\\begin\{flushleft\}/g, '');
processed = processed.replace(/\\end\{flushleft\}/g, '');
processed = processed.replace(/\\begin\{flushright\}/g, '');
processed = processed.replace(/\\end\{flushright\}/g, '');
processed = processed.replace(/\\label\{[^}]*\}/g, '');
processed = processed.replace(/\\caption\{[^}]*\}/g, '');
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
// Convert LaTeX math environments to display math (both bare and wrapped in $...$)
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*', 'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', 'cases'];
for (const env of mathEnvs) {
// Handle $\begin{env}...\end{env}$ (with dollar signs, possibly multiline)
const wrappedRegex = new RegExp(`\\$\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}\\$`, 'g');
processed = processed.replace(wrappedRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
// Handle bare \begin{env}...\end{env} (without dollar signs)
const bareRegex = new RegExp(`\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
processed = processed.replace(bareRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
}
// Convert LaTeX proof environments to styled blocks (use placeholders for HTML)
processed = processed.replace(
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
(_, content) => {
const html = `<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
}
);
// Convert LaTeX theorem-like environments
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
for (const env of theoremEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
const envName = env.charAt(0).toUpperCase() + env.slice(1);
processed = processed.replace(envRegex, (_, content) => {
const html = `<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
});
}
// Convert LaTeX text formatting commands (use placeholders to protect from markdown)
processed = processed.replace(/\\emph\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textit\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textbf\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<strong>${content}</strong>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\texttt\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<code class="inline-code">${content}</code>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\underline\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<u>${content}</u>`);
htmlCounter++;
return placeholder;
});
// Handle LaTeX line breaks and spacing
processed = processed.replace(/\\\\(?:\s*\n)?/g, '\n'); // \\ -> newline
processed = processed.replace(/\\newline/g, '\n');
processed = processed.replace(/\\par\b/g, '\n\n');
processed = processed.replace(/\\quad/g, ' ');
processed = processed.replace(/\\qquad/g, ' ');
processed = processed.replace(/~~/g, ' '); // non-breaking space
// Remove other common LaTeX commands that don't render
processed = processed.replace(/\\centering/g, '');
processed = processed.replace(/\\noindent/g, '');
processed = processed.replace(/\\hfill/g, '');
processed = processed.replace(/\\vspace\{[^}]*\}/g, '');
processed = processed.replace(/\\hspace\{[^}]*\}/g, ' ');
// Convert \(...\) to placeholder (display: false)
processed = processed.replace(/\\\(([\s\S]+?)\\\)/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: false });
mathCounter++;
return placeholder;
});
// Convert \[...\] to placeholder (display: true)
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: true });
mathCounter++;
return placeholder;
});
// Extract display math ($$...$$) BEFORE markdown processing
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
mathCounter++;
return placeholder;
});
// Extract inline math ($...$) BEFORE markdown processing
// Allow single-line only, skip currency patterns like $5 or $50
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
if (/^\d/.test(content.trim())) {
return match; // Keep as-is for currency
}
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
mathCounter++;
return placeholder;
});
// Restore escaped dollar signs
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Restore code blocks
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
// Clean up any remaining stray backslashes from unrecognized commands
processed = processed.replace(/\\(?=[a-zA-Z])/g, ''); // Remove \ before letters (unrecognized commands)
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
return processed;
}
/**
* Render math expressions with KaTeX and restore HTML placeholders
* Render math expressions with KaTeX after HTML is generated
*/
function renderMath(html: string): string {
// Replace all math placeholders with rendered KaTeX
for (const [placeholder, { content, displayMode }] of mathExpressions) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
html = html.replace(regex, () => {
try {
const rendered = katex.renderToString(content, {
displayMode,
throwOnError: false,
output: 'html'
});
if (displayMode) {
return `
<div class="math-display-wrapper">
<div class="math-display-header">
<span class="math-label">LaTeX</span>
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<div class="math-display-content">
${rendered}
</div>
</div>
`;
} else {
return `<span class="math-inline">${rendered}</span>`;
}
} catch {
const display = displayMode ? `$$${content}$$` : `$${content}$`;
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
}
});
}
// Restore HTML placeholders (for \textbf, \emph, etc.)
for (const [placeholder, htmlContent] of htmlSnippets) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
html = html.replace(regex, htmlContent);
}
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
return html;
}
@@ -377,50 +154,16 @@
}
}
async function handleMathCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedSource = target.getAttribute('data-math-source');
if (!encodedSource) return;
const source = decodeURIComponent(encodedSource);
try {
await navigator.clipboard.writeText(source);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy math:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of codeButtons) {
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
for (const button of mathButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleMathCopyClick);
}
}
}
$effect(() => {
@@ -681,290 +424,28 @@
color: #60a5fa;
}
/* KaTeX math styling - Base */
/* KaTeX math styling */
.markdown-content :global(.katex) {
font-size: 1.1em;
color: oklch(0.9 0 0);
}
/* Display math container wrapper */
.markdown-content :global(.math-display-wrapper) {
.markdown-content :global(.katex-display) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.15);
background: rgba(0, 0, 0, 0.3);
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover) {
border-color: rgba(255, 215, 0, 0.25);
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
}
/* Display math header - hidden by default, slides in on hover */
.markdown-content :global(.math-display-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.375rem 0.75rem;
background: rgba(255, 215, 0, 0.03);
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
opacity: 0;
max-height: 0;
padding-top: 0;
padding-bottom: 0;
overflow: hidden;
transition:
opacity 0.2s ease,
max-height 0.2s ease,
padding 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
opacity: 1;
max-height: 2.5rem;
padding: 0.375rem 0.75rem;
}
.markdown-content :global(.math-label) {
color: rgba(255, 215, 0, 0.7);
font-size: 0.65rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-math-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
opacity: 0;
transition:
color 0.2s,
opacity 0.15s ease;
}
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
opacity: 1;
}
.markdown-content :global(.copy-math-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-math-btn.copied) {
color: #22c55e;
}
/* Display math content area */
.markdown-content :global(.math-display-content) {
padding: 1rem 1.25rem;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
/* Custom scrollbar for math overflow */
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
height: 6px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
background: rgba(255, 255, 255, 0.05);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
background: rgba(255, 215, 0, 0.2);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
background: rgba(255, 215, 0, 0.35);
}
.markdown-content :global(.math-display-content .katex-display) {
margin: 0;
padding: 0;
}
.markdown-content :global(.math-display-content .katex-display > .katex) {
.markdown-content :global(.katex-display > .katex) {
text-align: center;
}
/* Inline math wrapper */
.markdown-content :global(.math-inline) {
display: inline;
padding: 0 0.125rem;
border-radius: 0.25rem;
transition: background-color 0.15s ease;
}
.markdown-content :global(.math-inline:hover) {
background: rgba(255, 215, 0, 0.05);
}
/* Dark theme KaTeX overrides */
.markdown-content :global(.katex .mord),
.markdown-content :global(.katex .minner),
.markdown-content :global(.katex .mop),
.markdown-content :global(.katex .mbin),
.markdown-content :global(.katex .mrel),
.markdown-content :global(.katex .mpunct) {
color: oklch(0.9 0 0);
}
/* Fraction lines and rules */
.markdown-content :global(.katex .frac-line),
.markdown-content :global(.katex .overline-line),
.markdown-content :global(.katex .underline-line),
.markdown-content :global(.katex .hline),
.markdown-content :global(.katex .rule) {
border-color: oklch(0.85 0 0) !important;
background: oklch(0.85 0 0);
}
/* Square roots and SVG elements */
.markdown-content :global(.katex .sqrt-line) {
border-color: oklch(0.85 0 0) !important;
}
.markdown-content :global(.katex svg) {
fill: oklch(0.85 0 0);
stroke: oklch(0.85 0 0);
}
.markdown-content :global(.katex svg path) {
stroke: oklch(0.85 0 0);
}
/* Delimiters (parentheses, brackets, braces) */
.markdown-content :global(.katex .delimsizing),
.markdown-content :global(.katex .delim-size1),
.markdown-content :global(.katex .delim-size2),
.markdown-content :global(.katex .delim-size3),
.markdown-content :global(.katex .delim-size4),
.markdown-content :global(.katex .mopen),
.markdown-content :global(.katex .mclose) {
color: oklch(0.75 0 0);
}
/* Math error styling */
.markdown-content :global(.math-error) {
display: inline-flex;
align-items: center;
gap: 0.375rem;
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.25rem 0.5rem;
padding: 0.125rem 0.25rem;
border-radius: 0.25rem;
border: 1px solid rgba(248, 113, 113, 0.2);
}
.markdown-content :global(.math-error-icon) {
font-size: 0.875em;
opacity: 0.9;
}
/* LaTeX proof environment */
.markdown-content :global(.latex-proof) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 255, 255, 0.02);
border-left: 3px solid rgba(255, 215, 0, 0.4);
border-radius: 0 0.375rem 0.375rem 0;
}
.markdown-content :global(.latex-proof-header) {
font-weight: 600;
font-style: italic;
color: oklch(0.85 0 0);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-proof-header::after) {
content: '.';
}
.markdown-content :global(.latex-proof-content) {
color: oklch(0.9 0 0);
}
.markdown-content :global(.latex-proof-content p:last-child) {
margin-bottom: 0;
}
/* QED symbol at end of proof */
.markdown-content :global(.latex-proof-content::after) {
content: '∎';
display: block;
text-align: right;
color: oklch(0.7 0 0);
margin-top: 0.5rem;
}
/* LaTeX theorem-like environments */
.markdown-content :global(.latex-theorem) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 215, 0, 0.03);
border: 1px solid rgba(255, 215, 0, 0.15);
border-radius: 0.375rem;
}
.markdown-content :global(.latex-theorem-header) {
font-weight: 700;
color: var(--exo-yellow, #ffd700);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-theorem-header::after) {
content: '.';
}
.markdown-content :global(.latex-theorem-content) {
color: oklch(0.9 0 0);
font-style: italic;
}
.markdown-content :global(.latex-theorem-content p:last-child) {
margin-bottom: 0;
}
/* LaTeX diagram/figure placeholder */
.markdown-content :global(.latex-diagram-placeholder) {
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
margin: 1rem 0;
padding: 1.5rem 2rem;
background: rgba(255, 255, 255, 0.02);
border: 1px dashed rgba(255, 215, 0, 0.25);
border-radius: 0.5rem;
color: rgba(255, 215, 0, 0.6);
font-size: 0.875rem;
}
.markdown-content :global(.latex-diagram-icon) {
font-size: 1.25rem;
opacity: 0.8;
}
.markdown-content :global(.latex-diagram-text) {
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.75rem;
text-transform: uppercase;
letter-spacing: 0.05em;
}
</style>

View File

@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
// Uses API preview data when available, falls back to local estimation
const placementPreview = $derived(() => {
const nodeArray = nodeList();
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
const numNodes = nodeArray.length;
const iconSize = numNodes === 1 ? 50 : 36;

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: NodeInfo | undefined): string | null {
function checkNode(node: typeof data.nodes[string]): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
@@ -39,19 +39,17 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
return matchFromInterfaces.name;
}
if (node.ip_to_interface) {
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
if (mapped && mapped.trim().length > 0) {
return mapped;
}
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
}
return null;
}
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
@@ -257,24 +255,21 @@ function wrapLine(text: string, maxLen: number): string[] {
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
const pairMap = new Map<string, PairEntry>();
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
if (edge.source === a) entry.aToB = true;
else entry.bToA = true;
const ip = edge.sendBackIp || '?';
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
const ifaceInfo = getInterfaceLabel(edge.source, ip);
entry.connections.push({
from: edge.source,
@@ -343,8 +338,9 @@ function wrapLine(text: string, maxLen: number): string[] {
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
@@ -385,32 +381,32 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
const quadrants: Record<string, typeof debugEdgeLabels> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
if (quadrantEdges.length === 0) return;
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
quadrantEdges.forEach(edge => {
edges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;

View File

@@ -99,36 +99,20 @@ interface RawNodeProfile {
interface RawTopologyNode {
nodeId: string;
nodeProfile?: RawNodeProfile;
nodeProfile: RawNodeProfile;
}
// New connection edge types from Python SocketConnection/RDMAConnection
interface RawSocketConnection {
sinkMultiaddr?: {
address?: string;
// Multiaddr uses snake_case (no camelCase alias)
ip_address?: string;
ipAddress?: string; // fallback in case it changes
address_type?: string;
port?: number;
};
interface RawTopologyConnection {
localNodeId: string;
sendBackNodeId: string;
sendBackMultiaddr?:
| { multiaddr?: string; address?: string; ip_address?: string }
| string;
}
interface RawRDMAConnection {
sourceRdmaIface?: string;
sinkRdmaIface?: string;
}
type RawConnectionEdge = RawSocketConnection | RawRDMAConnection;
// New nested mapping format: { source: { sink: [edge1, edge2, ...] } }
type RawConnectionsMap = Record<string, Record<string, RawConnectionEdge[]>>;
interface RawTopology {
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
// New nested mapping format
connections?: RawConnectionsMap;
nodes: RawTopologyNode[];
connections?: RawTopologyConnection[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
@@ -229,18 +213,9 @@ function transformTopology(
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === "string" ? node : node.nodeId;
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode =
typeof node === "object" ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
@@ -289,7 +264,7 @@ function transformTopology(
}
}
nodes[nodeId] = {
nodes[node.nodeId] = {
system_info: {
model_id: profile?.modelId ?? "Unknown",
chip: profile?.chipId,
@@ -317,34 +292,29 @@ function transformTopology(
};
}
// Handle connections - nested mapping format { source: { sink: [edges] } }
const connections = raw.connections;
if (connections && typeof connections === "object") {
for (const [source, sinks] of Object.entries(connections)) {
if (!sinks || typeof sinks !== "object") continue;
for (const [sink, edgeList] of Object.entries(sinks)) {
if (!Array.isArray(edgeList)) continue;
for (const edge of edgeList) {
// Extract IP from SocketConnection (uses snake_case: ip_address)
let sendBackIp: string | undefined;
if (edge && typeof edge === "object" && "sinkMultiaddr" in edge) {
const multiaddr = edge.sinkMultiaddr;
if (multiaddr) {
// Try both snake_case (actual) and camelCase (in case it changes)
sendBackIp =
multiaddr.ip_address ||
multiaddr.ipAddress ||
extractIpFromMultiaddr(multiaddr.address);
}
}
// RDMAConnection (sourceRdmaIface/sinkRdmaIface) has no IP - edge just shows connection exists
for (const conn of raw.connections || []) {
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
if (conn.localNodeId === conn.sendBackNodeId) continue;
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
if (nodes[source] && nodes[sink] && source !== sink) {
edges.push({ source, target: sink, sendBackIp });
}
}
let sendBackIp: string | undefined;
if (conn.sendBackMultiaddr) {
const multi = conn.sendBackMultiaddr;
if (typeof multi === "string") {
sendBackIp = extractIpFromMultiaddr(multi);
} else {
sendBackIp =
multi.ip_address ||
extractIpFromMultiaddr(multi.multiaddr) ||
extractIpFromMultiaddr(multi.address);
}
}
edges.push({
source: conn.localNodeId,
target: conn.sendBackNodeId,
sendBackIp,
});
}
return { nodes, edges };

View File

@@ -434,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -915,7 +915,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
const [tag, shard] = getTagged(shardWrapped);
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
return { runnerId, tag, deviceRank };
});

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_card ?? shardData.modelCard;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 187 KiB

View File

@@ -126,6 +126,3 @@ env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -19,8 +19,8 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -59,6 +59,7 @@ from exo.shared.types.events import (
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -86,12 +87,12 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: str) -> ModelCard:
async def resolve_model_meta(model_id: str) -> ModelMetadata:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
return model_card.metadata
else:
return await get_model_card(model_id)
return await get_model_meta(model_id)
class API:
@@ -154,19 +155,18 @@ class API:
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
self.app.exception_handler(HTTPException)(self.http_exception_handler)
async def http_exception_handler(
self, _: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
@@ -196,7 +196,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_meta=await resolve_model_meta(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -206,15 +206,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=command.model_card,
model_meta=command.model_meta,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -231,7 +231,7 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=model_card,
model_meta=model_meta,
)
async def get_placement(
@@ -241,17 +241,16 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_meta = await resolve_model_meta(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -278,7 +277,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -296,32 +295,32 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for card in cards:
model_meta = card.metadata
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -332,17 +331,17 @@ class API:
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -351,7 +350,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_card.storage_size.in_bytes
total_bytes = model_meta.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -359,14 +358,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -374,7 +373,7 @@ class API:
error=None,
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -549,8 +548,8 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -576,8 +575,8 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -600,8 +599,9 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
return total_available
@@ -610,13 +610,13 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.model_id,
id=card.short_id,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
)
for card in MODEL_CARDS.values()
]

View File

@@ -27,7 +27,6 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
@@ -159,7 +158,6 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -202,7 +200,9 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
@@ -237,8 +237,6 @@ class Master:
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)

View File

@@ -6,25 +6,23 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -54,32 +52,37 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_profiles, command.model_card.storage_size
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
if len(cycles_with_sufficient_memory) == 0:
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_card.supports_tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_card.hidden_size % len(cycle) == 0
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -89,38 +92,44 @@ def place_instance(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
if any(topology.node_is_leaf(node.node_id) for node in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle),
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
start=Memory(),
),
)
shard_assignments = get_shard_assignments(
command.model_card, selected_cycle, command.sharding, node_profiles
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -128,20 +137,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
[node_id for node_id in selected_cycle],
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle.node_ids[0],
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_profiles=node_profiles,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
@@ -150,7 +158,6 @@ def place_instance(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
node_profiles=node_profiles,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,

View File

@@ -1,13 +1,15 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import TypeGuard, cast
from loguru import logger
from pydantic import BaseModel
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -17,113 +19,67 @@ from exo.shared.types.worker.shards import (
)
class NodeWithProfile(BaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[Cycle],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[Cycle]:
filtered_cycles: list[Cycle] = []
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
for cycle in cycles:
if not all(node in node_profiles for node in cycle):
if not narrow_all_nodes(cycle):
continue
total_mem = sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
start=Memory(),
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cycle)
filtered_cycles.append(cast(list[NodeInfo], cycle))
return filtered_cycles
def get_smallest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_card.n_layers
world_size = len(cycle)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_profiles[node_id].memory.ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_profiles[node_id].memory.ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -132,11 +88,11 @@ def get_shard_assignments_for_pipeline_parallel(
)
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
node_to_runner[node.node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -145,17 +101,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_card: ModelCard,
cycle: Cycle,
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
total_layers = model_card.n_layers
world_size = len(cycle)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
for i, node in enumerate(selected_cycle):
shard = TensorShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -166,10 +122,10 @@ def get_shard_assignments_for_tensor_parallel(
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
node_to_runner[node.node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -178,22 +134,22 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_card: ModelCard,
cycle: Cycle,
model_meta: ModelMetadata,
selected_cycle: list[NodeInfo],
sharding: Sharding,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_card=model_card,
cycle=cycle,
node_profiles=node_profiles,
model_meta=model_meta,
selected_cycle=selected_cycle,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_card=model_card,
cycle=cycle,
model_meta=model_meta,
selected_cycle=selected_cycle,
)
@@ -208,40 +164,38 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
)
return []
cycle = cycles[0]
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycle):
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
get_thunderbolt = True
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
cycle = cycles[0]
hosts: list[Host] = []
for i in range(len(cycle)):
current_node = cycle.node_ids[i]
next_node = cycle.node_ids[(i + 1) % len(cycle)]
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.get_all_connections_between(
source=current_node, sink=next_node
):
if not isinstance(connection, SocketConnection):
continue
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break
return hosts
def get_mlx_jaccl_devices_matrix(
selected_cycle: list[NodeId],
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -260,37 +214,72 @@ def get_mlx_jaccl_devices_matrix(
if i == j:
continue
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
break
else:
logger.warning(
f"Failed to find interface name between {node_i} and {node_j}"
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
"Current ibv backend requires all-to-all rdma connections"
)
return matrix
def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def _find_interface_name_for_ip(
ip_address: str, node_profile: NodePerformanceProfile
ip_address: str,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
for interface in node_profile.network_interfaces:
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -298,10 +287,7 @@ def _find_interface_name_for_ip(
def _find_ip_prioritised(
node_id: NodeId,
other_node_id: NodeId,
cycle_digraph: Topology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -312,12 +298,9 @@ def _find_ip_prioritised(
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {
_find_interface_name_for_ip(ip, node_profiles[other_node_id]): ip
for ip, _ in ips
}
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
en0_ip = iface_map.get("en0")
if en0_ip:
@@ -341,10 +324,9 @@ def _find_ip_prioritised(
def get_mlx_ring_hosts_by_node(
selected_cycle: Cycle,
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
@@ -359,13 +341,14 @@ def get_mlx_ring_hosts_by_node(
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node_id in enumerate(selected_cycle):
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
hosts_for_node: list[Host] = []
for idx, other_node_id in enumerate(selected_cycle):
for idx, other_node in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
@@ -375,12 +358,10 @@ def get_mlx_ring_hosts_by_node(
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_profiles
)
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
@@ -394,34 +375,31 @@ def get_mlx_ring_hosts_by_node(
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
logger.info(f"Selecting coordinator: {coordinator}")
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_profiles)
if ip is not None:
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n: f"{get_ip_for_node(n)}:{coordinator_port}"
for n in cycle_digraph.list_nodes()
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}

View File

@@ -1,39 +1,67 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
for i in range(10)
],
system=SystemPerformanceProfile(),
)
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)
return _create_connection

View File

@@ -7,7 +7,6 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -20,12 +19,15 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodeGatheredInfo,
NodePerformanceMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -81,14 +83,21 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
NodePerformanceMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
),
@@ -109,8 +118,9 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -153,7 +163,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
@@ -166,8 +176,9 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -1,23 +1,20 @@
from typing import Callable
import pytest
from loguru import logger
from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_node_profile,
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -29,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -42,20 +44,21 @@ def instance() -> Instance:
@pytest.fixture
def model_card() -> ModelCard:
return ModelCard(
def model_meta() -> ModelMetadata:
return ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -67,70 +70,47 @@ def place_instance_command(model_card: ModelCard) -> PlaceInstance:
[
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 468, 1092), 12, (2, 3, 7)),
((312, 518, 1024), 12, (2, 3, 7)),
],
)
def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
model_card: ModelCard,
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_card)
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_id_c, sink=node_id_b, edge=create_socket_connection(4)
)
conn_a_c = Connection(
source=node_id_a, sink=node_id_c, edge=create_socket_connection(5)
)
conn_b_a = Connection(
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
)
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
topology.add_connection(conn_b_a)
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_card.model_id
assert instance.shard_assignments.model_id == model_meta.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -150,21 +130,23 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit() -> None:
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -175,21 +157,23 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
topology.add_node(create_node(1001 * 1024, node_id))
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -200,15 +184,17 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit() -> None:
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -216,7 +202,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, profiles)
place_instance(cic, topology, {})
def test_get_transition_events_no_change(instance: Instance):
@@ -261,169 +247,217 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_leaf_nodes(
model_card: ModelCard,
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
# Model requires more than any single node but fits within a 3-node cycle
model_card.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Daisy chain topology (directed)
topology.add_connection(
Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_d, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
)
cic = place_instance_command(model_card=model_card)
# Act
placements = place_instance(cic, topology, {})
# act
placements = place_instance(cic, topology, {}, profiles)
# assert
# Assert: D-E-F cycle should be selected as it has more total memory
assert len(placements) == 1
instance = list(placements.values())[0]
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(
node_id_c,
node_id_d,
)
)
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(
model_card: ModelCard,
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="10.0.0.1",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
ip_address="192.168.1.100",
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
# RDMA connections (directed)
topology.add_connection(
Connection(source=node_a, sink=node_b, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_a, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_c, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_c, sink=node_b, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_a, sink=node_c, edge=create_rdma_connection(5))
)
topology.add_connection(
Connection(source=node_c, sink=node_a, edge=create_rdma_connection(5))
)
# Ethernet connections (directed)
topology.add_connection(Connection(source=node_a, sink=node_b, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_a, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_b, edge=ethernet_conn))
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
cic = PlaceInstance(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
min_nodes=1,
)
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert isinstance(instance, MlxJacclInstance)
assert instance.jaccl_devices is not None
assert instance.ibv_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.jaccl_devices
matrix = instance.ibv_devices
assert len(matrix) == 3
for i in range(3):
assert matrix[i][i] is None
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3
@@ -435,5 +469,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
if node_id == assigned_nodes[0]:
assert coordinator.startswith("0.0.0.0:")
else:
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
ip_part = coordinator.split(":")[0]
# Just verify it's a valid IP format
assert len(ip_part.split(".")) == 4

View File

@@ -1,187 +1,162 @@
from copy import copy
from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_node_profile, create_socket_connection
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = [c for c in topology.get_cycles() if len(c) != 1]
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 2
assert set(n for n in filtered_cycles[0]) == {node1_id, node2_id}
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory():
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
topology.get_cycles(), Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory():
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 3
assert set(n for n in filtered_cycles[0]) == {
assert set(n.node_id for n in filtered_cycles[0]) == {
node_a_id,
node_b_id,
node_c_id,
}
def test_get_smallest_cycles():
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
# act
smallest_cycles = get_smallest_cycles(cycles)
smallest_cycles = get_smallest_cycles(topology.get_cycles())
# assert
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
@pytest.mark.parametrize(
@@ -190,12 +165,12 @@ def test_get_smallest_cycles():
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -205,61 +180,44 @@ def test_get_shard_assignments(
node_b_id = NodeId()
node_c_id = NodeId()
# create connections (A -> B -> C -> A forms a 3-cycle, plus B -> A also exists)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(4)
)
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_card = ModelCard(
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
# pick the 3-node cycle deterministically (cycle ordering can vary)
selected_cycle = next(cycle for cycle in cycles if len(cycle) == 3)
selected_cycle = cycles[0]
# act
shard_assignments = get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_profiles=node_profiles
model_meta, selected_cycle, Sharding.Pipeline
)
# assert
runner_id_a = shard_assignments.node_to_runner[node_a_id]
runner_id_b = shard_assignments.node_to_runner[node_b_id]
runner_id_c = shard_assignments.node_to_runner[node_c_id]
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
assert (
shard_assignments.runner_to_shard[runner_id_a].end_layer
- shard_assignments.runner_to_shard[runner_id_a].start_layer
@@ -270,37 +228,30 @@ def test_get_shard_assignments(
- shard_assignments.runner_to_shard[runner_id_b].start_layer
== expected_layers[1]
)
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
def test_get_hosts_from_subgraph():
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -308,78 +259,95 @@ def test_get_hosts_from_subgraph():
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip="169.254.0.1", port=1234),
Host(ip="169.254.0.2", port=1234),
Host(ip="169.254.0.3", port=1234),
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(5)
)
conn_a_c = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
npp = NodePerformanceProfile(
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=0,
ram_available=0,
swap_total=0,
swap_available=0,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
npp_a = copy(npp)
npp_a.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
npp_b = copy(npp)
npp_b.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
npp_c = copy(npp)
npp_c.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
node_profiles = {
node_a_id: npp_a,
node_b_id: npp_b,
node_c_id: npp_c,
}
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
@@ -388,12 +356,11 @@ def test_get_mlx_jaccl_coordinators():
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_jaccl_coordinators(
node_a_id,
coordinator_port=5000,
cycle_digraph=topology,
node_profiles=node_profiles,
cycle, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -414,127 +381,19 @@ def test_get_mlx_jaccl_coordinators():
f"Coordinator for {node_id} should use port 5000"
)
# Rank 0 (node_a) treats this as the listen socket so should listen on all IPs
# Rank 0 (node_a) treats this as the listen socket so should listen on all
# IPs
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
"Rank 0 node should use 0.0.0.0 as coordinator listen address"
"Rank 0 node should use localhost as coordinator"
)
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert isinstance(conn_b_a.edge, SocketConnection)
assert (
coordinators[node_b_id] == f"{conn_b_a.edge.sink_multiaddr.ip_address}:5000"
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert isinstance(conn_c_a.edge, SocketConnection)
assert coordinators[node_c_id] == (
f"{conn_c_a.edge.sink_multiaddr.ip_address}:5000"
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises():
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a = create_node_profile(900 * 1024)
node_b = create_node_profile(50 * 1024)
node_c = create_node_profile(10 * 1024) # Insufficient memory
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
model_card = ModelCard(
model_id=ModelId("test-model"),
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(model_card, selected_cycle, Sharding.Pipeline, profiles)

View File

@@ -1,14 +1,13 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
@pytest.fixture
@@ -17,15 +16,20 @@ def topology() -> Topology:
@pytest.fixture
def socket_connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryUsage.from_bytes(
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -39,91 +43,162 @@ def node_profile() -> NodePerformanceProfile:
)
def test_add_node(topology: Topology):
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
# arrange
node_id = NodeId()
# act
topology.add_node(node_id)
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
# assert
assert topology.node_is_leaf(node_id)
data = topology.get_node_profile(node_id)
assert data == node_profile
def test_add_connection(topology: Topology, socket_connection: SocketConnection):
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
connection = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
data = list(topology.list_connections())
data = topology.get_connection_profile(connection)
# assert
assert data == [connection]
assert data == connection.connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
def test_remove_connection_still_connected(
topology: Topology, socket_connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_connection(conn)
topology.remove_connection(connection)
# assert
assert list(topology.get_all_connections_between(node_a, node_b)) == []
assert topology.get_connection_profile(connection) is None
def test_remove_node_still_connected(
topology: Topology, socket_connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_node(node_b)
topology.remove_node(connection.local_node_id)
# assert
assert list(topology.out_edges(node_a)) == []
assert topology.get_node_profile(connection.local_node_id) is None
def test_list_nodes(topology: Topology, socket_connection: SocketConnection):
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeId) for node in nodes)
assert set(node for node in nodes) == set([node_a, node_b])
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}

View File

@@ -11,8 +11,10 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -25,23 +27,13 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import Connection, RDMAConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacThunderboltConnections,
MacThunderboltIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
@@ -55,12 +47,16 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -192,7 +188,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -200,12 +196,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -213,68 +205,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
case MacThunderboltIdentifiers():
profile.tb_interfaces = info.idents
case MacThunderboltConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
Connection(
source=event.node_id,
sink=conn_map[tb_conn.sink_uuid][0],
edge=RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_profiles": new_profiles,
"last_seen": last_seen,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_connection(event.conn)
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.conn)
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -38,7 +38,6 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"

View File

@@ -11,6 +11,9 @@ class InterceptLogger(HypercornLogger):
def __init__(self, config: Config):
super().__init__(config)
assert self.error_logger
# TODO: Decide if we want to provide access logs
# assert self.access_logger
# self.access_logger.handlers = [_InterceptHandler()]
self.error_logger.handlers = [_InterceptHandler()]

View File

@@ -1,281 +1,552 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
name: str
description: str
tags: list[str]
metadata: ModelMetadata
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
}

View File

@@ -6,8 +6,9 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
@@ -91,18 +92,18 @@ async def get_safetensors_size(model_id: str) -> Memory:
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
_model_meta_cache: dict[str, ModelMetadata] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_card(model_id: str) -> ModelCard:
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
@@ -112,11 +113,14 @@ async def _get_model_card(model_id: str) -> ModelCard:
None,
)
return ModelCard(
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,8 +31,9 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -43,4 +43,7 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection
def test_state_serialization_roundtrip() -> None:
@@ -12,11 +12,9 @@ def test_state_serialization_roundtrip() -> None:
node_b = NodeId("node-b")
connection = Connection(
source=node_a,
sink=node_b,
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
),
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
@@ -25,11 +23,5 @@ def test_state_serialization_roundtrip() -> None:
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)
assert (
state.topology.to_snapshot().nodes
== restored_state.topology.to_snapshot().nodes
)
assert set(state.topology.to_snapshot().connections) == set(
restored_state.topology.to_snapshot().connections
)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert restored_state.model_dump_json() == json_repr

View File

@@ -1,227 +1,203 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.topology import (
Connection,
Cycle,
RDMAConnection,
SocketConnection,
)
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
class TopologySnapshot(BaseModel):
nodes: Sequence[NodeId]
connections: Mapping[
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
]
nodes: list[NodeInfo]
connections: list[Connection]
model_config = ConfigDict(frozen=True, extra="forbid")
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
@dataclass
class Topology:
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()), connections=self.map_connections()
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node_id in snapshot.nodes:
for node in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node_id)
topology.add_node(node)
for source in snapshot.connections:
for sink in snapshot.connections[source]:
for edge in snapshot.connections[source][sink]:
topology.add_connection(
Connection(source=source, sink=sink, edge=edge)
)
for connection in snapshot.connections:
topology.add_connection(connection)
return topology
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
return
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
]
def out_edges(self, node_id: NodeId) -> Iterable[Connection]:
if node_id not in self._vertex_indices:
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return (
Connection(source=self._graph[source], sink=self._graph[sink], edge=edge)
for source, sink, edge in self._graph.out_edges(
self._vertex_indices[node_id]
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
)
]
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._vertex_indices
return node_id in self._node_id_to_rx_id_map
def add_connection(self, conn: Connection) -> None:
source, sink, edge = conn.source, conn.sink, conn.edge
del conn
if edge in self.get_all_connections_between(source, sink):
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
connection: Connection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
return
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
_ = self._graph.add_edge(src_id, sink_id, edge)
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
try:
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def list_connections(
self,
) -> Iterable[Connection]:
return (
(
Connection(
source=self._graph[src_id],
sink=self._graph[sink_id],
edge=connection,
)
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._vertex_indices:
if node_id not in self._node_id_to_rx_id_map:
return
rx_idx = self._vertex_indices[node_id]
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph.remove_node(rx_idx)
del self._vertex_indices[node_id]
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
def replace_all_out_rdma_connections(
self, source: NodeId, new_connections: Sequence[Connection]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for conn in new_connections:
self.add_connection(conn)
def remove_connection(self, conn: Connection) -> None:
if (
conn.source not in self._vertex_indices
or conn.sink not in self._vertex_indices
):
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
return
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[conn.source], self._vertex_indices[conn.sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == conn.edge:
self._graph.remove_edge_from_index(conn_idx)
def get_cycles(self) -> list[Cycle]:
"""Get simple cycles in the graph, including singleton cycles"""
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
def get_cycles(self) -> list[list[NodeInfo]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[Cycle] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[self._graph[idx] for idx in cycle_idx])
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
for node_id in self.list_nodes():
cycles.append(Cycle(node_ids=[node_id]))
return cycles
def get_cycles_tb(self) -> list[Cycle]:
def get_cycles_tb(self) -> list[list[NodeInfo]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[Cycle] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
if connection.source in node_ids and connection.sink in node_ids:
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -4,9 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -206,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_card: ModelCard
model_meta: ModelMetadata
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum):

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
class PlaceInstance(BaseCommand):
model_card: ModelCard
model_meta: ModelMetadata
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int

View File

@@ -16,9 +16,7 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
return core_schema.str_schema()
class NodeId(Id):

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
# TODO: bikeshed this name
class NodeGatheredInfo(BaseEvent):
class NodePerformanceMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
info: GatheredInfo
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
class NodeDownloadProgress(BaseEvent):
@@ -97,11 +107,11 @@ class ChunkGenerated(BaseEvent):
class TopologyEdgeCreated(BaseEvent):
conn: Connection
edge: Connection
class TopologyEdgeDeleted(BaseEvent):
conn: Connection
edge: Connection
Event = (
@@ -115,8 +125,10 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodeGatheredInfo
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| TopologyEdgeCreated

View File

@@ -0,0 +1,18 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -1,11 +1,10 @@
import re
from typing import ClassVar
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from pydantic import BaseModel, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,14 +1,12 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import ThunderboltIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryUsage(CamelCaseModel):
class MemoryPerformanceProfile(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -54,12 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[ThunderboltIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
class ConnectionProfile(CamelCaseModel):
throughput: float
latency: float
jitter: float

View File

@@ -1,81 +0,0 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class ThunderboltConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class ThunderboltIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
## Intentionally minimal, only collecting data we care about - there's a lot more
class _ReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str | None = None
class _ConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
class ThunderboltConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
items: list[_ConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: _ReceptacleTag | None = None
def ident(self, ifaces: dict[str, str]) -> ThunderboltIdentifier | None:
if (
self.domain_uuid_key is None
or self.receptacle_1_tag is None
or self.receptacle_1_tag.receptacle_id_key is None
):
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
assert tag in ifaces # doesn't need to be an assertion but im confident
# if tag not in ifaces: return None
iface = f"rdma_{ifaces[tag]}"
return ThunderboltIdentifier(
rdma_interface=iface, domain_uuid=self.domain_uuid_key
)
def conn(self) -> ThunderboltConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
),
None,
)
if sink_key is None:
return None
return ThunderboltConnection(
source_uuid=self.domain_uuid_key, sink_uuid=sink_key
)
class ThunderboltConnectivity(BaseModel, extra="ignore"):
SPThunderboltDataType: list[ThunderboltConnectivityData] = []
@classmethod
async def gather(cls) -> list[ThunderboltConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return ThunderboltConnectivity.model_validate_json(
proc.stdout
).SPThunderboltDataType

View File

@@ -1,41 +1,37 @@
from collections.abc import Iterator
from dataclasses import dataclass
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.utils.pydantic_ext import FrozenModel
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
@dataclass(frozen=True)
class Cycle:
node_ids: list[NodeId]
def __len__(self) -> int:
return self.node_ids.__len__()
def __iter__(self) -> Iterator[NodeId]:
return self.node_ids.__iter__()
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
def is_thunderbolt(self) -> bool:
return True
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
class Connection(FrozenModel):
source: NodeId
sink: NodeId
edge: RDMAConnection | SocketConnection
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
ibv_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]

View File

@@ -0,0 +1,43 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Callable
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
SystemPerformanceProfile,
)
class ResourceCollector(ABC):
@abstractmethod
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
class SystemResourceCollector(ResourceCollector):
async def collect(self) -> SystemPerformanceProfile: ...
class MemoryResourceCollector(ResourceCollector):
async def collect(self) -> MemoryPerformanceProfile: ...
class ResourceMonitor:
data_collectors: list[ResourceCollector]
effect_handlers: set[
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
]
async def _collect(
self,
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
tasks: list[
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
] = [collector.collect() for collector in self.data_collectors]
return await asyncio.gather(*tasks)
async def collect(self) -> None:
profiles = await self._collect()
for profile in profiles:
for effect_handler in self.effect_handlers:
effect_handler(profile)

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.models import ModelMetadata
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_card: ModelCard
model_meta: ModelMetadata
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_card.model_id,
self.model_meta.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -1,235 +0,0 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import (
ThunderboltConnection,
ThunderboltConnectivity,
ThunderboltIdentifier,
)
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacThunderboltIdentifiers(TaggedModel):
idents: Sequence[ThunderboltIdentifier]
class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if IS_DARWIN:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler_thunderbolt_data(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await ThunderboltConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -1,70 +0,0 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -1,24 +0,0 @@
import sys
import pytest
from exo.shared.types.thunderbolt import (
ThunderboltConnectivity,
)
from exo.utils.info_gatherer.info_gatherer import (
_gather_iface_map, # pyright: ignore[reportPrivateUsage]
)
@pytest.mark.anyio
@pytest.mark.skipif(
sys.platform != "darwin", reason="Thunderbolt info can only be gathered on macos"
)
async def test_tb_parsing():
data = await ThunderboltConnectivity.gather()
ifaces = await _gather_iface_map()
assert ifaces
assert data
for datum in data:
datum.ident(ifaces)
datum.conn()

View File

@@ -19,20 +19,11 @@ class CamelCaseModel(BaseModel):
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
strict=True,
)
class FrozenModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
strict=True,
frozen=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):

View File

@@ -28,8 +28,9 @@ def bar(send: MpSender[str]):
send.close()
# not async, just want the fail_after
@pytest.mark.anyio
async def test_channel_ipc():
async def test_channel_setup():
with fail_after(0.5):
s, r = mp_channel[str]()
p1 = mp.Process(target=foo, args=(r,))

View File

@@ -5,7 +5,6 @@ import shutil
import ssl
import time
import traceback
from collections.abc import Awaitable
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
@@ -246,15 +245,12 @@ def create_http_session(
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
@@ -460,10 +456,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_card.model_id))
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -526,24 +522,24 @@ async def download_progress_for_local_path(
async def download_shard(
shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], None],
max_parallel_downloads: int = 8,
skip_download: bool = False,
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
str(shard.model_meta.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
"/", "--"
)
if not skip_download:
@@ -552,13 +548,13 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_card.model_id), revision, recursive=True
str(shard.model_meta.model_id), revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -567,9 +563,9 @@ async def download_shard(
)
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def on_progress_wrapper(
def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None:
):
start_time = (
file_progress[file.path].start_time
if file.path in file_progress
@@ -592,7 +588,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -605,11 +601,11 @@ async def download_shard(
else "in_progress",
start_time=start_time,
)
await on_progress(
on_progress(
shard,
calculate_repo_progress(
shard,
str(shard.model_card.model_id),
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
@@ -619,7 +615,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -633,21 +629,14 @@ async def download_shard(
semaphore = asyncio.Semaphore(max_parallel_downloads)
def schedule_progress(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None:
asyncio.create_task(
on_progress_wrapper(file, curr_bytes, total_bytes, is_renamed)
)
async def download_with_semaphore(file: FileListEntry) -> None:
async def download_with_semaphore(file: FileListEntry):
async with semaphore:
await download_file_with_retry(
str(shard.model_card.model_id),
str(shard.model_meta.model_id),
revision,
file.path,
target_dir,
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
lambda curr_bytes, total_bytes, is_renamed: on_progress_wrapper(
file, curr_bytes, total_bytes, is_renamed
),
)
@@ -657,9 +646,9 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
return target_dir / gguf.path, final_repo_progress
else:

View File

@@ -1,10 +1,9 @@
import asyncio
from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -20,21 +19,21 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: str) -> ShardMetadata:
model_card = await get_model_card(model_id)
model_meta = await get_model_meta(model_id)
return PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_card.n_layers,
n_layers=model_card.n_layers,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
model_meta=base_shard.model_meta,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
@@ -49,8 +48,7 @@ class SingletonShardDownloader(ShardDownloader):
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
self.shard_downloader.on_progress(callback)
@@ -85,19 +83,18 @@ class CachedShardDownloader(ShardDownloader):
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
self.shard_downloader.on_progress(callback)
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
if (shard.model_meta.model_id, shard) in self.cache:
return self.cache[(shard.model_meta.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_card.model_id, shard)] = target_dir
self.cache[(shard.model_meta.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(
@@ -116,18 +113,17 @@ class ResumableShardDownloader(ShardDownloader):
def __init__(self, max_parallel_downloads: int = 8):
self.max_parallel_downloads = max_parallel_downloads
self.on_progress_callbacks: list[
Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]]
Callable[[ShardMetadata, RepoDownloadProgress], None]
] = []
async def on_progress_wrapper(
def on_progress_wrapper(
self, shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
for callback in self.on_progress_callbacks:
await callback(shard, progress)
callback(shard, progress)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
self.on_progress_callbacks.append(callback)

View File

@@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from copy import copy
from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -32,8 +31,7 @@ class ShardDownloader(ABC):
@abstractmethod
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
pass
@@ -61,8 +59,7 @@ class NoopShardDownloader(ShardDownloader):
return Path("/tmp/noop_shard")
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
pass
@@ -86,8 +83,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,

View File

@@ -46,11 +46,9 @@ class CustomMlxLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
# Set twice to avoid __setattr__ recursion
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
self.original_layer: _LayerCallable = original_layer
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -60,7 +58,7 @@ class CustomMlxLayer(nn.Module):
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
return getattr(original_layer, name)
return object.__getattribute__(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
@@ -170,21 +168,11 @@ def pipeline_auto_parallel(
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
# We can assume the model has at least one layer thanks to placement.
# If a layer type doesn't exist, we can set it to 0.
inner_model_instance.swa_idx = (
0
if "sliding_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.ga_idx = (
0
if "full_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
_set_layers(model, layers)
@@ -214,9 +202,9 @@ def tensor_auto_parallel(
segments: int = 1
def _all_to_sharded(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - all to sharded")
return weight.ndim - 1, segments
# if path.endswith("bias"):
# logger.info(f"Sharding bias for {path} - all to sharded")
# return weight.ndim - 1, segments
return max(weight.ndim - 2, 0), segments
all_to_sharded_linear_in_place = partial(
@@ -228,10 +216,10 @@ def tensor_auto_parallel(
n = group.size()
def _sharded_to_all(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - sharded to all")
weight /= n
return None
# if path.endswith("bias"):
# logger.info(f"Sharding bias for {path} - sharded to all")
# weight /= n
# return None
return -1, segments
sharded_to_all_linear_in_place = partial(

View File

@@ -75,7 +75,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_card.storage_size.in_kb
* model_shard_meta.model_meta.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
@@ -186,26 +186,20 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
assert all(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
ibv_devices_json = json.dumps(ibv_devices)
with open(coordination_file, "w") as f:
_ = f.write(jaccl_devices_json)
_ = f.write(ibv_devices_json)
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
@@ -241,7 +235,7 @@ def load_mlx_items(
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
@@ -269,7 +263,7 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_card.model_id)
model_path = build_model_path(shard_metadata.model_meta.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
@@ -305,9 +299,8 @@ def shard_and_load(
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
timeout_seconds = 60 + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
@@ -328,7 +321,7 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:

View File

@@ -8,7 +8,6 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
@@ -17,13 +16,16 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -32,7 +34,7 @@ from exo.shared.types.tasks import (
Task,
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -43,14 +45,14 @@ from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
class Worker:
@@ -84,7 +86,7 @@ class Worker:
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
self._tg: TaskGroup | None = None
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -96,13 +98,37 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
@@ -116,17 +142,6 @@ class Worker:
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -146,6 +161,7 @@ class Worker:
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
assert self._tg
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
@@ -186,11 +202,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_card.model_id not in self.download_status:
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -205,7 +221,7 @@ class Worker:
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -236,7 +252,8 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
if self._tg:
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
@@ -253,28 +270,24 @@ class Worker:
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
async def _nack_request(self, since_idx: int) -> None:
@@ -323,6 +336,7 @@ class Worker:
event_sender=self.event_sender.clone(),
)
self.runners[task.bound_instance.bound_runner_id] = runner
assert self._tg
self._tg.start_soon(runner.run)
return runner
@@ -339,13 +353,14 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
# TODO: i hate callbacks
def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
@@ -356,11 +371,12 @@ class Worker:
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
self.event_sender.send_nowait(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
@@ -376,13 +392,14 @@ class Worker:
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
self.download_status[shard.model_meta.model_id] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
@@ -403,14 +420,9 @@ class Worker:
async def _poll_connection_updates(self):
while True:
edges = set(
conn.edge for conn in self.state.topology.out_edges(self.node_id)
)
conns = await check_reachable(
self.state.topology,
self.node_id,
self.state.node_profiles,
)
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology, self.node_id)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
@@ -418,33 +430,26 @@ class Worker:
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(
TopologyEdgeCreated(
conn=Connection(
source=self.node_id, sink=nid, edge=edge
)
)
)
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
for nid, conn in self.state.topology.out_edges(self.node_id):
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address
not in conns.get(conn.sink, set())
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
@@ -478,7 +483,7 @@ class Worker:
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)

View File

@@ -2,8 +2,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -114,7 +114,7 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
@@ -191,7 +191,7 @@ def _load_model(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner

View File

@@ -22,7 +22,7 @@ def entrypoint(
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.jaccl_devices) >= 2
and len(bound_instance.instance.ibv_devices) >= 2
)
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"

View File

@@ -1,6 +1,8 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -13,6 +15,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -20,6 +23,7 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -48,6 +52,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -57,6 +62,33 @@ from exo.worker.engines.mlx.utils_mlx import (
from exo.worker.runner.bootstrap import logger
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -67,7 +99,6 @@ def main(
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
@@ -149,7 +180,7 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
model=cast(Model, model),
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -170,16 +201,20 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert model
assert tokenizer
assert task_params.messages[0].content is not None
try:
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
):
assert model
assert tokenizer
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
)
@@ -193,13 +228,13 @@ def main(
for response in mlx_generator:
match response:
case GenerationResponse():
if device_rank == 0:
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_card.model_id,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
@@ -208,24 +243,6 @@ def main(
)
)
# can we make this more explicit?
except Exception as e:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_card.model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():

View File

@@ -1,7 +1,7 @@
from typing import Final
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import TaskId
from exo.shared.types.worker.instances import InstanceId, RunnerId

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -32,8 +32,9 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,

View File

@@ -1,202 +0,0 @@
# type: ignore
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
class MockLayer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
self.use_sliding = True
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x * 2
@dataclass(frozen=True)
class PipelineTestConfig:
model_path: Path
total_layers: int
base_port: int
max_tokens: int
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
import json
import tempfile
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
return hostfile_path, hosts
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
total_layers=24,
base_port=29600,
max_tokens=200,
)
def run_gpt_oss_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
layer_splits: list[tuple[int, int]],
prompt_tokens: int,
prefill_step_size: int,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 200,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
from mlx_lm import load, stream_generate
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
model, tokenizer = load(str(model_path))
# Generate a prompt of exact token length
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
# Build prompt with approximate target length
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
# Truncate to exact target length
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
supports_tensor=False,
),
device_rank=rank,
world_size=world_size,
start_layer=start_layer,
end_layer=end_layer,
n_layers=24,
)
model = pipeline_auto_parallel(model, group, shard_meta)
# Barrier before generation
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
generated_text = ""
for response in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=max_tokens,
prefill_step_size=prefill_step_size,
):
generated_text += response.text
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
def run_gpt_oss_tensor_parallel_device(
rank: int,
world_size: int,
hostfile_path: str,
model_path: Path,
prompt_tokens: int,
prefill_step_size: int,
result_queue: Any, # pyright: ignore[reportAny]
max_tokens: int = 10,
) -> None:
import os
import traceback
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
from mlx_lm import load, stream_generate
from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
model, tokenizer = load(str(model_path))
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
base_len = len(base_tokens)
repeats = (prompt_tokens // base_len) + 2
long_text = base_text * repeats
tokens = tokenizer.encode(long_text)
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
model = tensor_auto_parallel(model, group)
barrier = mlx_core.distributed.all_sum(mlx_core.array([1.0]), group=group)
mlx_core.eval(barrier)
generated_text = ""
for response in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=max_tokens,
prefill_step_size=prefill_step_size,
):
generated_text += response.text
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]

View File

@@ -1,137 +0,0 @@
import multiprocessing as mp
from typing import Any
import mlx.core as mx
import pytest
from exo.worker.engines.mlx.auto_parallel import (
CustomMlxLayer,
PipelineFirstLayer,
PipelineLastLayer,
)
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
def run_pipeline_device(
rank: int,
world_size: int,
hostfile_path: str,
result_queue: Any, # pyright: ignore[reportAny]
) -> None:
import os
os.environ["MLX_HOSTFILE"] = hostfile_path
os.environ["MLX_RANK"] = str(rank)
import mlx.core as mlx_core
import mlx.nn as mlx_nn
class MockLayerInner(mlx_nn.Module):
def __init__(self) -> None:
super().__init__()
self.custom_attr = "test_value"
def __call__(
self, x: mlx_core.array, *args: object, **kwargs: object
) -> mlx_core.array:
return x * 2
try:
group = mlx_core.distributed.init(backend="ring", strict=True)
mock = MockLayerInner()
first = PipelineFirstLayer(mock, r=rank, group=group)
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
x = mlx_core.ones((1, 4))
result = composed(x)
mlx_core.eval(result)
success = result.shape == x.shape
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
except Exception as e:
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
def test_single_wrapper_delegates_attributes() -> None:
mock = MockLayer()
wrapped = CustomMlxLayer(mock)
assert wrapped.custom_attr == "test_value" # type: ignore[attr-defined]
assert wrapped.use_sliding is True # type: ignore[attr-defined]
def test_composed_wrappers_delegate_attributes() -> None:
mock = MockLayer()
group = mx.distributed.init()
first = PipelineFirstLayer(mock, r=0, group=group)
composed = PipelineLastLayer(first, r=0, s=1, group=group)
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
assert composed.use_sliding is True # type: ignore[attr-defined]
def test_missing_attribute_raises() -> None:
mock = MockLayer()
wrapped = CustomMlxLayer(mock)
with pytest.raises(AttributeError):
_ = wrapped.nonexistent_attr # type: ignore[attr-defined]
def test_composed_call_works() -> None:
import json
import os
import tempfile
ctx = mp.get_context("spawn")
world_size = 2
base_port = 29500
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(hosts, f)
hostfile_path = f.name
try:
result_queue: Any = ctx.Queue()
processes: list[Any] = []
for rank in range(world_size):
p = ctx.Process(
target=run_pipeline_device,
args=(rank, world_size, hostfile_path, result_queue),
)
p.start()
processes.append(p)
for p in processes: # pyright: ignore[reportAny]
p.join(timeout=10) # pyright: ignore[reportAny]
results: dict[int, Any] = {}
errors: dict[int, str] = {}
while not result_queue.empty(): # pyright: ignore[reportAny]
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
if success:
results[rank] = value
else:
errors[rank] = value
assert len(results) == world_size, (
f"Expected {world_size} results, got {len(results)}. Errors: {errors}"
)
for rank in range(world_size):
assert rank in results, (
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
)
result_array = results[rank]
# Both devices see the final result (4.0) after all_gather
assert (result_array == 4.0).all(), (
f"Device {rank}: expected 4.0, got {result_array}"
)
finally:
os.unlink(hostfile_path)

View File

@@ -76,13 +76,13 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for _, card in MODEL_CARDS.items():
for short_id, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
parts = short_id.split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (card.model_id.short(), card)
families[family] = (short_id, card)
return list(families.values())

View File

@@ -1,7 +1,7 @@
import exo.worker.plan as plan_mod
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance

View File

@@ -0,0 +1,50 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()

View File

@@ -121,21 +121,6 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness.
class EventCollector:
def __init__(self) -> None:
self.events: list[Event] = []
def send(self, event: Event) -> None:
self.events.append(event)
def close(self) -> None:
pass
def join(self) -> None:
pass
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -145,20 +130,22 @@ def _run(tasks: Iterable[Task]):
)
task_sender, task_receiver = mp_channel[Task]()
event_sender = EventCollector()
event_sender, event_receiver = mp_channel[Event]()
with task_sender:
with task_sender, event_receiver:
for t in tasks:
task_sender.send(t)
# worst monkeypatch known to man
# this is some c++ nonsense
event_sender.close = nothin
event_sender.join = nothin
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_sender.events
return event_receiver.collect()
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):

View File

@@ -0,0 +1,6 @@
from .profile import start_polling_memory_metrics, start_polling_node_metrics
__all__ = [
"start_polling_node_metrics",
"start_polling_memory_metrics",
]

View File

@@ -0,0 +1,103 @@
import platform
import shutil
from subprocess import CalledProcessError
from typing import cast
from anyio import run_process
from pydantic import BaseModel, ConfigDict, ValidationError
class MacMonError(Exception):
"""Exception raised for errors in the MacMon functions."""
def _get_binary_path() -> str:
"""
Get the path to the macmon binary.
Raises:
MacMonError: If the binary doesn't exist or can't be made executable.
"""
# Check for macOS with ARM chip
system = platform.system().lower()
machine = platform.machine().lower()
if system != "darwin" or not (
"arm" in machine or "m1" in machine or "m2" in machine
):
raise MacMonError("MacMon only supports macOS with Apple Silicon (ARM) chips")
path = shutil.which("macmon")
if path is None:
raise MacMonError("MacMon not found in PATH")
return path
class TempMetrics(BaseModel):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
model_config = ConfigDict(extra="ignore")
class Metrics(BaseModel):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
all_power: float
ane_power: float
cpu_power: float
ecpu_usage: tuple[int, float]
gpu_power: float
gpu_ram_power: float
gpu_usage: tuple[int, float]
pcpu_usage: tuple[int, float]
ram_power: float
sys_power: float
temp: TempMetrics
timestamp: str
model_config = ConfigDict(extra="ignore")
async def get_metrics_async() -> Metrics:
"""
Asynchronously run the binary and return the metrics as a Python dictionary.
Args:
binary_path: Optional path to the binary. If not provided, will use the bundled binary.
Returns:
A mapping containing system metrics.
Raises:
MacMonError: If there's an error running the binary.
"""
path = _get_binary_path()
try:
# TODO: Keep Macmon running in the background?
result = await run_process([path, "pipe", "-s", "1"])
return Metrics.model_validate_json(result.stdout.decode().strip())
except ValidationError as e:
raise MacMonError(f"Error parsing JSON output: {e}") from e
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
raise MacMonError(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
) from e

View File

@@ -1,5 +1,3 @@
from collections.abc import Mapping
import anyio
import httpx
from anyio import create_task_group
@@ -7,7 +5,6 @@ from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
REACHABILITY_ATTEMPTS = 3
@@ -21,9 +18,8 @@ async def check_reachability(
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
url = f"http://[{target_ip}]:52415/node_id"
else:
url = f"http://{target_ip}:52415/node_id"
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
@@ -77,9 +73,7 @@ async def check_reachability(
async def check_reachable(
topology: Topology,
self_node_id: NodeId,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
@@ -97,16 +91,16 @@ async def check_reachable(
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
for node_id in topology.list_nodes():
if node_id not in node_profiles:
for node in topology.list_nodes():
if not node.node_profile:
continue
if node_id == self_node_id:
if node.node_id == self_node_id:
continue
for iface in node_profiles[node_id].network_interfaces:
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node_id,
node.node_id,
reachable,
client,
)

View File

@@ -0,0 +1,114 @@
import asyncio
import os
import platform
from typing import Any, Callable, Coroutine
import anyio
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from .macmon import (
MacMonError,
Metrics,
)
from .macmon import (
get_metrics_async as macmon_get_metrics_async,
)
from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
)
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
if platform.system().lower() == "darwin":
return await macmon_get_metrics_async()
def get_memory_profile() -> MemoryPerformanceProfile:
"""Construct a MemoryPerformanceProfile using psutil"""
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
async def start_polling_memory_metrics(
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 0.5,
) -> None:
"""Continuously poll and emit memory-only metrics at a faster cadence.
Parameters
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
- poll_interval_s: interval between polls
"""
while True:
try:
mem = get_memory_profile()
await callback(mem)
except MacMonError as e:
logger.opt(exception=e).error("Memory Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
)
)
except asyncio.TimeoutError:
logger.warning(
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
)
except MacMonError as e:
logger.opt(exception=e).error("Resource Monitor encountered error")
return
finally:
await anyio.sleep(poll_interval_s)

View File

@@ -0,0 +1,77 @@
"""Tests for macmon error handling.
These tests verify that MacMon errors are handled gracefully without
crashing the application or spamming logs.
"""
import platform
from subprocess import CalledProcessError
from unittest.mock import AsyncMock, patch
import pytest
from exo.worker.utils.macmon import MacMonError, get_metrics_async
@pytest.mark.skipif(
platform.system().lower() != "darwin" or "arm" not in platform.machine().lower(),
reason="MacMon only supports macOS with Apple Silicon",
)
class TestMacMonErrorHandling:
"""Test MacMon error handling."""
async def test_called_process_error_wrapped_as_macmon_error(self) -> None:
"""CalledProcessError should be wrapped as MacMonError."""
mock_error = CalledProcessError(
returncode=1,
cmd=["macmon", "pipe", "-s", "1"],
stderr=b"some error message",
)
with (
patch(
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
),
patch(
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
) as mock_run,
):
mock_run.side_effect = mock_error
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon failed with return code 1" in str(exc_info.value)
assert "some error message" in str(exc_info.value)
async def test_called_process_error_with_no_stderr(self) -> None:
"""CalledProcessError with no stderr should be handled gracefully."""
mock_error = CalledProcessError(
returncode=1,
cmd=["macmon", "pipe", "-s", "1"],
stderr=None,
)
with (
patch(
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
),
patch(
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
) as mock_run,
):
mock_run.side_effect = mock_error
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon failed with return code 1" in str(exc_info.value)
assert "no stderr" in str(exc_info.value)
async def test_macmon_not_found_raises_macmon_error(self) -> None:
"""When macmon is not found in PATH, MacMonError should be raised."""
with patch("exo.worker.utils.macmon.shutil.which", return_value=None):
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon not found in PATH" in str(exc_info.value)

View File

@@ -34,8 +34,7 @@ from exo.shared.types.worker.instances import (
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.channels import MpReceiver, MpSender, mp_channel
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
@@ -66,7 +65,6 @@ async def main():
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
app, # type: ignore
@@ -78,15 +76,6 @@ async def main():
shutdown.set()
async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
@@ -135,7 +124,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
meta = MODEL_CARDS[test.model_id].metadata
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
@@ -145,15 +134,15 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_card=card,
model_meta=meta,
device_rank=i,
world_size=world_size,
start_layer=(card.n_layers // world_size) * i,
start_layer=(meta.n_layers // world_size) * i,
end_layer=min(
card.n_layers, (card.n_layers // world_size) * (i + 1)
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
),
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
- (meta.n_layers // world_size) * i,
)
for i in range(world_size)
},
@@ -220,16 +209,16 @@ async def jaccl_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
meta = MODEL_CARDS[test.model_id].metadata
world_size = len(test.devs)
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
ibv_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
@@ -239,12 +228,12 @@ def jaccl_instance(test: Tests, iid: InstanceId):
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_card=card,
model_meta=meta,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
start_layer=meta.n_layers,
end_layer=meta.n_layers,
n_layers=meta.n_layers,
)
for i in range(world_size)
},