3 Commits

Author SHA1 Message Date
Evan
ee218d1f4c format 2025-12-18 20:45:13 +00:00
Evan
39a4857189 starting downloads rewrite 2025-12-18 20:45:13 +00:00
Evan
4fe2e31fb9 migrate to anyio in most places 2025-12-18 20:45:13 +00:00
72 changed files with 1577 additions and 3174 deletions

View File

@@ -1,298 +0,0 @@
name: Build EXO macOS DMG
on:
push:
tags:
- "v*"
branches:
- "test-app"
jobs:
build-macos-app:
runs-on: "macos-26"
env:
SPARKLE_VERSION: 2.8.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
AWS_REGION: ${{ secrets.AWS_REGION }}
EXO_BUILD_NUMBER: ${{ github.run_number }}
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
steps:
# ============================================================
# Checkout and tag validation
# ============================================================
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Derive release version from tag
run: |
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
VERSION="0.0.0-alpha.0"
echo "IS_ALPHA=true" >> $GITHUB_ENV
else
VERSION="${GITHUB_REF_NAME#v}"
if [[ "$VERSION" == *-alpha* ]]; then
echo "IS_ALPHA=true" >> $GITHUB_ENV
else
echo "IS_ALPHA=false" >> $GITHUB_ENV
fi
fi
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
- name: Ensure tag commit is on main
if: github.ref_type == 'tag'
run: |
git fetch origin main
# Alpha tags can be on any branch, production tags must be on main
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Alpha tag detected, skipping main branch check"
elif ! git merge-base --is-ancestor origin/main HEAD; then
echo "Production tag must point to a commit on main"
exit 1
fi
# ============================================================
# Install dependencies
# ============================================================
- name: Select Xcode 26.2
run: |
sudo xcode-select -s /Applications/Xcode_26.2.app
if ! xcrun -f metal >/dev/null 2>&1; then
echo "Metal toolchain is not installed."
exit 1
fi
- name: Install Homebrew packages
run: brew install just awscli macmon
- name: Install UV
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: uv.lock
- name: Setup Python
run: |
uv python install
uv sync --locked
- name: Build dashboard
run: |
cd dashboard
npm ci
npm run build
- name: Install Sparkle CLI
run: |
CLI_URL="${SPARKLE_CLI_URL:-https://github.com/sparkle-project/Sparkle/releases/download/${SPARKLE_VERSION}/Sparkle-${SPARKLE_VERSION}.tar.xz}"
echo "Downloading Sparkle CLI from: $CLI_URL"
mkdir -p /tmp/sparkle
curl --fail --location --output /tmp/sparkle.tar.xz "$CLI_URL"
tar -xJf /tmp/sparkle.tar.xz -C /tmp/sparkle --strip-components=1
echo "SPARKLE_BIN=/tmp/sparkle/bin" >> $GITHUB_ENV
- name: Prepare code-signing keychain
env:
MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }}
MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
PROVISIONING_PROFILE: ${{ secrets.PROVISIONING_PROFILE }}
run: |
KEYCHAIN_PATH="$HOME/Library/Keychains/build.keychain-db"
# Create fresh keychain
security create-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$KEYCHAIN_PATH"
# Disable auto-lock (no timeout, no lock-on-sleep)
security set-keychain-settings "$KEYCHAIN_PATH"
# Add to search list while preserving existing keychains
security list-keychains -d user -s "$KEYCHAIN_PATH" $(security list-keychains -d user | tr -d '"')
# Set as default and unlock
security default-keychain -s "$KEYCHAIN_PATH"
security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$KEYCHAIN_PATH"
# Import certificate with full access for codesign
echo "$MACOS_CERTIFICATE" | base64 --decode > /tmp/cert.p12
security import /tmp/cert.p12 -k "$KEYCHAIN_PATH" -P "$MACOS_CERTIFICATE_PASSWORD" \
-T /usr/bin/codesign -T /usr/bin/security -T /usr/bin/productbuild
rm /tmp/cert.p12
# Allow codesign to access the key without prompting
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$MACOS_CERTIFICATE_PASSWORD" "$KEYCHAIN_PATH"
# Verify keychain is unlocked and identity is available
echo "Verifying signing identity..."
security find-identity -v -p codesigning "$KEYCHAIN_PATH"
# Setup provisioning profile
mkdir -p "$HOME/Library/Developer/Xcode/UserData/Provisioning Profiles"
echo "$PROVISIONING_PROFILE" | base64 --decode > "$HOME/Library/Developer/Xcode/UserData/Provisioning Profiles/EXO.provisionprofile"
# Export keychain path for other steps
echo "BUILD_KEYCHAIN_PATH=$KEYCHAIN_PATH" >> $GITHUB_ENV
# ============================================================
# Build the bundle
# ============================================================
- name: Build PyInstaller bundle
run: uv run pyinstaller packaging/pyinstaller/exo.spec
- name: Build Swift app
env:
MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
run: |
cd app/EXO
security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$BUILD_KEYCHAIN_PATH"
SIGNING_IDENTITY=$(security find-identity -v -p codesigning "$BUILD_KEYCHAIN_PATH" | awk -F '"' '{print $2}')
xcodebuild clean build \
-scheme EXO \
-configuration Release \
-derivedDataPath build \
MARKETING_VERSION="$RELEASE_VERSION" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
EXO_BUILD_TAG="$RELEASE_VERSION" \
EXO_BUILD_COMMIT="$GITHUB_SHA" \
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
mkdir -p ../../output
cp -R build/Build/Products/Release/EXO.app ../../output/EXO.app
- name: Inject PyInstaller runtime
run: |
rm -rf output/EXO.app/Contents/Resources/exo
mkdir -p output/EXO.app/Contents/Resources
cp -R dist/exo output/EXO.app/Contents/Resources/exo
- name: Codesign PyInstaller runtime
env:
MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
run: |
cd output
security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$BUILD_KEYCHAIN_PATH"
SIGNING_IDENTITY=$(security find-identity -v -p codesigning "$BUILD_KEYCHAIN_PATH" | awk -F '"' '{print $2}')
RUNTIME_DIR="EXO.app/Contents/Resources/exo"
find "$RUNTIME_DIR" -type f \( -perm -111 -o -name "*.dylib" -o -name "*.so" \) -print0 |
while IFS= read -r -d '' file; do
/usr/bin/codesign --force --timestamp --options runtime \
--sign "$SIGNING_IDENTITY" "$file"
done
- name: Sign, notarize, and create DMG
env:
MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
APPLE_NOTARIZATION_USERNAME: ${{ secrets.APPLE_NOTARIZATION_USERNAME }}
APPLE_NOTARIZATION_PASSWORD: ${{ secrets.APPLE_NOTARIZATION_PASSWORD }}
APPLE_NOTARIZATION_TEAM: ${{ secrets.APPLE_NOTARIZATION_TEAM }}
run: |
cd output
security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$BUILD_KEYCHAIN_PATH"
SIGNING_IDENTITY=$(security find-identity -v -p codesigning "$BUILD_KEYCHAIN_PATH" | awk -F '"' '{print $2}')
/usr/bin/codesign --deep --force --timestamp --options runtime \
--sign "$SIGNING_IDENTITY" EXO.app
mkdir -p dmg-root
cp -R EXO.app dmg-root/
ln -s /Applications dmg-root/Applications
DMG_NAME="EXO-${RELEASE_VERSION}.dmg"
hdiutil create -volname "EXO" -srcfolder dmg-root -ov -format UDZO "$DMG_NAME"
/usr/bin/codesign --force --timestamp --options runtime \
--sign "$SIGNING_IDENTITY" "$DMG_NAME"
if [[ -n "$APPLE_NOTARIZATION_USERNAME" ]]; then
SUBMISSION_OUTPUT=$(xcrun notarytool submit "$DMG_NAME" \
--apple-id "$APPLE_NOTARIZATION_USERNAME" \
--password "$APPLE_NOTARIZATION_PASSWORD" \
--team-id "$APPLE_NOTARIZATION_TEAM" \
--wait --timeout 15m 2>&1)
echo "$SUBMISSION_OUTPUT"
SUBMISSION_ID=$(echo "$SUBMISSION_OUTPUT" | awk 'tolower($1)=="id:" && $2 ~ /^[0-9a-fA-F-]+$/ {print $2; exit}')
STATUS=$(echo "$SUBMISSION_OUTPUT" | awk 'tolower($1)=="status:" {print $2; exit}')
if [[ -n "$SUBMISSION_ID" ]]; then
xcrun notarytool log "$SUBMISSION_ID" \
--apple-id "$APPLE_NOTARIZATION_USERNAME" \
--password "$APPLE_NOTARIZATION_PASSWORD" \
--team-id "$APPLE_NOTARIZATION_TEAM" > notarization-log.txt || true
echo "===== Notarization Log ====="
cat notarization-log.txt
echo "============================"
fi
if [[ "$STATUS" != "Accepted" ]]; then
echo "Notarization failed with status: ${STATUS:-Unknown}"
exit 1
fi
xcrun stapler staple "$DMG_NAME"
fi
- name: Generate Sparkle appcast
env:
SPARKLE_DOWNLOAD_PREFIX: ${{ env.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
IS_ALPHA: ${{ env.IS_ALPHA }}
run: |
set -euo pipefail
cd output
DOWNLOAD_PREFIX="${SPARKLE_DOWNLOAD_PREFIX:-https://assets.exolabs.net}"
echo "$SPARKLE_ED25519_PRIVATE" > sparkle_ed25519.key
chmod 600 sparkle_ed25519.key
CHANNEL_FLAG=""
if [[ "$IS_ALPHA" == "true" ]]; then
CHANNEL_FLAG="--channel alpha"
echo "Generating appcast for alpha channel"
fi
$SPARKLE_BIN/generate_appcast \
--ed-key-file sparkle_ed25519.key \
--download-url-prefix "$DOWNLOAD_PREFIX" \
$CHANNEL_FLAG \
.
# ============================================================
# Upload artifacts
# ============================================================
- name: Upload DMG
uses: actions/upload-artifact@v4
with:
name: EXO-dmg-${{ env.RELEASE_VERSION }}
path: output/EXO-${{ env.RELEASE_VERSION }}.dmg
- name: Upload to S3
if: env.SPARKLE_S3_BUCKET != '' && github.ref_type == 'tag'
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION: ${{ env.AWS_REGION }}
SPARKLE_S3_BUCKET: ${{ env.SPARKLE_S3_BUCKET }}
SPARKLE_S3_PREFIX: ${{ env.SPARKLE_S3_PREFIX }}
IS_ALPHA: ${{ env.IS_ALPHA }}
run: |
set -euo pipefail
cd output
PREFIX="${SPARKLE_S3_PREFIX:-}"
if [[ -n "$PREFIX" && "${PREFIX: -1}" != "/" ]]; then
PREFIX="${PREFIX}/"
fi
DMG_NAME="EXO-${RELEASE_VERSION}.dmg"
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
if [[ "$IS_ALPHA" != "true" ]]; then
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
fi
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache

View File

@@ -5,21 +5,10 @@ Thank you for your interest in contributing to EXO!
## Getting Started
To run EXO from source:
**Prerequisites:**
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
```bash
brew install uv
```
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
```bash
brew install macmon
```
```bash
git clone https://github.com/exo-explore/exo.git
cd exo/dashboard
npm install && npm run build && cd ..
npm install && npm run build
uv run exo
```

214
README.md
View File

@@ -1,223 +1,55 @@
<div align="center">
<picture>
<source media="(prefers-color-scheme: light)" srcset="/docs/imgs/exo-logo-black-bg.jpg">
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
<source media="(prefers-color-scheme: light)" srcset="/docs/exo-logo-black-bg.jpg">
<img alt="exo logo" src="/docs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
</p>
[![GitHub Repo stars](https://img.shields.io/github/stars/exo-explore/exo)](https://github.com/exo-explore/exo/stargazers)
[![License: Apache-2.0](https://img.shields.io/badge/License-Apache2.0-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0.html)
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
---
exo connects all your devices into an AI cluster. Not only does exo enable running models larger than would fit on a single device, but with [day-0 support for RDMA over Thunderbolt](https://x.com/exolabs/status/2001817749744476256?s=20), makes models run faster as you add more devices.
## Features
- **Automatic Device Discovery**: Devices running exo automatically discover each other - no manual configuration.
- **RDMA over Thunderbolt**: exo ships with [day-0 support for RDMA over Thunderbolt 5](https://x.com/exolabs/status/2001817749744476256?s=20), enabling 99% reduction in latency between devices.
- **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link.
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Benchmarks
<details>
<summary>Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-1-qwen3-235b.jpeg" alt="Benchmark - Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
<details>
<summary>DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-2-deepseek-3.1-671b.jpeg" alt="Benchmark - DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
<details>
<summary>Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-3-kimi-k2-thinking.jpeg" alt="Benchmark - Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
- **Automatic Device Discovery**: Devices running EXO automatically discover each other on your local network - no manual configuration.
- **RDMA over Thunderbolt**: EXO ships with Day-0 support for RDMA over Thunderbolt 5, enabling 99% reduction in latency between devices.
- **Auto Parallel**: EXO automatically splits up models to run distributed across devices.
- **Tensor Parallelism**: EXO supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: EXO uses [ml-explore/mlx](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
---
## Quick Start
Devices running exo automatically discover each other, without needing any manual configuration. Each device provides an API and a dashboard for interacting with your cluster (runs at `http://localhost:52415`).
You need at least one Mac device running macOS Tahoe 26.2 (released December 12th 2025).
There are two ways to run exo:
You can download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-latest.dmg). It will ask for permission to modify system settings and install a new Network profile. We hope to make this smoother in the future!
### Run from Source (Mac & Linux)
To run from source, clone the repo, build the dashboard with `cd dashboard && npm install && npm run build` and run `uv run exo`.
**Prerequisites:**
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
```bash
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
Clone the repo, build the dashboard, and run exo:
```bash
# Clone exo
git clone https://github.com/exo-explore/exo
# Build dashboard
cd exo/dashboard && npm install && npm run build && cd ..
# Run exo
uv run exo
```
This starts the exo dashboard and API at http://localhost:52415/
### macOS App
exo ships a macOS app that runs in the background on your Mac.
<img src="docs/imgs/macos-app-one-macbook.png" alt="exo macOS App - running on a MacBook" width="35%" />
The macOS app requires macOS Tahoe 26.2 or later.
Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-latest.dmg).
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
After starting with either of these methods go to `http://localhost:52415` in your browser, and you'll have EXO.
---
### Using the API
## Requirements
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
- Mac devices with Apple Silicon (M-series chips)
- macOS Tahoe 26.2 or later (released December 12th 2025)
- Older macOS versions may work without RDMA, but only 26.2+ is officially supported
- For RDMA over Thunderbolt: a high quality Thunderbolt 5 cable
---
**1. Preview instance placements**
The `/instance/previews` endpoint will preview all valid placements for your model.
```bash
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
```
Sample response:
```json
{
"previews": [
{
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"sharding": "Pipeline",
"instance_meta": "MlxRing",
"instance": {...},
"memory_delta_by_node": {"local": 729808896},
"error": null
}
// ...possibly more placements...
]
}
```
This will return all valid placements for this model. Pick a placement that you like.
To pick the first one, pipe into `jq`:
```bash
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1
```
---
**2. Create a model instance**
Send a POST to `/instance` with your desired placement in the `instance` field (the full payload must match types as in `CreateInstanceParams`), which you can copy from step 1:
```bash
curl -X POST http://localhost:52415/instance \
-H 'Content-Type: application/json' \
-d '{
"instance": {...}
}'
```
Sample response:
```json
{
"message": "Command received.",
"command_id": "e9d1a8ab-...."
}
```
---
**3. Send a chat completion**
Now, make a POST to `/v1/chat/completions` (the same format as OpenAI's API):
```bash
curl -N -X POST http://localhost:52415/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"messages": [
{"role": "user", "content": "What is Llama 3.2 1B?"}
],
"stream": true
}'
```
---
**4. Delete the instance**
When you're done, delete the instance by its ID (find it via `/state` or `/instance` endpoints):
```bash
curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
```
**Other useful API endpoints*:**
- List all models: `curl http://localhost:52415/models`
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
---
## Hardware Accelerator Support
On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working on extending hardware accelerator support. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community.
We intend to add support for other hardware platforms [like the DGX Spark](https://x.com/exolabs/status/1978525767739883736) in the future, but they are not currently supported. If you'd like support for a new hardware platform, please search for an existing feature request and add a thumbs up so we know what hardware is important to the community.
---
## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to EXO.

View File

@@ -19,7 +19,6 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -212,7 +212,7 @@ struct ContentView: View {
private var dashboardButton: some View {
Button {
guard let url = URL(string: "http://localhost:52415/") else { return }
guard let url = URL(string: "http://localhost:8000/") else { return }
NSWorkspace.shared.open(url)
} label: {
HStack {

View File

@@ -35,7 +35,7 @@ struct BugReportService {
}
func sendReport(
baseURL: URL = URL(string: "http://127.0.0.1:52415")!,
baseURL: URL = URL(string: "http://127.0.0.1:8000")!,
now: Date = Date(),
isManual: Bool = false
) async throws -> BugReportOutcome {

View File

@@ -15,7 +15,7 @@ final class ClusterStateService: ObservableObject {
private let endpoint: URL
init(
baseURL: URL = URL(string: "http://127.0.0.1:52415")!,
baseURL: URL = URL(string: "http://127.0.0.1:8000")!,
session: URLSession = .shared
) {
self.baseURL = baseURL

View File

@@ -861,7 +861,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -901,7 +900,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,7 +1516,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1528,7 +1525,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,7 +1937,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2612,7 +2607,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2800,7 +2794,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2945,7 +2938,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2967,7 +2959,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -96,7 +96,7 @@ interface RawNodeProfile {
interface RawTopologyNode {
nodeId: string;
nodeProfile?: RawNodeProfile;
nodeProfile: RawNodeProfile;
}
interface RawTopologyConnection {
@@ -105,13 +105,9 @@ interface RawTopologyConnection {
sendBackMultiaddr?: { multiaddr?: string; address?: string; ip_address?: string } | string;
}
// Connection can be an object or a tuple [source, target, metadata]
type RawConnectionItem = RawTopologyConnection | [string, string, { sinkMultiaddr?: { ip_address?: string; address?: string } }?];
interface RawTopology {
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
connections?: RawConnectionItem[];
nodes: RawTopologyNode[];
connections?: RawTopologyConnection[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
@@ -202,17 +198,9 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === 'string' ? node : node.nodeId;
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode = typeof node === 'object' ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
@@ -250,7 +238,7 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
}
}
nodes[nodeId] = {
nodes[node.nodeId] = {
system_info: {
model_id: profile?.modelId ?? 'Unknown',
chip: profile?.chipId,
@@ -272,34 +260,14 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
};
}
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
for (const conn of raw.connections || []) {
let localNodeId: string | undefined;
let sendBackNodeId: string | undefined;
let sendBackMultiaddr: { multiaddr?: string; address?: string; ip_address?: string } | string | undefined;
// Check if it's a tuple format [source, target, metadata]
if (Array.isArray(conn)) {
localNodeId = conn[0] as string;
sendBackNodeId = conn[1] as string;
const metadata = conn[2] as { sinkMultiaddr?: { ip_address?: string; address?: string } } | undefined;
if (metadata?.sinkMultiaddr) {
sendBackMultiaddr = metadata.sinkMultiaddr;
}
} else {
// Object format with localNodeId/sendBackNodeId
localNodeId = conn.localNodeId;
sendBackNodeId = conn.sendBackNodeId;
sendBackMultiaddr = conn.sendBackMultiaddr;
}
if (!localNodeId || !sendBackNodeId) continue;
if (localNodeId === sendBackNodeId) continue;
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
if (conn.localNodeId === conn.sendBackNodeId) continue;
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
let sendBackIp: string | undefined;
if (sendBackMultiaddr) {
const multi = sendBackMultiaddr;
if (conn.sendBackMultiaddr) {
const multi = conn.sendBackMultiaddr;
if (typeof multi === 'string') {
sendBackIp = extractIpFromMultiaddr(multi);
} else {
@@ -308,8 +276,8 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
}
edges.push({
source: localNodeId,
target: sendBackNodeId,
source: conn.localNodeId,
target: conn.sendBackNodeId,
sendBackIp
});
}

View File

@@ -1,64 +0,0 @@
# EXO Architecture overview
EXO uses an _Event Sourcing_ architecture, and Erlang-style _message passing_. To facilitate this, we've written a channel library extending anyio channels with inspiration from tokio::sync::mpsc.
Each logical module - designed to be functional independently of the others - communicates with the rest of the system by sending messages on topics.
## Systems
There are currently 5 major systems:
- Master
Executes placement and orders events through a single writer
- Worker
Schedules work on a node, gathers system information, etc.#
- Runner
Executes inference jobs (for now) in an isolated process from the worker for fault-tolerance.
- API
Runs a python webserver for exposing state and commands to client applications
- Election
Implements a distributed algorithm for master election in unstable networking conditions
## Topics
There are currently 5 topics:
- Commands
The API and Worker instruct the master when the event log isn't sufficient. Namely placement and catchup requests go through Commands atm.
- Local Events
All nodes write events here, the master reads those events and orders them
- Global Events
The master writes events here, all nodes read from this topic and fold the produced events into their `State`
- Election Messages
Before establishing a cluster, nodes communicate here to negotiate a master node.
- Connection Messages
The networking system write mdns-discovered hardware connections here.
## Event Sourcing
Lots has been written about event sourcing, but it lets us centralize faulty connections and message ACKing with the following model.
Whenever a device produces side effects, it captures those side effects in an `Event`. `Event`s are then "applied" to their model of `State`, which is globally distributed across the cluster. Whenever a command is received, it is combined with state to produce side effects, captured in yet more events. The rule of thumb is "`Event`s are past tense, `Command`s are imperative". Telling a node to perform some action like "place this model" or "Give me a copy of the event log" is represented by a command (The worker's `Task`s are also commands), while "this node is using 300GB of ram" is an event. Notably, `Event`s SHOULD never cause side effects on their own. There are a few exceptions to this, we're working out the specifics of generalizing the distributed event sourcing model to make it better suit our needs
## Purity
A significant goal of the current design is to make data flow explicit. Classes should either represent simple data (`CamelCaseModel`s typically, and `TaggedModel`s for unions) or active `System`s (Erlang `Actor`s), with all transformations of that data being "referentially transparent" - destructure and construct new data, don't mutate in place. We have had varying degrees of success with this, and are still exploring where purity makes sense.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 514 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 519 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 486 KiB

View File

Before

Width:  |  Height:  |  Size: 7.9 KiB

After

Width:  |  Height:  |  Size: 7.9 KiB

View File

Before

Width:  |  Height:  |  Size: 295 KiB

After

Width:  |  Height:  |  Size: 295 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 131 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 171 KiB

View File

@@ -91,10 +91,6 @@
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
@@ -104,11 +100,6 @@
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list

View File

@@ -1,118 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
import importlib.util
import shutil
from pathlib import Path
from PyInstaller.utils.hooks import collect_submodules
PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
raise SystemExit(f"Unable to locate Exo entrypoint: {ENTRYPOINT}")
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
block_cipher = None
def _module_directory(module_name: str) -> Path:
spec = importlib.util.find_spec(module_name)
if spec is None:
raise SystemExit(f"Module '{module_name}' is not available in the current environment.")
if spec.submodule_search_locations:
return Path(next(iter(spec.submodule_search_locations))).resolve()
if spec.origin:
return Path(spec.origin).resolve().parent
raise SystemExit(f"Unable to determine installation directory for '{module_name}'.")
MLX_PACKAGE_DIR = _module_directory("mlx")
MLX_LIB_DIR = MLX_PACKAGE_DIR / "lib"
if not MLX_LIB_DIR.is_dir():
raise SystemExit(f"mlx Metal libraries are missing: {MLX_LIB_DIR}")
def _safe_collect(package_name: str) -> list[str]:
try:
return collect_submodules(package_name)
except ImportError:
return []
HIDDEN_IMPORTS = sorted(
set(
collect_submodules("mlx")
+ _safe_collect("mlx_lm")
+ _safe_collect("transformers")
)
)
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]
MACMON_PATH = shutil.which("macmon")
if MACMON_PATH is None:
raise SystemExit(
"macmon binary not found in PATH. "
"Install it via: brew install macmon"
)
BINARIES: list[tuple[str, str]] = [
(MACMON_PATH, "."),
]
a = Analysis(
[str(ENTRYPOINT)],
pathex=[str(SOURCE_ROOT)],
binaries=BINARIES,
datas=DATAS,
hiddenimports=HIDDEN_IMPORTS,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name="exo",
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=False,
upx_exclude=[],
name="exo",
)

View File

@@ -5,31 +5,17 @@ description = "Exo"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"typeguard>=4.4.4",
"pydantic>=2.11.7",
"base58>=2.1.1",
"cryptography>=45.0.5",
"httpx>=0.28.1",
"fastapi>=0.116.1",
"filelock>=3.18.0",
"aiosqlite>=0.21.0",
"networkx>=3.5",
"protobuf>=6.32.0",
"rich>=14.1.0",
"rustworkx>=0.17.1",
"sqlmodel>=0.0.24",
"sqlalchemy[asyncio]>=2.0.43",
"greenlet>=3.2.4",
"huggingface-hub>=0.33.4",
"psutil>=7.0.0",
"loguru>=0.7.3",
"textual>=5.3.0",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"bidict>=0.23.1",
"mlx>=0.30.1",
"mlx>=0.29.3",
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
@@ -43,11 +29,9 @@ exo = "exo.main:main"
# dependencies only required for development
[dependency-groups]
dev = [
"pyinstaller>=6.17.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
"pytest-env",
"ruff>=0.11.13",
"trio>=0.32.0",
]
# mlx[cuda] requires a newer version of mlx. the ideal on linux is: default to mlx[cpu] unless[cuda] specified.
@@ -124,12 +108,4 @@ extend-exclude = ["shared/protobufs/**", "*mlx_typings/**", "rust/exo_pyo3_bindi
extend-select = ["I", "N", "B", "A", "PIE", "SIM"]
[tool.pytest.ini_options]
pythonpath = "."
asyncio_mode = "auto"
markers = [
"slow: marks tests as slow (deselected by default)"
]
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
anyio_mode = "auto"

View File

@@ -1,39 +1,4 @@
from __future__ import annotations
import sys
from collections.abc import Sequence
from multiprocessing import freeze_support
from typing import Final
from exo.main import main
INLINE_CODE_FLAG: Final[str] = "-c"
def _maybe_run_inline_code(argv: Sequence[str]) -> bool:
"""
Reproduce the bare minimum of Python's `-c` flag so multiprocessing
helper processes (for example the resource tracker) can execute.
"""
try:
flag_index = argv.index(INLINE_CODE_FLAG)
except ValueError:
return False
code_index = flag_index + 1
if code_index >= len(argv):
return False
inline_code = argv[code_index]
sys.argv = ["-c", *argv[code_index + 1 :]]
namespace: dict[str, object] = {"__name__": "__main__"}
exec(inline_code, namespace, namespace)
return True
if __name__ == "__main__":
if _maybe_run_inline_code(sys.argv):
sys.exit(0)
freeze_support()
main()

View File

@@ -207,7 +207,6 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -263,7 +262,6 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -428,8 +426,9 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
return total_available

View File

@@ -158,7 +158,6 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -201,7 +200,9 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:

View File

@@ -6,11 +6,10 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ibv_coordinators,
get_mlx_ibv_devices_matrix,
get_shard_assignments,
get_smallest_cycles,
)
@@ -20,10 +19,10 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import Host, NodeId
from exo.shared.types.common import Host
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -52,16 +51,19 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
cycles = topology.get_cycles() + [[node] for node in all_nodes]
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_profiles, command.model_meta.storage_size
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
if len(cycles_with_sufficient_memory) == 0:
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
@@ -69,15 +71,13 @@ def place_instance(
smallest_tb_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(
[node.node_id for node in cycle]
).is_thunderbolt_cycle([node.node_id for node in cycle])
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node.node_id) for node in cycle)
@@ -86,7 +86,11 @@ def place_instance(
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node.node_profile.memory.ram_available for node in cycle),
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
start=Memory(),
),
)
@@ -95,16 +99,14 @@ def place_instance(
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
[node.node_id for node in selected_cycle]
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -112,19 +114,20 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle[0].node_id,
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
ibv_devices=mlx_ibv_devices,
ibv_coordinators=mlx_ibv_coordinators,
)
case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)

View File

@@ -1,4 +1,5 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import TypeGuard, cast
from loguru import logger
from pydantic import BaseModel
@@ -8,7 +9,7 @@ from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -23,32 +24,27 @@ class NodeWithProfile(BaseModel):
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[list[NodeId]],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[list[NodeWithProfile]]:
filtered_cycles: list[list[NodeWithProfile]] = []
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
for cycle in cycles:
if not all(node in node_profiles for node in cycle):
if not narrow_all_nodes(cycle):
continue
total_mem = sum(
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(
[
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
for node in cycle
]
)
filtered_cycles.append(cast(list[NodeInfo], cycle))
return filtered_cycles
def get_smallest_cycles(
cycles: list[list[NodeWithProfile]],
) -> list[list[NodeWithProfile]]:
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
@@ -139,9 +135,11 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
selected_cycle: list[NodeInfo],
sharding: Sharding,
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
@@ -178,16 +176,17 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for src, sink, connection in cycle_digraph.list_connections():
if not isinstance(connection, SocketConnection):
continue
if src == current_node and sink == next_node:
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break
@@ -195,7 +194,8 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
return hosts
def get_mlx_jaccl_devices_matrix(
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -204,7 +204,6 @@ def get_mlx_jaccl_devices_matrix(
to device j, or None if no connection exists or no interface name is found.
Diagonal elements are always None.
"""
selected_cycle = list(cycle_digraph.list_nodes())
num_nodes = len(selected_cycle)
matrix: list[list[str | None]] = [
[None for _ in range(num_nodes)] for _ in range(num_nodes)
@@ -215,55 +214,86 @@ def get_mlx_jaccl_devices_matrix(
if i == j:
continue
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
# Find the IP J uses to talk to I
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
break
else:
logger.warning(
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
"Current ibv backend requires all-to-all rdma connections"
)
return matrix
def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[str]:
"""Find all IP addresses that connect node i to node j."""
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
def _find_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def get_mlx_ibv_coordinators(
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
"""Get the coordinator addresses for MLX IBV (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
selected_cycle = list(cycle_digraph.list_nodes())
logger.info(f"Selecting coordinator: {coordinator}")
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}
return {
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}

View File

@@ -1,36 +1,67 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
# TODO: this is a hack to get the port for the send_back_multiaddr
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)
return _create_connection

View File

@@ -2,7 +2,6 @@ from datetime import datetime, timezone
from typing import Sequence
import anyio
import pytest
from loguru import logger
from exo.master.main import Master
@@ -19,13 +18,15 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodeGatheredInfo,
NodePerformanceMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -38,7 +39,6 @@ from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
@@ -81,14 +81,21 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
NodePerformanceMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
),
@@ -152,7 +159,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()

View File

@@ -1,3 +1,5 @@
from typing import Callable
import pytest
from loguru import logger
@@ -5,20 +7,14 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_connection,
create_node_profile,
create_rdma_connection,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import SocketConnection
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -30,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -73,33 +74,30 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(2))
topology.add_connection(node_id_c, node_id_a, create_connection(3))
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_a))
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
@@ -125,11 +123,12 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit() -> None:
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -138,7 +137,7 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
n_layers=10,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -149,11 +148,12 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
topology.add_node(create_node(1001 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -162,7 +162,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
n_layers=10,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -173,11 +173,12 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit() -> None:
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
@@ -188,7 +189,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, profiles)
place_instance(cic, topology, {})
def test_get_transition_events_no_change(instance: Instance):
@@ -234,102 +235,190 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_prioritizes_leaf_cycle_with_less_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
# cycle having LESS total memory than D-E-F. The algorithm should still choose
# the cycle that contains a leaf node.
model_meta.storage_size = Memory.from_bytes(1000)
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
node_id_x = NodeId()
node_id_y = NodeId()
node_id_z = NodeId()
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
# Daisy chain topology
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_a, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(1))
topology.add_connection(node_id_c, node_id_b, create_connection(1))
topology.add_connection(node_id_c, node_id_d, create_connection(1))
topology.add_connection(node_id_d, node_id_c, create_connection(1))
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
logger.info(list(topology.list_connections()))
# Extra nodes with tiny memory so they can't form singleton placements
topology.add_node(create_node(10, node_id_x))
topology.add_node(create_node(10, node_id_y))
topology.add_node(create_node(10, node_id_z))
# Build directed cycles
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_d))
# Add extra outgoing edges from D/E/F so none of them are leaves
topology.add_connection(create_connection(node_id_d, node_id_x))
topology.add_connection(create_connection(node_id_e, node_id_y))
topology.add_connection(create_connection(node_id_f, node_id_z))
cic = place_instance_command(
model_meta=model_meta,
)
# act
placements = place_instance(cic, topology, {}, profiles)
# Act
placements = place_instance(cic, topology, {})
# assert
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
# D-E-F has more total memory.
assert len(placements) == 1
instance = list(placements.values())[0]
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(node_id_c, node_id_d)
)
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="192.168.1.100",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a, node_b, create_rdma_connection(3))
topology.add_connection(node_b, node_c, create_rdma_connection(4))
topology.add_connection(node_c, node_a, create_rdma_connection(5))
topology.add_connection(node_b, node_a, create_rdma_connection(3))
topology.add_connection(node_c, node_b, create_rdma_connection(4))
topology.add_connection(node_a, node_c, create_rdma_connection(5))
topology.add_connection(node_a, node_b, ethernet_conn)
topology.add_connection(node_b, node_c, ethernet_conn)
topology.add_connection(node_c, node_a, ethernet_conn)
topology.add_connection(node_a, node_c, ethernet_conn)
topology.add_connection(node_b, node_a, ethernet_conn)
topology.add_connection(node_c, node_b, ethernet_conn)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
cic = PlaceInstance(
sharding=Sharding.Tensor,
@@ -339,7 +428,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
min_nodes=1,
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -347,10 +436,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.jaccl_devices is not None
assert instance.jaccl_coordinators is not None
assert instance.ibv_devices is not None
assert instance.ibv_coordinators is not None
matrix = instance.jaccl_devices
matrix = instance.ibv_devices
assert len(matrix) == 3
for i in range(3):
@@ -359,21 +448,21 @@ def test_tensor_rdma_backend_connectivity_matrix(
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3
assert len(instance.ibv_coordinators) == 3
for node_id in assigned_nodes:
assert node_id in instance.jaccl_coordinators
coordinator = instance.jaccl_coordinators[node_id]
assert node_id in instance.ibv_coordinators
coordinator = instance.ibv_coordinators[node_id]
assert ":" in coordinator
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
if node_id == assigned_nodes[0]:

View File

@@ -1,48 +1,56 @@
from typing import Callable
import pytest
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_mlx_ibv_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_connection, create_node_profile
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(1)
connection2 = create_connection(2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
@@ -50,65 +58,64 @@ def test_filter_cycles_by_memory():
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory():
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(1)
connection2 = create_connection(2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
topology.get_cycles(), Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory():
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
@@ -120,38 +127,31 @@ def test_filter_multiple_cycles_by_memory():
}
def test_get_smallest_cycles():
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
# act
smallest_cycles = get_smallest_cycles(cycles)
smallest_cycles = get_smallest_cycles(topology.get_cycles())
# assert
assert len(smallest_cycles) == 1
@@ -168,6 +168,9 @@ def test_get_smallest_cycles():
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -176,25 +179,19 @@ def test_get_shard_assignments(
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_c_id, create_connection(2))
topology.add_connection(node_c_id, node_a_id, create_connection(3))
topology.add_connection(node_b_id, node_a_id, create_connection(4))
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
@@ -202,11 +199,7 @@ def test_get_shard_assignments(
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
)
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
cycles = topology.get_cycles()
selected_cycle = cycles[0]
# act
@@ -235,21 +228,28 @@ def test_get_shard_assignments(
)
def test_get_hosts_from_subgraph():
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -257,47 +257,108 @@ def test_get_hosts_from_subgraph():
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=("169.254.0.2"), port=1234),
Host(ip=("169.254.0.3"), port=1234),
Host(ip=("169.254.0.4"), port=1234),
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
def test_get_mlx_ibv_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
conn_a_b = create_connection(1)
conn_b_a = create_connection(2)
conn_b_c = create_connection(3)
conn_c_b = create_connection(4)
conn_c_a = create_connection(5)
conn_a_c = create_connection(6)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
topology.add_connection(node_a_id, node_b_id, conn_a_b)
topology.add_connection(node_b_id, node_a_id, conn_b_a)
topology.add_connection(node_b_id, node_c_id, conn_b_c)
topology.add_connection(node_c_id, node_b_id, conn_c_b)
topology.add_connection(node_c_id, node_a_id, conn_c_a)
topology.add_connection(node_a_id, node_c_id, conn_a_c)
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_b)
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_jaccl_coordinators(
node_a_id, coordinator_port=5000, cycle_digraph=topology
coordinators = get_mlx_ibv_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -326,11 +387,11 @@ def test_get_mlx_jaccl_coordinators():
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
"node_b should use the IP from conn_b_a"
)
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
"node_c should use the IP from conn_c_a"
)
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"

View File

@@ -1,14 +1,13 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
@pytest.fixture
@@ -17,15 +16,20 @@ def topology() -> Topology:
@pytest.fixture
def connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryUsage.from_bytes(
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -39,85 +43,162 @@ def node_profile() -> NodePerformanceProfile:
)
def test_add_node(topology: Topology):
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
# arrange
node_id = NodeId()
# act
topology.add_node(node_id)
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
# assert
assert topology.node_is_leaf(node_id)
data = topology.get_node_profile(node_id)
assert data == node_profile
def test_add_connection(topology: Topology, connection: SocketConnection):
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
data = list(conn for _, _, conn in topology.list_connections())
data = topology.get_connection_profile(connection)
# assert
assert data == [connection]
assert data == connection.connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
def test_remove_connection_still_connected(
topology: Topology, connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_connection(node_a, node_b, connection)
topology.remove_connection(connection)
# assert
assert list(topology.get_all_connections_between(node_a, node_b)) == []
assert topology.get_connection_profile(connection) is None
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_node(node_b)
topology.remove_node(connection.local_node_id)
# assert
assert list(topology.out_edges(node_a)) == []
assert topology.get_node_profile(connection.local_node_id) is None
def test_list_nodes(topology: Topology, connection: SocketConnection):
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeId) for node in nodes)
assert {node for node in nodes} == {node_a, node_b}
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}

View File

@@ -15,7 +15,6 @@ def buffer() -> OrderedBuffer[Event]:
return OrderedBuffer[Event]()
@pytest.mark.asyncio
async def test_initial_state(buffer: OrderedBuffer[Event]):
"""Tests that a new buffer is empty and starts at index 1."""
assert buffer.next_idx_to_release == 0
@@ -23,7 +22,6 @@ async def test_initial_state(buffer: OrderedBuffer[Event]):
assert buffer.drain() == []
@pytest.mark.asyncio
async def test_ingest_and_drain_sequential_events(buffer: OrderedBuffer[Event]):
"""Tests ingesting and draining a simple, ordered sequence of events."""
events = [make_indexed_event(0), make_indexed_event(1), make_indexed_event(2)]
@@ -35,7 +33,6 @@ async def test_ingest_and_drain_sequential_events(buffer: OrderedBuffer[Event]):
assert not buffer.store
@pytest.mark.asyncio
async def test_ingest_out_of_order_events(buffer: OrderedBuffer[Event]):
"""Tests that out-of-order events are buffered and drained in the correct sequence."""
event1 = make_indexed_event(0)
@@ -51,7 +48,6 @@ async def test_ingest_out_of_order_events(buffer: OrderedBuffer[Event]):
assert buffer.next_idx_to_release == 3
@pytest.mark.asyncio
async def test_drain_with_gap_in_sequence(buffer: OrderedBuffer[Event]):
"""Tests that draining stops when there is a gap in the event indices."""
event1 = make_indexed_event(0)
@@ -68,7 +64,6 @@ async def test_drain_with_gap_in_sequence(buffer: OrderedBuffer[Event]):
assert 2 in buffer.store
@pytest.mark.asyncio
async def test_fill_gap_and_drain_remaining(buffer: OrderedBuffer[Event]):
"""Tests that once a gap is filled, the rest of the sequence is drained."""
event0 = make_indexed_event(0)
@@ -87,7 +82,6 @@ async def test_fill_gap_and_drain_remaining(buffer: OrderedBuffer[Event]):
assert buffer.next_idx_to_release == 3
@pytest.mark.asyncio
async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
"""Tests that if multiple events for the same index are ingested, the first one wins."""
event2_first = make_indexed_event(1)
@@ -106,7 +100,6 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
assert drained[1][1].event_id != event2_second[1].event_id
@pytest.mark.asyncio
async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):
"""Tests that events with an index lower than next_idx_to_release are dropped."""
buffer.ingest(*make_indexed_event(0))
@@ -124,7 +117,6 @@ async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):
assert buffer.drain() == []
@pytest.mark.asyncio
async def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]):
"""Tests reusing the buffer after it has been fully drained."""
buffer.ingest(*make_indexed_event(0))

View File

@@ -11,8 +11,10 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -25,23 +27,13 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import RDMAConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacTBConnections,
MacTBIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
@@ -55,12 +47,16 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -192,7 +188,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -200,12 +196,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -213,69 +205,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
# TODO: should be broken up into individual events instead of this monster
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
# TODO: makes me slightly sad
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
case MacTBIdentifiers():
profile.tb_interfaces = info.idents
case MacTBConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
(
conn_map[tb_conn.sink_uuid][0],
RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_profiles": new_profiles,
"last_seen": last_seen,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_connection(event.source, event.sink, event.edge)
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.sink, event.source, event.edge)
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -1,47 +1,35 @@
import os
import sys
from pathlib import Path
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
EXO_HOME_RELATIVE_PATH = os.environ.get("EXO_HOME", ".exo")
EXO_HOME = Path.home() / EXO_HOME_RELATIVE_PATH
EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR")
EXO_MODELS_DIR = Path(EXO_MODELS_DIR_ENV) if EXO_MODELS_DIR_ENV else EXO_HOME / "models"
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
EXO_GLOBAL_EVENT_DB = EXO_HOME / "global_events.db"
EXO_WORKER_EVENT_DB = EXO_HOME / "worker_events.db"
EXO_MASTER_STATE = EXO_HOME / "master_state.json"
EXO_WORKER_STATE = EXO_HOME / "worker_state.json"
EXO_MASTER_LOG = EXO_HOME / "master.log"
EXO_WORKER_LOG = EXO_HOME / "worker.log"
EXO_LOG = EXO_HOME / "exo.log"
EXO_TEST_LOG = EXO_HOME / "exo_test.log"
if _EXO_HOME_ENV is not None:
return Path.home() / _EXO_HOME_ENV
EXO_NODE_ID_KEYPAIR = EXO_HOME / "node_id.keypair"
if sys.platform != "linux":
return Path.home() / ".exo"
EXO_WORKER_KEYRING_FILE = EXO_HOME / "worker_keyring"
EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"
xdg_value = os.environ.get(env_var, None)
if xdg_value is not None:
return Path(xdg_value) / "exo"
return Path.home() / fallback / "exo"
EXO_CONFIG_HOME = _get_xdg_dir("XDG_CONFIG_HOME", ".config")
EXO_DATA_HOME = _get_xdg_dir("XDG_DATA_HOME", ".local/share")
EXO_CACHE_HOME = _get_xdg_dir("XDG_CACHE_HOME", ".cache")
# Models directory (data)
_EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR", None)
EXO_MODELS_DIR = (
EXO_DATA_HOME / "models"
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
)
# Log files (data/logs or cache)
EXO_LOG = EXO_CACHE_HOME / "exo.log"
EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
EXO_IPC_DIR = EXO_HOME / "ipc"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
# lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip.
LB_TFLOPS = 2.3
LB_MEMBW_GBPS = 68
LB_DISK_GBPS = 1.5

View File

@@ -24,8 +24,6 @@ class _InterceptHandler(logging.Handler):
except ValueError:
level = record.levelno
return
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())

View File

@@ -1,7 +1,5 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
@@ -50,7 +48,7 @@ class ConfigData(BaseModel):
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
await target_dir.mkdir(parents=True, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
@@ -60,14 +58,14 @@ async def get_config_data(model_id: str) -> ConfigData:
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
async with await config_path.open("r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
await target_dir.mkdir(parents=True, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
@@ -77,7 +75,7 @@ async def get_safetensors_size(model_id: str) -> Memory:
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
async with await index_path.open("r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata

View File

@@ -1,8 +1,5 @@
"""Pytest configuration and shared fixtures for shared package tests."""
import asyncio
from typing import Generator
import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
@@ -12,21 +9,6 @@ from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""Create an event loop for the test session."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()
@pytest.fixture(autouse=True)
def reset_event_loop():
"""Reset the event loop for each test to ensure clean state."""
# This ensures each test gets a fresh event loop state
def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
@@ -52,7 +34,7 @@ def caplog(caplog: LogCaptureFixture):
format="{message}",
level=0,
filter=lambda record: record["level"].no >= caplog.handler.level,
enqueue=True, # Set to 'True' if your test is spawning child processes.
enqueue=True,
)
yield caplog
logger.remove(handler_id)

View File

@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
NodeDownloadProgress(download_progress=event), state
)
assert new_state.downloads == {NodeId("node-1"): [event]}
assert new_state == State(downloads={NodeId("node-1"): [event]})
def test_apply_two_node_download_progress():
@@ -39,4 +39,7 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})

View File

@@ -46,7 +46,6 @@ def fast_election_timeout(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
@pytest.mark.anyio
async def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> None:
"""
Start a round by injecting an ElectionMessage with higher clock.
@@ -102,7 +101,6 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> No
assert election.seniority == 2
@pytest.mark.anyio
async def test_peer_with_higher_seniority_wins_and_we_switch_master() -> None:
"""
If a peer with clearly higher seniority participates in the round, they should win.
@@ -156,7 +154,6 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master() -> None:
assert election.seniority == 0
@pytest.mark.anyio
async def test_ignores_older_messages() -> None:
"""
Messages with a lower clock than the current round are ignored by the receiver.
@@ -205,7 +202,6 @@ async def test_ignores_older_messages() -> None:
# Not asserting on the result; focus is on ignore behavior.
@pytest.mark.anyio
async def test_two_rounds_emit_two_broadcasts_and_increment_clock() -> None:
"""
Two successive rounds → two broadcasts. Second round triggered by a higher-clock message.
@@ -251,7 +247,6 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock() -> None:
# Not asserting on who won; just that both rounds were broadcast.
@pytest.mark.anyio
async def test_promotion_new_seniority_counts_participants() -> None:
"""
When we win against two peers in the same round, our seniority becomes
@@ -300,7 +295,6 @@ async def test_promotion_new_seniority_counts_participants() -> None:
assert election.seniority == 3
@pytest.mark.anyio
async def test_connection_message_triggers_new_round_broadcast() -> None:
"""
A connection message increments the clock and starts a new campaign.
@@ -352,7 +346,6 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
# After cancellation (before election finishes), no seniority changes asserted here.
@pytest.mark.anyio
async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
"""
With equal seniority, the node that has seen more commands should win the election.

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import SocketConnection
from exo.shared.types.topology import Connection
def test_state_serialization_roundtrip() -> None:
@@ -11,12 +11,14 @@ def test_state_serialization_roundtrip() -> None:
node_a = NodeId("node-a")
node_b = NodeId("node-b")
connection = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
connection = Connection(
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
state.topology.add_connection(node_a, node_b, connection)
state.topology.add_connection(connection)
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)

View File

@@ -1,118 +0,0 @@
"""Tests for XDG Base Directory Specification compliance."""
import os
import sys
from pathlib import Path
from unittest import mock
def test_xdg_paths_on_linux():
"""Test that XDG paths are used on Linux when XDG env vars are set."""
with (
mock.patch.dict(
os.environ,
{
"XDG_CONFIG_HOME": "/tmp/test-config",
"XDG_DATA_HOME": "/tmp/test-data",
"XDG_CACHE_HOME": "/tmp/test-cache",
},
clear=False,
),
mock.patch.object(sys, "platform", "linux"),
):
# Re-import to pick up mocked values
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
assert Path("/tmp/test-config/exo") == constants.EXO_CONFIG_HOME
assert Path("/tmp/test-data/exo") == constants.EXO_DATA_HOME
assert Path("/tmp/test-cache/exo") == constants.EXO_CACHE_HOME
def test_xdg_default_paths_on_linux():
"""Test that XDG default paths are used on Linux when env vars are not set."""
# Remove XDG env vars and EXO_HOME
env = {
k: v
for k, v in os.environ.items()
if not k.startswith("XDG_") and k != "EXO_HOME"
}
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch.object(sys, "platform", "linux"),
):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
home = Path.home()
assert home / ".config" / "exo" == constants.EXO_CONFIG_HOME
assert home / ".local/share" / "exo" == constants.EXO_DATA_HOME
assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME
def test_legacy_exo_home_takes_precedence():
"""Test that EXO_HOME environment variable takes precedence for backward compatibility."""
with mock.patch.dict(
os.environ,
{
"EXO_HOME": ".custom-exo",
"XDG_CONFIG_HOME": "/tmp/test-config",
},
clear=False,
):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
home = Path.home()
assert home / ".custom-exo" == constants.EXO_CONFIG_HOME
assert home / ".custom-exo" == constants.EXO_DATA_HOME
def test_macos_uses_traditional_paths():
"""Test that macOS uses traditional ~/.exo directory."""
# Remove EXO_HOME to ensure we test the default behavior
env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"}
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch.object(sys, "platform", "darwin"),
):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
home = Path.home()
assert home / ".exo" == constants.EXO_CONFIG_HOME
assert home / ".exo" == constants.EXO_DATA_HOME
assert home / ".exo" == constants.EXO_CACHE_HOME
def test_node_id_in_config_dir():
"""Test that node ID keypair is in the config directory."""
import exo.shared.constants as constants
assert constants.EXO_NODE_ID_KEYPAIR.parent == constants.EXO_CONFIG_HOME
def test_models_in_data_dir():
"""Test that models directory is in the data directory."""
# Clear EXO_MODELS_DIR to test default behavior
env = {k: v for k, v in os.environ.items() if k != "EXO_MODELS_DIR"}
with mock.patch.dict(os.environ, env, clear=True):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
assert constants.EXO_MODELS_DIR.parent == constants.EXO_DATA_HOME

View File

@@ -1,219 +1,203 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
class TopologySnapshot(BaseModel):
nodes: Sequence[NodeId]
connections: Mapping[
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
]
nodes: list[NodeInfo]
connections: list[Connection]
model_config = ConfigDict(frozen=True, extra="forbid")
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
@dataclass
class Topology:
# the _graph can be used as a int -> NodeId map.
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()), connections=self.map_connections()
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node_id in snapshot.nodes:
for node in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node_id)
topology.add_node(node)
for source in snapshot.connections:
for sink in snapshot.connections[source]:
for conn in snapshot.connections[source][sink]:
topology.add_connection(source, sink, conn)
for connection in snapshot.connections:
topology.add_connection(connection)
return topology
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
return
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
]
def out_edges(
self, node_id: NodeId
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
if node_id not in self._vertex_indices:
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return (
(self._graph[nid], conn)
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
)
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
]
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._vertex_indices
return node_id in self._node_id_to_rx_id_map
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
source: NodeId,
sink: NodeId,
connection: SocketConnection | RDMAConnection,
connection: Connection,
) -> None:
if connection in self.get_all_connections_between(source, sink):
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
return
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
_ = self._graph.add_edge(src_id, sink_id, connection)
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
try:
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def list_connections(
self,
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
return (
(
self._graph[src_id],
self._graph[sink_id],
connection,
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._vertex_indices:
if node_id not in self._node_id_to_rx_id_map:
return
rx_idx = self._vertex_indices[node_id]
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph.remove_node(rx_idx)
del self._vertex_indices[node_id]
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
def replace_all_out_tb_connections(
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for sink, conn in new_connections:
self.add_connection(source, sink, conn)
def remove_connection(
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
) -> None:
if source not in self._vertex_indices or sink not in self._vertex_indices:
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
return
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[source], self._vertex_indices[sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == edge:
self._graph.remove_edge_from_index(conn_idx)
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
def get_cycles(self) -> list[list[NodeId]]:
def get_cycles(self) -> list[list[NodeInfo]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[NodeId]] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_cycles_tb(self) -> list[list[NodeId]]:
def get_cycles_tb(self) -> list[list[NodeInfo]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[list[NodeId]] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
topology = Topology()
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for source, sink, connection in self.list_connections():
if source in node_ids and sink in node_ids:
topology.add_connection(source, sink, connection)
for connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import SocketConnection
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
# TODO: bikeshed this naem
class NodeGatheredInfo(BaseEvent):
class NodePerformanceMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
class NodeDownloadProgress(BaseEvent):
@@ -97,15 +107,11 @@ class ChunkGenerated(BaseEvent):
class TopologyEdgeCreated(BaseEvent):
source: NodeId
sink: NodeId
edge: SocketConnection
edge: Connection
class TopologyEdgeDeleted(BaseEvent):
source: NodeId
sink: NodeId
edge: SocketConnection
edge: Connection
Event = (
@@ -119,8 +125,10 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodeGatheredInfo
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| TopologyEdgeCreated

View File

@@ -1,11 +1,10 @@
import re
from typing import ClassVar
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from pydantic import BaseModel, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,14 +1,12 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import TBIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryUsage(CamelCaseModel):
class MemoryPerformanceProfile(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -54,16 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[TBIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
class ConnectionProfile(CamelCaseModel):
pass
throughput: float
latency: float
jitter: float

View File

@@ -1,64 +0,0 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class TBConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class TBIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
# Intentionally minimal, only collecting data we care about - there's a lot more
class TBReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str
class TBConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None
class TBConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None
device_name_key: str
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: TBReceptacleTag
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
if self.domain_uuid_key is None:
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
iface = f"rdma_{ifaces[tag]}"
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
def conn(self) -> TBConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
)
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
class TBConnectivity(BaseModel):
SPThunderboltDataType: list[TBConnectivityData]
@classmethod
async def gather(cls) -> list[TBConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType

View File

@@ -1,32 +1,37 @@
from enum import Enum
from loguru import logger
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.utils.pydantic_ext import FrozenModel
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
def is_thunderbolt(self) -> bool:
logger.warning("duh")
return True
# TODO
class LinkType(str, Enum):
Thunderbolt = "Thunderbolt"
Ethernet = "Ethernet"
WiFi = "WiFi"
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -29,8 +29,8 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
ibv_devices: list[list[str | None]]
ibv_coordinators: dict[NodeId, str]
# TODO: Single node instance

View File

@@ -36,6 +36,9 @@ class Sender[T](AnyioSender[T]):
raise ClosedResourceError
return Receiver(_state=self._state)
def __enter__(self) -> Self:
return self
class Receiver[T](AnyioReceiver[T]):
def clone(self) -> "Receiver[T]":

View File

View File

@@ -1,231 +0,0 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacTBIdentifiers(TaggedModel):
idents: Sequence[TBIdentifier]
class MacTBConnections(TaggedModel):
conns: Sequence[TBConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
# TODO
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacTBIdentifiers
| MacTBConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
if IS_DARWIN:
tg.start_soon(self._monitor_system_profiler)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await TBConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacTBIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacTBConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -1,70 +0,0 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -1,56 +0,0 @@
import socket
from collections.abc import Mapping
from ipaddress import ip_address
from anyio import create_task_group, to_thread
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
# TODO: ref. api port
async def check_reachability(
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout
try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
except socket.gaierror:
# seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}")
return
finally:
sock.close()
if result == 0:
if target_node_id not in out:
out[target_node_id] = set()
out[target_node_id].add(target_ip)
async def check_reachable(
our_node_id: NodeId,
topology: Topology,
profiles: Mapping[NodeId, NodePerformanceProfile],
) -> Mapping[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
our_profile = profiles.get(our_node_id, None)
if our_profile is None:
return {}
our_interfaces = our_profile.network_interfaces
async with create_task_group() as tg:
for node_id in topology.list_nodes():
if node_id not in profiles or node_id == our_node_id:
continue
for iface in profiles[node_id].network_interfaces:
if ip_address(iface.ip_address).is_loopback:
# Definitely a loopback address
continue
if iface in our_interfaces:
# Skip duplicates with our own interfaces
continue
tg.start_soon(check_reachability, iface.ip_address, node_id, reachable)
return reachable

View File

@@ -19,20 +19,11 @@ class CamelCaseModel(BaseModel):
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
strict=True,
)
class FrozenModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
strict=True,
frozen=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):

View File

@@ -1,77 +0,0 @@
"""Tests for macmon error handling.
These tests verify that MacMon errors are handled gracefully without
crashing the application or spamming logs.
"""
import platform
from subprocess import CalledProcessError
from unittest.mock import AsyncMock, patch
import pytest
from exo.worker.utils.macmon import MacMonError, get_metrics_async
@pytest.mark.skipif(
platform.system().lower() != "darwin" or "arm" not in platform.machine().lower(),
reason="MacMon only supports macOS with Apple Silicon",
)
class TestMacMonErrorHandling:
"""Test MacMon error handling."""
async def test_called_process_error_wrapped_as_macmon_error(self) -> None:
"""CalledProcessError should be wrapped as MacMonError."""
mock_error = CalledProcessError(
returncode=1,
cmd=["macmon", "pipe", "-s", "1"],
stderr=b"some error message",
)
with (
patch(
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
),
patch(
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
) as mock_run,
):
mock_run.side_effect = mock_error
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon failed with return code 1" in str(exc_info.value)
assert "some error message" in str(exc_info.value)
async def test_called_process_error_with_no_stderr(self) -> None:
"""CalledProcessError with no stderr should be handled gracefully."""
mock_error = CalledProcessError(
returncode=1,
cmd=["macmon", "pipe", "-s", "1"],
stderr=None,
)
with (
patch(
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
),
patch(
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
) as mock_run,
):
mock_run.side_effect = mock_error
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon failed with return code 1" in str(exc_info.value)
assert "no stderr" in str(exc_info.value)
async def test_macmon_not_found_raises_macmon_error(self) -> None:
"""When macmon is not found in PATH, MacMonError should be raised."""
with patch("exo.worker.utils.macmon.shutil.which", return_value=None):
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon not found in PATH" in str(exc_info.value)

View File

@@ -235,7 +235,6 @@ class L1C(TaggedModel):
L1 = L1A | L1B | L1C
@pytest.mark.anyio
async def test_tagged_union_is_fast():
# payload along the "C" path (worst case for DFS if branches are tried A->B->C)
payload = {"L1C": {"child": {"L2C": {"child": {"L3C": {"x": 123}}}}}}

View File

@@ -1,7 +1,6 @@
import multiprocessing as mp
import time
import pytest
from anyio import fail_after
from loguru import logger
@@ -28,8 +27,8 @@ def bar(send: MpSender[str]):
send.close()
@pytest.mark.anyio
async def test_channel_ipc():
# not async, just want the fail_after
async def test_channel_setup():
with fail_after(0.5):
s, r = mp_channel[str]()
p1 = mp.Process(target=foo, args=(r,))

View File

@@ -1,19 +1,15 @@
import asyncio
import hashlib
import os
import shutil
import ssl
import time
import traceback
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from typing import Literal, cast
from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
import certifi
import anyio
import httpx
from anyio import Path, to_thread
from loguru import logger
from pydantic import (
BaseModel,
@@ -24,10 +20,11 @@ from pydantic import (
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.constants import EXO_HOME, EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
@@ -132,16 +129,37 @@ async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def ensure_exo_home() -> Path:
home = Path(EXO_HOME)
await home.mkdir(parents=True, exist_ok=True)
return home
async def has_exo_home_read_access() -> bool:
try:
return await to_thread.run_sync(os.access, EXO_HOME, os.R_OK)
except OSError:
return False
async def has_exo_home_write_access() -> bool:
try:
return await to_thread.run_sync(os.access, EXO_HOME, os.W_OK)
except OSError:
return False
async def ensure_models_dir() -> Path:
await aios.makedirs(EXO_MODELS_DIR, exist_ok=True)
return EXO_MODELS_DIR
models_dir = Path(EXO_MODELS_DIR)
await models_dir.mkdir(parents=True, exist_ok=True)
return models_dir
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
if not await aios.path.exists(model_dir):
if not await model_dir.exists():
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
await to_thread.run_sync(shutil.rmtree, model_dir)
return True
@@ -149,14 +167,14 @@ async def seed_models(seed_dir: str | Path):
"""Move model in resources folder of app to .cache/huggingface/hub"""
source_dir = Path(seed_dir)
dest_dir = await ensure_models_dir()
for path in source_dir.iterdir():
if path.is_dir() and path.name.startswith("models--"):
async for path in source_dir.iterdir():
if await path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir / path.name
if await aios.path.exists(dest_path):
if await dest_path.exists():
logger.info("Skipping moving model to .cache directory")
else:
try:
await aios.rename(str(path), str(dest_path))
await path.rename(dest_path)
except Exception:
logger.error(f"Error seeding model {path} to {dest_path}")
logger.error(traceback.format_exc())
@@ -168,16 +186,16 @@ async def fetch_file_list_with_cache(
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
await aios.makedirs(target_dir, exist_ok=True)
await target_dir.mkdir(parents=True, exist_ok=True)
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
if await cache_file.exists():
async with await cache_file.open("r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await cache_file.parent.mkdir(parents=True, exist_ok=True)
async with await cache_file.open("w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
return file_list
@@ -192,7 +210,7 @@ async def fetch_file_list_with_retry(
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
await anyio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
)
@@ -206,66 +224,55 @@ async def _fetch_file_list(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
create_http_client(timeout_profile="short") as client,
):
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
response = await client.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Failed to fetch file list: {response.status_code}")
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
async def get_download_headers() -> dict[str, str]:
return {**(await get_auth_headers()), "Accept-Encoding": "identity"}
def create_http_session(
auto_decompress: bool = False,
def create_http_client(
timeout_profile: Literal["short", "long"] = "long",
) -> aiohttp.ClientSession:
) -> httpx.AsyncClient:
if timeout_profile == "short":
total_timeout = 30
connect_timeout = 10
sock_read_timeout = 30
sock_connect_timeout = 10
timeout = httpx.Timeout(
connect=10,
read=30,
write=30,
pool=30,
)
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 1800
sock_connect_timeout = 60
timeout = httpx.Timeout(
connect=60,
read=1800,
write=1800,
pool=1800,
)
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
sock_read=sock_read_timeout,
sock_connect=sock_connect_timeout,
),
)
return httpx.AsyncClient(timeout=timeout)
async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -> str:
hasher = hashlib.sha1() if hash_type == "sha1" else hashlib.sha256()
if hash_type == "sha1":
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
header = f"blob {(await path.stat()).st_size}\0".encode()
hasher.update(header)
async with aiofiles.open(path, "rb") as f:
async with await path.open("rb") as f:
while chunk := await f.read(8 * 1024 * 1024):
hasher.update(chunk)
return hasher.hexdigest()
@@ -281,24 +288,28 @@ async def file_meta(
)
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.head(url, headers=headers) as r,
create_http_client(timeout_profile="short") as client,
):
if r.status == 307:
r = await client.head(url, headers=headers)
if r.status_code == 307:
# On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("x-linked-etag")
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag)
if "x-linked-size" in r.headers and "x-linked-etag" in r.headers:
content_length = int(r.headers["x-linked-size"])
etag = trim_etag(r.headers["x-linked-etag"])
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
redirected_location = r.headers["location"]
return await file_meta(repo_id, revision, path, redirected_location)
# this can totally fail in weird ways if the HF endpoint behaves weirdly
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
r.headers.get("x-linked-size", None)
or r.headers.get("content-length", None)
or 0
)
etag = cast(
str | None,
r.headers.get("x-linked-etag", None) or r.headers.get("etag", None),
)
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
etag = trim_etag(etag)
@@ -310,13 +321,13 @@ async def download_file_with_retry(
revision: str,
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
progress_sender: Sender[tuple[int, int, bool]] | None,
) -> Path:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _download_file(
repo_id, revision, path, target_dir, on_progress
repo_id, revision, path, target_dir, progress_sender
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
@@ -325,7 +336,7 @@ async def download_file_with_retry(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
await anyio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
)
@@ -336,18 +347,16 @@ async def _download_file(
revision: str,
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
progress_sender: Sender[tuple[int, int, bool]] | None,
) -> Path:
if await aios.path.exists(target_dir / path):
if await (target_dir / path).exists():
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
await ((target_dir / path).parent).mkdir(parents=True, exist_ok=True)
length, etag = await file_meta(repo_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
(await aios.stat(partial_path)).st_size
if (await aios.path.exists(partial_path))
else None
(await partial_path.stat()).st_size if (await partial_path.exists()) else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
@@ -356,20 +365,19 @@ async def _download_file(
headers["Range"] = f"bytes={resume_byte_pos}-"
n_read = resume_byte_pos or 0
async with (
create_http_session(timeout_profile="long") as session,
session.get(url, headers=headers) as r,
create_http_client(timeout_profile="long") as client,
client.stream("GET", url, headers=headers) as r,
):
if r.status == 404:
if r.status_code == 404:
raise FileNotFoundError(f"File not found: {url}")
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
assert r.status_code in [200, 206], (
f"Failed to download {path} from {url}: {r.status_code}"
)
async with aiofiles.open(
partial_path, "ab" if resume_byte_pos else "wb"
) as f:
while chunk := await r.content.read(8 * 1024 * 1024):
async with await partial_path.open("ab" if resume_byte_pos else "wb") as f:
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
n_read = n_read + (await f.write(chunk))
on_progress(n_read, length, False)
if progress_sender is not None:
await progress_sender.send((n_read, length, False))
final_hash = await calc_hash(
partial_path, hash_type="sha256" if len(remote_hash) == 64 else "sha1"
@@ -377,14 +385,15 @@ async def _download_file(
integrity = final_hash == remote_hash
if not integrity:
try:
await aios.remove(partial_path)
await partial_path.unlink()
except Exception as e:
logger.error(f"Error removing partial file {partial_path}: {e}")
raise Exception(
f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}"
)
await aios.rename(partial_path, target_dir / path)
on_progress(length, length, True)
await partial_path.rename(target_dir / path)
if progress_sender is not None:
await progress_sender.send((length, length, True))
return target_dir / path
@@ -440,11 +449,11 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
await (target_dir).mkdir(parents=True, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
repo_id, revision, "model.safetensors.index.json", target_dir, None
)
async with aiofiles.open(index_file, "r") as f:
async with await index_file.open("r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
@@ -461,10 +470,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path):
return (await aios.stat(path)).st_size
if await aios.path.exists(partial_path):
return (await aios.stat(partial_path)).st_size
if await path.exists():
return (await path.stat()).st_size
if await partial_path.exists():
return (await partial_path.stat()).st_size
return 0
@@ -475,12 +484,12 @@ async def download_progress_for_local_path(
total_files = 0
total_bytes = 0
if await aios.path.isdir(local_path):
if await local_path.is_dir():
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
file_path = Path(root) / f
size = (await aios.stat(file_path)).st_size
size = (await (file_path).stat()).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
@@ -515,9 +524,10 @@ async def download_progress_for_local_path(
)
# this function still has disgusting amounts of currying, but its better
async def download_shard(
shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], None],
progress_sender: Sender[RepoDownloadProgress],
max_parallel_downloads: int = 8,
skip_download: bool = False,
allow_patterns: list[str] | None = None,
@@ -526,7 +536,7 @@ async def download_shard(
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_meta.model_id)):
if await Path(str(shard.model_meta.model_id)).exists():
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
@@ -538,7 +548,7 @@ async def download_shard(
"/", "--"
)
if not skip_download:
await aios.makedirs(target_dir, exist_ok=True)
await target_dir.mkdir(parents=True, exist_ok=True)
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
@@ -558,9 +568,14 @@ async def download_shard(
)
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def huh(file: FileListEntry, recv: Receiver[tuple[int, int, bool]]):
async with recv:
async for curr, total, done in recv:
await progress_sender.send(on_progress_wrapper(file, curr, total, done))
def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
):
) -> RepoDownloadProgress:
start_time = (
file_progress[file.path].start_time
if file.path in file_progress
@@ -596,15 +611,12 @@ async def download_shard(
else "in_progress",
start_time=start_time,
)
on_progress(
return calculate_repo_progress(
shard,
calculate_repo_progress(
shard,
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
),
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
)
for file in filtered_file_list:
@@ -622,28 +634,31 @@ async def download_shard(
start_time=time.time(),
)
semaphore = asyncio.Semaphore(max_parallel_downloads)
semaphore = anyio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file: FileListEntry):
async def download_with_semaphore(
file: FileListEntry, sender: Sender[tuple[int, int, bool]]
):
async with semaphore:
await download_file_with_retry(
str(shard.model_meta.model_id),
revision,
file.path,
target_dir,
lambda curr_bytes, total_bytes, is_renamed: on_progress_wrapper(
file, curr_bytes, total_bytes, is_renamed
),
sender,
)
if not skip_download:
await asyncio.gather(
*[download_with_semaphore(file) for file in filtered_file_list]
)
async with anyio.create_task_group() as tg:
for file in filtered_file_list:
send, recv = channel[tuple[int, int, bool]](1)
tg.start_soon(download_with_semaphore, file, send)
tg.start_soon(huh, file, recv)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
)
on_progress(shard, final_repo_progress)
await progress_sender.send(final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
return target_dir / gguf.path, final_repo_progress
else:

View File

@@ -3,8 +3,7 @@ from fnmatch import fnmatch
from pathlib import Path
from typing import Callable, Generator, Iterable
import aiofiles
import aiofiles.os as aios
import anyio
from loguru import logger
from exo.shared.types.worker.shards import ShardMetadata
@@ -69,9 +68,9 @@ def get_hf_home() -> Path:
async def get_hf_token() -> str | None:
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
token_path = get_hf_home() / "token"
if await aios.path.exists(token_path):
async with aiofiles.open(token_path, "r") as f:
token_path = anyio.Path(get_hf_home() / "token")
if await token_path.exists():
async with await anyio.open_file(token_path, "r") as f:
return (await f.read()).strip()
return None

View File

@@ -1,7 +1,8 @@
import asyncio
from pathlib import Path
from typing import AsyncIterator, Callable
from anyio import Path
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.worker.shards import (

View File

@@ -1,7 +1,11 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from typing import AsyncIterator, Callable, Self
import anyio
from anyio import Path, create_task_group
from anyio.abc import CancelScope, TaskGroup
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
@@ -9,7 +13,54 @@ from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
@dataclass
class ShardDownloader2:
progress_sender: Sender[RepoDownloadProgress]
max_parallel_downloads = 8
# The last item on the shard stack is currently being downloaded
shard_stack: list[tuple[ShardMetadata, bool]] = field(
init=False, default_factory=list
)
_top_scope: CancelScope | None = field(init=False, default=None)
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
def start_shard(self, shard: ShardMetadata, config_only: bool = False):
self.shard_stack.append((shard, config_only))
# Cancel current tasks
if self._top_scope:
self._top_scope.cancel()
# Create a new scope
self._top_scope = CancelScope()
async def run(self):
async with self._tg:
await anyio.sleep_forever()
def shutdown(self):
self.progress_sender.close()
self._tg.cancel_scope.cancel()
async def _new_download(self, scope: CancelScope):
(shard, config_only) = self.shard_stack[-1]
with self.progress_sender.clone() as send, scope:
allow_patterns = ["config.json"] if config_only else None
target_dir, _ = await download_shard(
shard,
send,
max_parallel_downloads=self.max_parallel_downloads,
allow_patterns=allow_patterns,
)
return target_dir
@classmethod
def default(cls) -> tuple[Self, Receiver[RepoDownloadProgress]]:
send, recv = channel[RepoDownloadProgress](10)
return cls(send), recv
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?

View File

@@ -101,7 +101,13 @@ def mlx_distributed_init(
bound_instance: BoundInstance,
) -> mx.distributed.Group:
"""
Initialize the MLX distributed
Initialize the MLX distributed (runs in thread pool).
Either hosts or mlx_ibv_devices must be provided:
- hosts: traditional host-based connectivity using MLX_HOSTFILE
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
@@ -123,22 +129,22 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
jaccl_devices_json = json.dumps(jaccl_devices)
ibv_devices_json = json.dumps(ibv_devices)
with open(devices_file, "w") as f:
_ = f.write(jaccl_devices_json)
_ = f.write(ibv_devices_json)
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_JACCL_DEVICES: {jaccl_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = devices_file
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")
@@ -246,22 +252,15 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
# TODO: Let's move away from this custom logic to mlx_lm.load()
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [163586]
elif "glm" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [151336, 151329, 151338]
else:
eos_token_ids = None
tokenizer = cast(
TokenizerWrapper,
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
# TODO: HACK for Kimi K2 wrong eos token id
eos_token_ids=[163586]
if "kimi-k2" in shard_metadata.model_meta.model_id.lower()
else None,
),
)
assert isinstance(tokenizer, TokenizerWrapper)

View File

@@ -16,13 +16,15 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -31,7 +33,7 @@ from exo.shared.types.tasks import (
Task,
TaskStatus,
)
from exo.shared.types.topology import SocketConnection
from exo.shared.types.topology import Connection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -42,14 +44,14 @@ from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
class Worker:
@@ -83,7 +85,7 @@ class Worker:
self.state: State = State()
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
self._tg: TaskGroup | None = None
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -95,13 +97,37 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -114,17 +140,6 @@ class Worker:
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -144,6 +159,7 @@ class Worker:
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
assert self._tg
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
@@ -232,7 +248,8 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
if self._tg:
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
@@ -249,24 +266,24 @@ class Worker:
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
async def _nack_request(self, since_idx: int) -> None:
@@ -315,6 +332,7 @@ class Worker:
event_sender=self.event_sender.clone(),
)
self.runners[task.bound_instance.bound_runner_id] = runner
assert self._tg
self._tg.start_soon(runner.run)
return runner
@@ -339,10 +357,11 @@ class Worker:
# TODO: i hate callbacks
def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
_: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
shard = progress.shard
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
self.download_status[shard] = status
@@ -373,6 +392,7 @@ class Worker:
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
@@ -395,35 +415,28 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(
self.node_id, self.state.topology, self.state.node_profiles
)
conns = await check_reachable(self.state.topology)
for nid in conns:
for ip in conns[nid]:
edge = SocketConnection(
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(
TopologyEdgeCreated(
source=self.node_id, sink=nid, edge=edge
)
)
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
for nid, conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn, SocketConnection):
continue
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
nid, set()
if (
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
)
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)

View File

@@ -22,7 +22,7 @@ def entrypoint(
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.jaccl_devices) >= 2
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"

View File

@@ -0,0 +1,6 @@
from .profile import start_polling_memory_metrics, start_polling_node_metrics
__all__ = [
"start_polling_node_metrics",
"start_polling_memory_metrics",
]

View File

@@ -1,7 +1,6 @@
import platform
import shutil
from subprocess import CalledProcessError
from typing import cast
from anyio import run_process
from pydantic import BaseModel, ConfigDict, ValidationError
@@ -81,6 +80,7 @@ async def get_metrics_async() -> Metrics:
"""
path = _get_binary_path()
result = None
try:
# TODO: Keep Macmon running in the background?
result = await run_process([path, "pipe", "-s", "1"])
@@ -90,14 +90,8 @@ async def get_metrics_async() -> Metrics:
except ValidationError as e:
raise MacMonError(f"Error parsing JSON output: {e}") from e
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
raise MacMonError(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
) from e
if result:
raise MacMonError(
f"MacMon failed with return code {result.returncode}"
) from e
raise e

View File

@@ -0,0 +1,41 @@
import socket
from anyio import create_task_group, to_thread
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
# TODO: ref. api port
async def check_reachability(
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout
try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
except socket.gaierror:
# seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}")
return
finally:
sock.close()
if result == 0:
if target_node_id not in out:
out[target_node_id] = set()
out[target_node_id].add(target_ip)
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability, iface.ip_address, node.node_id, reachable
)
return reachable

View File

@@ -0,0 +1,108 @@
import os
import platform
from typing import Any, Callable, Coroutine
import anyio
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from .macmon import (
MacMonError,
Metrics,
)
from .macmon import (
get_metrics_async as macmon_get_metrics_async,
)
from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
)
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
if platform.system().lower() == "darwin":
return await macmon_get_metrics_async()
def get_memory_profile() -> MemoryPerformanceProfile:
"""Construct a MemoryPerformanceProfile using psutil"""
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
async def start_polling_memory_metrics(
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 0.5,
) -> None:
"""Continuously poll and emit memory-only metrics at a faster cadence.
Parameters
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
- poll_interval_s: interval between polls
"""
while True:
try:
mem = get_memory_profile()
await callback(mem)
except MacMonError as e:
logger.opt(exception=e).error("Memory Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
)
)
except MacMonError as e:
logger.opt(exception=e).error("Resource Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)

843
uv.lock generated
View File

File diff suppressed because it is too large Load Diff