54 Commits

Author SHA1 Message Date
Evan
7d2e828aba stop pinging loopback addresses 2025-12-27 12:17:35 +00:00
Evan
b5319d6b03 switch from sequence to map of connections 2025-12-27 12:00:22 +00:00
Evan
b988e08d69 pydantic types are now coherent 2025-12-27 11:20:49 +00:00
Evan
9bf5979f8a rebase fix 2025-12-27 01:04:12 +00:00
Sami Khan
91944383d3 parsing api fix 2025-12-27 01:04:12 +00:00
Evan
dcc6872724 code review followup 2025-12-27 01:04:12 +00:00
Evan
dccc2709c5 rename channel test 2025-12-24 19:52:52 +00:00
Evan
20d1246600 move macmon test 2025-12-24 19:52:52 +00:00
Evan
81bad9e01a cleanup after rebase 2025-12-24 19:52:52 +00:00
Evan
7ff67d0a28 dedup connections 2025-12-24 19:52:52 +00:00
Evan
c888b13d3f freeze those models 2025-12-24 19:52:52 +00:00
Evan
1f80705b56 format 2025-12-24 19:52:52 +00:00
Evan
b349330404 tidy 2025-12-24 19:52:52 +00:00
Evan
812ce47194 all mastet tests pass 2025-12-24 19:52:52 +00:00
Evan
643c6b8d28 ibv -> jaccl 2025-12-24 19:52:52 +00:00
Evan
4754f56bd4 tidying some horrible logic 2025-12-24 19:51:50 +00:00
Evan
66d01369b4 fix download test 2025-12-24 19:51:50 +00:00
Evan
d20d9e5fc8 fix all master tests except rdma placement 2025-12-24 19:51:50 +00:00
Evan
e67282282c fix topology tests 2025-12-24 19:51:33 +00:00
Evan
54daa9e2db bug 2025-12-24 19:51:33 +00:00
Evan
06125d1503 actually update the topology 2025-12-24 19:51:33 +00:00
Evan
505e756872 incorrect log 2025-12-24 19:51:33 +00:00
Evan
4cd3db0f6e handle an error 2025-12-24 19:51:33 +00:00
Evan
8b137a1e64 fix pydantic validation 2025-12-24 19:51:33 +00:00
Evan
4176c7ec25 type checks outside of tests, time to test 2025-12-24 19:51:33 +00:00
Evan
dbce607911 wuff 2025-12-24 19:51:33 +00:00
Evan
9949b93517 rework topology 2025-12-24 19:51:33 +00:00
Evan
f4feeff077 update placement 2025-12-24 19:51:33 +00:00
Evan
f529884344 mvp 2025-12-24 19:50:31 +00:00
Evan
df4c6ce24e tidy config 2025-12-24 19:50:31 +00:00
Jake Hillion
1c1792f5e8 mlx: update to 0.30.1 and align coordinator naming with MLX conventions
The Jaccl distributed backend requires MLX 0.30.1+, which includes the
RDMA over Thunderbolt support. The previous minimum version (0.29.3)
would fail at runtime with "The only valid values for backend are
'any', 'mpi' and 'ring' but 'jaccl' was provided."

Bump MLX dependency to >=0.30.1 and rename ibv_coordinators to
jaccl_coordinators to match MLX's naming conventions. This includes
the environment variable change from MLX_IBV_COORDINATOR to
MLX_JACCL_COORDINATOR.

Test plan:

Hardware setup: 3x Mac Studio M3 Ultra connected all-to-all with TB5

- Built a DMG [0]
- Installed on all Macs and started cluster.
- Requested a 2 node Tensor + MLX RDMA instance of Llama 3.3 70B (FP16).
- It started successfully.
- Queried the chat a few times. All was good. This didn't work
  previously.
- Killed the instance and spawned Pipeline + MLX Ring Llama 3.3 70B (FP16).
  Also started succesfully on two nodes and could be queried.

Still not working:
- Pipeline + MLX Ring on 3 nodes is failing. Haven't debugged that yet.

[0] https://github.com/exo-explore/exo/actions/runs/20467656904/job/58815275013
2025-12-24 16:47:01 +00:00
Jake Hillion
9afc1043ef exo: handle -c flag for multiprocessing helpers in frozen apps
When Python's multiprocessing spawns child processes on macOS (using the
"spawn" method), it also spawns helper processes like the resource tracker
by executing:

    ./frozen_app -c "from multiprocessing.resource_tracker import main; main()"

A frozen PyInstaller app doesn't understand `-c` natively - it just runs
main(). This causes the resource tracker to fail silently.

This adds a minimal `-c` handler that intercepts the flag, extracts the
inline code, and exec()s it before main() runs. This is required for the
Process() spawn in runner_supervisor.py to work correctly in the DMG.

Note that the pyinstaller docs say `freeze_support` is supposed to make
this work, but it doesn't.

Test plan:

Hardware setup: 3x Mac Studio M3 Ultra connected all-to-all with TB5

- Built a DMG[0].
- Installed on the Macs.
- Started an instance. Got an error this time in ~/.exo/exo.log. The
  last DMG from main doesn't show anything when an instance starts, this
  now shows the errors.

[0] https://github.com/exo-explore/exo/actions/runs/20464409279/job/58804485197
2025-12-23 17:08:50 +00:00
Evan Quiney
70c423f5e0 feat: conform to XDG Base Directory Specification on Linux (#988)
This is an extension of #964 with some cleanup.

---------

Co-authored-by: majiayu000 <1835304752@qq.com>
2025-12-23 17:02:55 +00:00
Jake Hillion
a24bdf7680 exo: enable multiprocessing support in PyInstaller bundles
Model loading fails silently when running from the DMG-packaged app,
despite working correctly with `uv run exo`. The bundled app spawns
child processes for model inference via multiprocessing, but these
processes fail to start in a frozen (PyInstaller) environment.

Add `freeze_support()` which is required for multiprocessing to work
in frozen applications.

Test plan:

Hardware setup: 3x Mac Studio M3 Ultra connected all-to-all with TB5

- Built a DMG using a modified .github/workflows/build-app.yml[0] to avoid
  publishing it.
- Installed on all 3 Macs, replacing the existing Exo.
- Downloaded Llama 3.3 70B (FP16).
- Downloaded Qwen3 Coder 235B A22B (8-bit).

Things that work now but didn't on the previous app:
- Topology looks good, previously there was no discovery.

What didn't work:
- Started an instance with Pipeline + MLX Ring + 3 Nodes. Failed.
- Started an instance with Tensor + MLX RDMA + 2 Nodes. Failed.

Will continue debugging the instance starting issues separately.

[0] https://github.com/exo-explore/exo/actions/runs/20461320368
2025-12-23 14:34:21 +00:00
Jake Hillion
e8855959c1 build-app: add branch trigger from named branch
As I've been working on the .dmg, it's become clear we need a way to
test changes to the app. It's too hard to reproduce the full DMG locally
to be reasonable and much more convenient to test if it's signed.

Add a feature to the build-app workflow where if you push specifically
to the `test-app` branch it'll perform a build. The version is stubbed
to `0.0.0-alpha.0`, which is about as low as it gets in semver so you'll
always update away from it automatically with Sparkle. The resulting DMG
won't be pushed to S3 but will be uploaded as a GitHub Actions artifact.

I've been using similar commits to this for a while for testing. It's
worked well and not interfered with auto updating at all.

Test plan:
- Pushed this change to `test-app`.
- Generated action at
  https://github.com/exo-explore/exo/actions/runs/20447213358/job/58752909332
- Installed the DMG on a Mac. It worked as intended.
2025-12-23 12:53:30 +00:00
Jake Hillion
0a7fe5d943 ci: migrate build-app to github hosted runners 2025-12-22 19:51:48 +00:00
rltakashige
51a5191ff3 format readme (#978)
## Motivation

README looks weird after last update. 
<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->
I actually checked the file on GitHub this time.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2025-12-22 18:06:27 +00:00
Evan Quiney
1efbd26388 add architecture.md, move images to docs/imgs (#968)
## Motivation

Documentation will make contribution easier and communicate our
development philosophy and decision process. Closes #967

## Changes

Added `architecture.md` to docs/ and moved the images out of docs and
into their own docs/imgs/ folder
2025-12-22 17:57:43 +00:00
Jake Hillion
02c915a88d pyproject: drop pathlib dependency 2025-12-22 17:52:44 +00:00
rltakashige
fc41bfa1f1 Add all prerequisites to README (#975)
## Motivation

Addresses #974 
```
INFO: pip is looking at multiple versions of exo to determine which version is compatible with other requirements. This could take a while.
ERROR: Could not find a version that satisfies the requirement exo-pyo3-bindings (from exo) (from versions: none)
ERROR: No matching distribution found for exo-pyo3-bindings
```

## Changes

Describes Rust dependency for building from source

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->
Tested locally and runs after this setup without exo-pyo3-bindings error

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2025-12-22 17:38:51 +00:00
Jake Hillion
dd0638b74d pyproject: add pyinstaller to dev-dependencies 2025-12-22 15:49:27 +00:00
majiayu000
e06830ce0b fix: update macOS app to use correct API port (52415)
Fixes #960

The macOS app was incorrectly using port 8000 instead of the default
exo API port 52415. This caused confusion as the README correctly
documents port 52415 but the app was connecting to a different port.
2025-12-22 13:24:09 +00:00
Jake Hillion
1df5079b98 ci: avoid pushing alpha build as latest 2025-12-22 13:00:49 +00:00
Nightguarder
1e75aeb2c2 Add Prerequisites to Readme (#936)
## Motivation
Users need to know what **prerequisites** they need in order to run exo.
Simple addition to docs prevents future raised issues.

## Changes

Updated ``README.md``:
- to include installation instructions for
**[uv](https://github.com/astral-sh/uv)** and
**[macmon](https://github.com/vladkens/macmon)**.

Updated ``CONTRIBUTING.md``:
-  to verify these prerequisites are met before starting development.

- Standardized on brew installation instructions for macOS users to keep
the guide simple.

## Why It Works

By listing these prerequisites upfront, users will set up their
environment correctly before attempting to run exo.

## Test Plan

### Manual Testing
MacBook Pro M4
- Verified that ``uv`` and ``macmon`` were missing initially, causing
failures
- after installing them via brew (as documented), uv run exo starts
successfully.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->

---------

Co-authored-by: Evan Quiney <evanev7@gmail.com>
2025-12-22 02:28:08 +00:00
Heath Dutton🕴️
c582bdd673 bugfix: Handle MacMon errors gracefully 2025-12-22 02:21:29 +00:00
Jake Hillion
1bae8ebbf6 ci: add build-app workflow 2025-12-22 02:12:30 +00:00
Alex Cheema
abaeb0323d Update README.md. (#956)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
Made a mistake on the merge of the last PR.
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2025-12-21 23:09:44 +00:00
Alex Cheema
7d15fbdaab readme tweaks5 (#954)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2025-12-21 22:48:35 +00:00
Alex Cheema
4a6e0fe171 Update README.md. (#949)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2025-12-21 18:31:23 +00:00
Olimbek Nizomov
f4792dce14 fix(downloads): use certifi for robust SSL certificate verification (#941)
fix(downloads): use certifi for robust SSL certificate verification

## Description
This change updates the SSL context creation in \`download_utils.py\` to
explicitly use the \`certifi\` CA bundle. This ensures that the
application has access to a reliable, up-to-date set of root
certificates, which is critical for verifying SSL connections to
external services like Hugging Face.

## Problem
On macOS environments (and potentially others), Python's default SSL
context often fails to locate the system's root certificates. This leads
to \`aiohttp.client_exceptions.ClientConnectorCertificateError\` errors
when attempting to download models.

## Solution
By passing \`cafile=certifi.where()\` to
\`ssl.create_default_context()\`, we force the application to use the
trusted certificate store provided by the \`certifi\` package. This is a
standard best practice for cross-platform Python applications and
resolves the verification failure.
2025-12-21 12:03:52 +00:00
rltakashige
a1b14a272e Extend eos_token_id fix for other models (#938)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
We currently use mlx_lm's load_tokenizer instead of load. This means
that some models are missing some configurations, such as eos_token_id.
This is clear for a model like GLM, which does not finish token
generation.

## Changes

<!-- Describe what you changed in detail -->
A small stopgap, to allow eos_token_ids to be added, and a TODO for us
to migrate to load. The reason we don't want to do this now is that a
solid testing framework is not configured in this repo yet.

## Why It Works

<!-- Explain why your approach solves the problem -->
It just uses the eos_token_ids I obtained from loading a tokenizer in
mlx_lm and calling `tokenizer.eos_token_ids` .

## Test Plan

### Manual Testing
Tested on several Macs.

### Automated Testing
None yet, as described.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2025-12-20 20:18:17 +00:00
Alex Cheema
f8483cfc18 Update README.md. (#932)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2025-12-19 21:23:25 +00:00
Alex Cheema
8bafd6fe68 Update README.md (#925)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2025-12-19 14:38:40 +00:00
Jake Hillion
f16afd723d nix: get rust build working on linux 2025-12-19 13:51:15 +00:00
72 changed files with 3173 additions and 1576 deletions

298
.github/workflows/build-app.yml vendored Normal file
View File

@@ -0,0 +1,298 @@
name: Build EXO macOS DMG
on:
push:
tags:
- "v*"
branches:
- "test-app"
jobs:
build-macos-app:
runs-on: "macos-26"
env:
SPARKLE_VERSION: 2.8.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
AWS_REGION: ${{ secrets.AWS_REGION }}
EXO_BUILD_NUMBER: ${{ github.run_number }}
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
steps:
# ============================================================
# Checkout and tag validation
# ============================================================
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Derive release version from tag
run: |
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
VERSION="0.0.0-alpha.0"
echo "IS_ALPHA=true" >> $GITHUB_ENV
else
VERSION="${GITHUB_REF_NAME#v}"
if [[ "$VERSION" == *-alpha* ]]; then
echo "IS_ALPHA=true" >> $GITHUB_ENV
else
echo "IS_ALPHA=false" >> $GITHUB_ENV
fi
fi
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
- name: Ensure tag commit is on main
if: github.ref_type == 'tag'
run: |
git fetch origin main
# Alpha tags can be on any branch, production tags must be on main
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Alpha tag detected, skipping main branch check"
elif ! git merge-base --is-ancestor origin/main HEAD; then
echo "Production tag must point to a commit on main"
exit 1
fi
# ============================================================
# Install dependencies
# ============================================================
- name: Select Xcode 26.2
run: |
sudo xcode-select -s /Applications/Xcode_26.2.app
if ! xcrun -f metal >/dev/null 2>&1; then
echo "Metal toolchain is not installed."
exit 1
fi
- name: Install Homebrew packages
run: brew install just awscli macmon
- name: Install UV
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: uv.lock
- name: Setup Python
run: |
uv python install
uv sync --locked
- name: Build dashboard
run: |
cd dashboard
npm ci
npm run build
- name: Install Sparkle CLI
run: |
CLI_URL="${SPARKLE_CLI_URL:-https://github.com/sparkle-project/Sparkle/releases/download/${SPARKLE_VERSION}/Sparkle-${SPARKLE_VERSION}.tar.xz}"
echo "Downloading Sparkle CLI from: $CLI_URL"
mkdir -p /tmp/sparkle
curl --fail --location --output /tmp/sparkle.tar.xz "$CLI_URL"
tar -xJf /tmp/sparkle.tar.xz -C /tmp/sparkle --strip-components=1
echo "SPARKLE_BIN=/tmp/sparkle/bin" >> $GITHUB_ENV
- name: Prepare code-signing keychain
env:
MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }}
MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
PROVISIONING_PROFILE: ${{ secrets.PROVISIONING_PROFILE }}
run: |
KEYCHAIN_PATH="$HOME/Library/Keychains/build.keychain-db"
# Create fresh keychain
security create-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$KEYCHAIN_PATH"
# Disable auto-lock (no timeout, no lock-on-sleep)
security set-keychain-settings "$KEYCHAIN_PATH"
# Add to search list while preserving existing keychains
security list-keychains -d user -s "$KEYCHAIN_PATH" $(security list-keychains -d user | tr -d '"')
# Set as default and unlock
security default-keychain -s "$KEYCHAIN_PATH"
security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$KEYCHAIN_PATH"
# Import certificate with full access for codesign
echo "$MACOS_CERTIFICATE" | base64 --decode > /tmp/cert.p12
security import /tmp/cert.p12 -k "$KEYCHAIN_PATH" -P "$MACOS_CERTIFICATE_PASSWORD" \
-T /usr/bin/codesign -T /usr/bin/security -T /usr/bin/productbuild
rm /tmp/cert.p12
# Allow codesign to access the key without prompting
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$MACOS_CERTIFICATE_PASSWORD" "$KEYCHAIN_PATH"
# Verify keychain is unlocked and identity is available
echo "Verifying signing identity..."
security find-identity -v -p codesigning "$KEYCHAIN_PATH"
# Setup provisioning profile
mkdir -p "$HOME/Library/Developer/Xcode/UserData/Provisioning Profiles"
echo "$PROVISIONING_PROFILE" | base64 --decode > "$HOME/Library/Developer/Xcode/UserData/Provisioning Profiles/EXO.provisionprofile"
# Export keychain path for other steps
echo "BUILD_KEYCHAIN_PATH=$KEYCHAIN_PATH" >> $GITHUB_ENV
# ============================================================
# Build the bundle
# ============================================================
- name: Build PyInstaller bundle
run: uv run pyinstaller packaging/pyinstaller/exo.spec
- name: Build Swift app
env:
MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
run: |
cd app/EXO
security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$BUILD_KEYCHAIN_PATH"
SIGNING_IDENTITY=$(security find-identity -v -p codesigning "$BUILD_KEYCHAIN_PATH" | awk -F '"' '{print $2}')
xcodebuild clean build \
-scheme EXO \
-configuration Release \
-derivedDataPath build \
MARKETING_VERSION="$RELEASE_VERSION" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
EXO_BUILD_TAG="$RELEASE_VERSION" \
EXO_BUILD_COMMIT="$GITHUB_SHA" \
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
mkdir -p ../../output
cp -R build/Build/Products/Release/EXO.app ../../output/EXO.app
- name: Inject PyInstaller runtime
run: |
rm -rf output/EXO.app/Contents/Resources/exo
mkdir -p output/EXO.app/Contents/Resources
cp -R dist/exo output/EXO.app/Contents/Resources/exo
- name: Codesign PyInstaller runtime
env:
MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
run: |
cd output
security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$BUILD_KEYCHAIN_PATH"
SIGNING_IDENTITY=$(security find-identity -v -p codesigning "$BUILD_KEYCHAIN_PATH" | awk -F '"' '{print $2}')
RUNTIME_DIR="EXO.app/Contents/Resources/exo"
find "$RUNTIME_DIR" -type f \( -perm -111 -o -name "*.dylib" -o -name "*.so" \) -print0 |
while IFS= read -r -d '' file; do
/usr/bin/codesign --force --timestamp --options runtime \
--sign "$SIGNING_IDENTITY" "$file"
done
- name: Sign, notarize, and create DMG
env:
MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}
APPLE_NOTARIZATION_USERNAME: ${{ secrets.APPLE_NOTARIZATION_USERNAME }}
APPLE_NOTARIZATION_PASSWORD: ${{ secrets.APPLE_NOTARIZATION_PASSWORD }}
APPLE_NOTARIZATION_TEAM: ${{ secrets.APPLE_NOTARIZATION_TEAM }}
run: |
cd output
security unlock-keychain -p "$MACOS_CERTIFICATE_PASSWORD" "$BUILD_KEYCHAIN_PATH"
SIGNING_IDENTITY=$(security find-identity -v -p codesigning "$BUILD_KEYCHAIN_PATH" | awk -F '"' '{print $2}')
/usr/bin/codesign --deep --force --timestamp --options runtime \
--sign "$SIGNING_IDENTITY" EXO.app
mkdir -p dmg-root
cp -R EXO.app dmg-root/
ln -s /Applications dmg-root/Applications
DMG_NAME="EXO-${RELEASE_VERSION}.dmg"
hdiutil create -volname "EXO" -srcfolder dmg-root -ov -format UDZO "$DMG_NAME"
/usr/bin/codesign --force --timestamp --options runtime \
--sign "$SIGNING_IDENTITY" "$DMG_NAME"
if [[ -n "$APPLE_NOTARIZATION_USERNAME" ]]; then
SUBMISSION_OUTPUT=$(xcrun notarytool submit "$DMG_NAME" \
--apple-id "$APPLE_NOTARIZATION_USERNAME" \
--password "$APPLE_NOTARIZATION_PASSWORD" \
--team-id "$APPLE_NOTARIZATION_TEAM" \
--wait --timeout 15m 2>&1)
echo "$SUBMISSION_OUTPUT"
SUBMISSION_ID=$(echo "$SUBMISSION_OUTPUT" | awk 'tolower($1)=="id:" && $2 ~ /^[0-9a-fA-F-]+$/ {print $2; exit}')
STATUS=$(echo "$SUBMISSION_OUTPUT" | awk 'tolower($1)=="status:" {print $2; exit}')
if [[ -n "$SUBMISSION_ID" ]]; then
xcrun notarytool log "$SUBMISSION_ID" \
--apple-id "$APPLE_NOTARIZATION_USERNAME" \
--password "$APPLE_NOTARIZATION_PASSWORD" \
--team-id "$APPLE_NOTARIZATION_TEAM" > notarization-log.txt || true
echo "===== Notarization Log ====="
cat notarization-log.txt
echo "============================"
fi
if [[ "$STATUS" != "Accepted" ]]; then
echo "Notarization failed with status: ${STATUS:-Unknown}"
exit 1
fi
xcrun stapler staple "$DMG_NAME"
fi
- name: Generate Sparkle appcast
env:
SPARKLE_DOWNLOAD_PREFIX: ${{ env.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
IS_ALPHA: ${{ env.IS_ALPHA }}
run: |
set -euo pipefail
cd output
DOWNLOAD_PREFIX="${SPARKLE_DOWNLOAD_PREFIX:-https://assets.exolabs.net}"
echo "$SPARKLE_ED25519_PRIVATE" > sparkle_ed25519.key
chmod 600 sparkle_ed25519.key
CHANNEL_FLAG=""
if [[ "$IS_ALPHA" == "true" ]]; then
CHANNEL_FLAG="--channel alpha"
echo "Generating appcast for alpha channel"
fi
$SPARKLE_BIN/generate_appcast \
--ed-key-file sparkle_ed25519.key \
--download-url-prefix "$DOWNLOAD_PREFIX" \
$CHANNEL_FLAG \
.
# ============================================================
# Upload artifacts
# ============================================================
- name: Upload DMG
uses: actions/upload-artifact@v4
with:
name: EXO-dmg-${{ env.RELEASE_VERSION }}
path: output/EXO-${{ env.RELEASE_VERSION }}.dmg
- name: Upload to S3
if: env.SPARKLE_S3_BUCKET != '' && github.ref_type == 'tag'
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION: ${{ env.AWS_REGION }}
SPARKLE_S3_BUCKET: ${{ env.SPARKLE_S3_BUCKET }}
SPARKLE_S3_PREFIX: ${{ env.SPARKLE_S3_PREFIX }}
IS_ALPHA: ${{ env.IS_ALPHA }}
run: |
set -euo pipefail
cd output
PREFIX="${SPARKLE_S3_PREFIX:-}"
if [[ -n "$PREFIX" && "${PREFIX: -1}" != "/" ]]; then
PREFIX="${PREFIX}/"
fi
DMG_NAME="EXO-${RELEASE_VERSION}.dmg"
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
if [[ "$IS_ALPHA" != "true" ]]; then
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
fi
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache

View File

@@ -5,10 +5,21 @@ Thank you for your interest in contributing to EXO!
## Getting Started
To run EXO from source:
**Prerequisites:**
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
```bash
brew install uv
```
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
```bash
brew install macmon
```
```bash
git clone https://github.com/exo-explore/exo.git
cd exo/dashboard
npm install && npm run build
npm install && npm run build && cd ..
uv run exo
```

214
README.md
View File

@@ -1,55 +1,223 @@
<div align="center">
<picture>
<source media="(prefers-color-scheme: light)" srcset="/docs/exo-logo-black-bg.jpg">
<img alt="exo logo" src="/docs/exo-logo-transparent.png" width="50%" height="50%">
<source media="(prefers-color-scheme: light)" srcset="/docs/imgs/exo-logo-black-bg.jpg">
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
[![GitHub Repo stars](https://img.shields.io/github/stars/exo-explore/exo)](https://github.com/exo-explore/exo/stargazers)
[![License: Apache-2.0](https://img.shields.io/badge/License-Apache2.0-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0.html)
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<p align="center">
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
</p>
</div>
---
exo connects all your devices into an AI cluster. Not only does exo enable running models larger than would fit on a single device, but with [day-0 support for RDMA over Thunderbolt](https://x.com/exolabs/status/2001817749744476256?s=20), makes models run faster as you add more devices.
## Features
- **Automatic Device Discovery**: Devices running EXO automatically discover each other on your local network - no manual configuration.
- **RDMA over Thunderbolt**: EXO ships with Day-0 support for RDMA over Thunderbolt 5, enabling 99% reduction in latency between devices.
- **Auto Parallel**: EXO automatically splits up models to run distributed across devices.
- **Tensor Parallelism**: EXO supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: EXO uses [ml-explore/mlx](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
- **Automatic Device Discovery**: Devices running exo automatically discover each other - no manual configuration.
- **RDMA over Thunderbolt**: exo ships with [day-0 support for RDMA over Thunderbolt 5](https://x.com/exolabs/status/2001817749744476256?s=20), enabling 99% reduction in latency between devices.
- **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link.
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Benchmarks
<details>
<summary>Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-1-qwen3-235b.jpeg" alt="Benchmark - Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
<details>
<summary>DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-2-deepseek-3.1-671b.jpeg" alt="Benchmark - DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
<details>
<summary>Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-3-kimi-k2-thinking.jpeg" alt="Benchmark - Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
---
## Quick Start
You need at least one Mac device running macOS Tahoe 26.2 (released December 12th 2025).
Devices running exo automatically discover each other, without needing any manual configuration. Each device provides an API and a dashboard for interacting with your cluster (runs at `http://localhost:52415`).
You can download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-latest.dmg). It will ask for permission to modify system settings and install a new Network profile. We hope to make this smoother in the future!
There are two ways to run exo:
To run from source, clone the repo, build the dashboard with `cd dashboard && npm install && npm run build` and run `uv run exo`.
### Run from Source (Mac & Linux)
After starting with either of these methods go to `http://localhost:52415` in your browser, and you'll have EXO.
**Prerequisites:**
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
```bash
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
Clone the repo, build the dashboard, and run exo:
```bash
# Clone exo
git clone https://github.com/exo-explore/exo
# Build dashboard
cd exo/dashboard && npm install && npm run build && cd ..
# Run exo
uv run exo
```
This starts the exo dashboard and API at http://localhost:52415/
### macOS App
exo ships a macOS app that runs in the background on your Mac.
<img src="docs/imgs/macos-app-one-macbook.png" alt="exo macOS App - running on a MacBook" width="35%" />
The macOS app requires macOS Tahoe 26.2 or later.
Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-latest.dmg).
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
---
## Requirements
### Using the API
- Mac devices with Apple Silicon (M-series chips)
- macOS Tahoe 26.2 or later (released December 12th 2025)
- Older macOS versions may work without RDMA, but only 26.2+ is officially supported
- For RDMA over Thunderbolt: a high quality Thunderbolt 5 cable
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
We intend to add support for other hardware platforms [like the DGX Spark](https://x.com/exolabs/status/1978525767739883736) in the future, but they are not currently supported. If you'd like support for a new hardware platform, please search for an existing feature request and add a thumbs up so we know what hardware is important to the community.
---
**1. Preview instance placements**
The `/instance/previews` endpoint will preview all valid placements for your model.
```bash
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
```
Sample response:
```json
{
"previews": [
{
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"sharding": "Pipeline",
"instance_meta": "MlxRing",
"instance": {...},
"memory_delta_by_node": {"local": 729808896},
"error": null
}
// ...possibly more placements...
]
}
```
This will return all valid placements for this model. Pick a placement that you like.
To pick the first one, pipe into `jq`:
```bash
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1
```
---
**2. Create a model instance**
Send a POST to `/instance` with your desired placement in the `instance` field (the full payload must match types as in `CreateInstanceParams`), which you can copy from step 1:
```bash
curl -X POST http://localhost:52415/instance \
-H 'Content-Type: application/json' \
-d '{
"instance": {...}
}'
```
Sample response:
```json
{
"message": "Command received.",
"command_id": "e9d1a8ab-...."
}
```
---
**3. Send a chat completion**
Now, make a POST to `/v1/chat/completions` (the same format as OpenAI's API):
```bash
curl -N -X POST http://localhost:52415/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"messages": [
{"role": "user", "content": "What is Llama 3.2 1B?"}
],
"stream": true
}'
```
---
**4. Delete the instance**
When you're done, delete the instance by its ID (find it via `/state` or `/instance` endpoints):
```bash
curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
```
**Other useful API endpoints*:**
- List all models: `curl http://localhost:52415/models`
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
---
## Hardware Accelerator Support
On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working on extending hardware accelerator support. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community.
---
## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to EXO.
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.

View File

@@ -19,6 +19,7 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -212,7 +212,7 @@ struct ContentView: View {
private var dashboardButton: some View {
Button {
guard let url = URL(string: "http://localhost:8000/") else { return }
guard let url = URL(string: "http://localhost:52415/") else { return }
NSWorkspace.shared.open(url)
} label: {
HStack {

View File

@@ -35,7 +35,7 @@ struct BugReportService {
}
func sendReport(
baseURL: URL = URL(string: "http://127.0.0.1:8000")!,
baseURL: URL = URL(string: "http://127.0.0.1:52415")!,
now: Date = Date(),
isManual: Bool = false
) async throws -> BugReportOutcome {

View File

@@ -15,7 +15,7 @@ final class ClusterStateService: ObservableObject {
private let endpoint: URL
init(
baseURL: URL = URL(string: "http://127.0.0.1:8000")!,
baseURL: URL = URL(string: "http://127.0.0.1:52415")!,
session: URLSession = .shared
) {
self.baseURL = baseURL

View File

@@ -861,6 +861,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -900,6 +901,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1516,6 +1518,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1525,6 +1528,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1937,6 +1941,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2607,6 +2612,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2794,6 +2800,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2938,6 +2945,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2959,6 +2967,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -96,7 +96,7 @@ interface RawNodeProfile {
interface RawTopologyNode {
nodeId: string;
nodeProfile: RawNodeProfile;
nodeProfile?: RawNodeProfile;
}
interface RawTopologyConnection {
@@ -105,9 +105,13 @@ interface RawTopologyConnection {
sendBackMultiaddr?: { multiaddr?: string; address?: string; ip_address?: string } | string;
}
// Connection can be an object or a tuple [source, target, metadata]
type RawConnectionItem = RawTopologyConnection | [string, string, { sinkMultiaddr?: { ip_address?: string; address?: string } }?];
interface RawTopology {
nodes: RawTopologyNode[];
connections?: RawTopologyConnection[];
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
connections?: RawConnectionItem[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
@@ -198,9 +202,17 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === 'string' ? node : node.nodeId;
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode = typeof node === 'object' ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
@@ -238,7 +250,7 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
}
}
nodes[node.nodeId] = {
nodes[nodeId] = {
system_info: {
model_id: profile?.modelId ?? 'Unknown',
chip: profile?.chipId,
@@ -260,14 +272,34 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
};
}
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
for (const conn of raw.connections || []) {
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
if (conn.localNodeId === conn.sendBackNodeId) continue;
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
let localNodeId: string | undefined;
let sendBackNodeId: string | undefined;
let sendBackMultiaddr: { multiaddr?: string; address?: string; ip_address?: string } | string | undefined;
// Check if it's a tuple format [source, target, metadata]
if (Array.isArray(conn)) {
localNodeId = conn[0] as string;
sendBackNodeId = conn[1] as string;
const metadata = conn[2] as { sinkMultiaddr?: { ip_address?: string; address?: string } } | undefined;
if (metadata?.sinkMultiaddr) {
sendBackMultiaddr = metadata.sinkMultiaddr;
}
} else {
// Object format with localNodeId/sendBackNodeId
localNodeId = conn.localNodeId;
sendBackNodeId = conn.sendBackNodeId;
sendBackMultiaddr = conn.sendBackMultiaddr;
}
if (!localNodeId || !sendBackNodeId) continue;
if (localNodeId === sendBackNodeId) continue;
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
let sendBackIp: string | undefined;
if (conn.sendBackMultiaddr) {
const multi = conn.sendBackMultiaddr;
if (sendBackMultiaddr) {
const multi = sendBackMultiaddr;
if (typeof multi === 'string') {
sendBackIp = extractIpFromMultiaddr(multi);
} else {
@@ -276,8 +308,8 @@ function transformTopology(raw: RawTopology, profiles?: RawNodeProfiles): Topolo
}
edges.push({
source: conn.localNodeId,
target: conn.sendBackNodeId,
source: localNodeId,
target: sendBackNodeId,
sendBackIp
});
}

64
docs/architecture.md Normal file
View File

@@ -0,0 +1,64 @@
# EXO Architecture overview
EXO uses an _Event Sourcing_ architecture, and Erlang-style _message passing_. To facilitate this, we've written a channel library extending anyio channels with inspiration from tokio::sync::mpsc.
Each logical module - designed to be functional independently of the others - communicates with the rest of the system by sending messages on topics.
## Systems
There are currently 5 major systems:
- Master
Executes placement and orders events through a single writer
- Worker
Schedules work on a node, gathers system information, etc.#
- Runner
Executes inference jobs (for now) in an isolated process from the worker for fault-tolerance.
- API
Runs a python webserver for exposing state and commands to client applications
- Election
Implements a distributed algorithm for master election in unstable networking conditions
## Topics
There are currently 5 topics:
- Commands
The API and Worker instruct the master when the event log isn't sufficient. Namely placement and catchup requests go through Commands atm.
- Local Events
All nodes write events here, the master reads those events and orders them
- Global Events
The master writes events here, all nodes read from this topic and fold the produced events into their `State`
- Election Messages
Before establishing a cluster, nodes communicate here to negotiate a master node.
- Connection Messages
The networking system write mdns-discovered hardware connections here.
## Event Sourcing
Lots has been written about event sourcing, but it lets us centralize faulty connections and message ACKing with the following model.
Whenever a device produces side effects, it captures those side effects in an `Event`. `Event`s are then "applied" to their model of `State`, which is globally distributed across the cluster. Whenever a command is received, it is combined with state to produce side effects, captured in yet more events. The rule of thumb is "`Event`s are past tense, `Command`s are imperative". Telling a node to perform some action like "place this model" or "Give me a copy of the event log" is represented by a command (The worker's `Task`s are also commands), while "this node is using 300GB of ram" is an event. Notably, `Event`s SHOULD never cause side effects on their own. There are a few exceptions to this, we're working out the specifics of generalizing the distributed event sourcing model to make it better suit our needs
## Purity
A significant goal of the current design is to make data flow explicit. Classes should either represent simple data (`CamelCaseModel`s typically, and `TaggedModel`s for unions) or active `System`s (Erlang `Actor`s), with all transformations of that data being "referentially transparent" - destructure and construct new data, don't mutate in place. We have had varying degrees of success with this, and are still exploring where purity makes sense.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 514 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 519 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 486 KiB

View File

Before

Width:  |  Height:  |  Size: 7.9 KiB

After

Width:  |  Height:  |  Size: 7.9 KiB

View File

Before

Width:  |  Height:  |  Size: 295 KiB

After

Width:  |  Height:  |  Size: 295 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 171 KiB

View File

@@ -91,6 +91,10 @@
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
@@ -100,6 +104,11 @@
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list

View File

@@ -0,0 +1,118 @@
# -*- mode: python ; coding: utf-8 -*-
import importlib.util
import shutil
from pathlib import Path
from PyInstaller.utils.hooks import collect_submodules
PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
raise SystemExit(f"Unable to locate Exo entrypoint: {ENTRYPOINT}")
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
block_cipher = None
def _module_directory(module_name: str) -> Path:
spec = importlib.util.find_spec(module_name)
if spec is None:
raise SystemExit(f"Module '{module_name}' is not available in the current environment.")
if spec.submodule_search_locations:
return Path(next(iter(spec.submodule_search_locations))).resolve()
if spec.origin:
return Path(spec.origin).resolve().parent
raise SystemExit(f"Unable to determine installation directory for '{module_name}'.")
MLX_PACKAGE_DIR = _module_directory("mlx")
MLX_LIB_DIR = MLX_PACKAGE_DIR / "lib"
if not MLX_LIB_DIR.is_dir():
raise SystemExit(f"mlx Metal libraries are missing: {MLX_LIB_DIR}")
def _safe_collect(package_name: str) -> list[str]:
try:
return collect_submodules(package_name)
except ImportError:
return []
HIDDEN_IMPORTS = sorted(
set(
collect_submodules("mlx")
+ _safe_collect("mlx_lm")
+ _safe_collect("transformers")
)
)
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]
MACMON_PATH = shutil.which("macmon")
if MACMON_PATH is None:
raise SystemExit(
"macmon binary not found in PATH. "
"Install it via: brew install macmon"
)
BINARIES: list[tuple[str, str]] = [
(MACMON_PATH, "."),
]
a = Analysis(
[str(ENTRYPOINT)],
pathex=[str(SOURCE_ROOT)],
binaries=BINARIES,
datas=DATAS,
hiddenimports=HIDDEN_IMPORTS,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name="exo",
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=False,
upx_exclude=[],
name="exo",
)

View File

@@ -5,17 +5,31 @@ description = "Exo"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"typeguard>=4.4.4",
"pydantic>=2.11.7",
"httpx>=0.28.1",
"base58>=2.1.1",
"cryptography>=45.0.5",
"fastapi>=0.116.1",
"filelock>=3.18.0",
"aiosqlite>=0.21.0",
"networkx>=3.5",
"protobuf>=6.32.0",
"rich>=14.1.0",
"rustworkx>=0.17.1",
"sqlmodel>=0.0.24",
"sqlalchemy[asyncio]>=2.0.43",
"greenlet>=3.2.4",
"huggingface-hub>=0.33.4",
"psutil>=7.0.0",
"loguru>=0.7.3",
"textual>=5.3.0",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx>=0.29.3",
"bidict>=0.23.1",
"mlx>=0.30.1",
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
@@ -29,9 +43,11 @@ exo = "exo.main:main"
# dependencies only required for development
[dependency-groups]
dev = [
"pyinstaller>=6.17.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
"pytest-env",
"ruff>=0.11.13",
"trio>=0.32.0",
]
# mlx[cuda] requires a newer version of mlx. the ideal on linux is: default to mlx[cpu] unless[cuda] specified.
@@ -108,4 +124,12 @@ extend-exclude = ["shared/protobufs/**", "*mlx_typings/**", "rust/exo_pyo3_bindi
extend-select = ["I", "N", "B", "A", "PIE", "SIM"]
[tool.pytest.ini_options]
anyio_mode = "auto"
pythonpath = "."
asyncio_mode = "auto"
markers = [
"slow: marks tests as slow (deselected by default)"
]
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"

View File

@@ -1,4 +1,39 @@
from __future__ import annotations
import sys
from collections.abc import Sequence
from multiprocessing import freeze_support
from typing import Final
from exo.main import main
INLINE_CODE_FLAG: Final[str] = "-c"
def _maybe_run_inline_code(argv: Sequence[str]) -> bool:
"""
Reproduce the bare minimum of Python's `-c` flag so multiprocessing
helper processes (for example the resource tracker) can execute.
"""
try:
flag_index = argv.index(INLINE_CODE_FLAG)
except ValueError:
return False
code_index = flag_index + 1
if code_index >= len(argv):
return False
inline_code = argv[code_index]
sys.argv = ["-c", *argv[code_index + 1 :]]
namespace: dict[str, object] = {"__name__": "__main__"}
exec(inline_code, namespace, namespace)
return True
if __name__ == "__main__":
if _maybe_run_inline_code(sys.argv):
sys.exit(0)
freeze_support()
main()

View File

@@ -207,6 +207,7 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -262,6 +263,7 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -426,9 +428,8 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
return total_available

View File

@@ -158,6 +158,7 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -200,9 +201,7 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:

View File

@@ -6,10 +6,11 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_shard_assignments,
get_smallest_cycles,
)
@@ -19,10 +20,10 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import Host
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.topology import NodeInfo
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -51,19 +52,16 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
cycles = topology.get_cycles() + [[node] for node in all_nodes]
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
candidate_cycles, node_profiles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
@@ -71,13 +69,15 @@ def place_instance(
smallest_tb_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
if topology.get_subgraph_from_nodes(
[node.node_id for node in cycle]
).is_thunderbolt_cycle([node.node_id for node in cycle])
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node.node_id) for node in cycle)
@@ -86,11 +86,7 @@ def place_instance(
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
(node.node_profile.memory.ram_available for node in cycle),
start=Memory(),
),
)
@@ -99,14 +95,16 @@ def place_instance(
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
[node.node_id for node in selected_cycle]
)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected ibv for a single node instance; falling back to MlxRing"
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -114,20 +112,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
cycle_digraph,
)
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
selected_cycle,
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle[0].node_id,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
ibv_coordinators=mlx_ibv_coordinators,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)

View File

@@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import TypeGuard, cast
from collections.abc import Generator, Mapping
from loguru import logger
from pydantic import BaseModel
@@ -9,7 +8,7 @@ from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -24,27 +23,32 @@ class NodeWithProfile(BaseModel):
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[list[NodeWithProfile]]:
filtered_cycles: list[list[NodeWithProfile]] = []
for cycle in cycles:
if not narrow_all_nodes(cycle):
if not all(node in node_profiles for node in cycle):
continue
total_mem = sum(
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[NodeInfo], cycle))
filtered_cycles.append(
[
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
for node in cycle
]
)
return filtered_cycles
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
def get_smallest_cycles(
cycles: list[list[NodeWithProfile]],
) -> list[list[NodeWithProfile]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
@@ -135,11 +139,9 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeInfo],
selected_cycle: list[NodeWithProfile],
sharding: Sharding,
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
@@ -176,17 +178,16 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
for src, sink, connection in cycle_digraph.list_connections():
if not isinstance(connection, SocketConnection):
continue
if src == current_node and sink == next_node:
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
@@ -194,8 +195,7 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
return hosts
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
def get_mlx_jaccl_devices_matrix(
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -204,6 +204,7 @@ def get_mlx_ibv_devices_matrix(
to device j, or None if no connection exists or no interface name is found.
Diagonal elements are always None.
"""
selected_cycle = list(cycle_digraph.list_nodes())
num_nodes = len(selected_cycle)
matrix: list[list[str | None]] = [
[None for _ in range(num_nodes)] for _ in range(num_nodes)
@@ -214,86 +215,55 @@ def get_mlx_ibv_devices_matrix(
if i == j:
continue
# Find the IP J uses to talk to I
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
break
else:
logger.warning(
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current ibv backend requires all-to-all rdma connections"
"Current jaccl backend requires all-to-all RDMA connections"
)
return matrix
def _find_connection_ip(
node_i: NodeInfo,
node_j: NodeInfo,
node_i: NodeId,
node_j: NodeId,
cycle_digraph: Topology,
) -> Generator[str]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address
def _find_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def get_mlx_ibv_coordinators(
selected_cycle: list[NodeInfo],
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX IBV (rank 0 device).
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
selected_cycle = list(cycle_digraph.list_nodes())
logger.info(f"Selecting coordinator: {coordinator}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
return "0.0.0.0"
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
return ip
logger.warning(
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}

View File

@@ -1,67 +1,36 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
return _create_connection
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)

View File

@@ -2,6 +2,7 @@ from datetime import datetime, timezone
from typing import Sequence
import anyio
import pytest
from loguru import logger
from exo.master.main import Master
@@ -18,15 +19,13 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodePerformanceMeasured,
NodeGatheredInfo,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
MemoryUsage,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -39,6 +38,7 @@ from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
@@ -81,21 +81,14 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodePerformanceMeasured(
NodeGatheredInfo(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
)
),
@@ -159,7 +152,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[1].event, InstanceCreated)
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()

View File

@@ -1,5 +1,3 @@
from typing import Callable
import pytest
from loguru import logger
@@ -7,14 +5,20 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_connection,
create_node_profile,
create_rdma_connection,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -26,11 +30,6 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -74,30 +73,33 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_a))
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(2))
topology.add_connection(node_id_c, node_id_a, create_connection(3))
# act
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
# assert
assert len(placements) == 1
@@ -123,12 +125,11 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_exact_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -137,7 +138,7 @@ def test_get_instance_placements_one_node_exact_fit(
n_layers=10,
),
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -148,12 +149,11 @@ def test_get_instance_placements_one_node_exact_fit(
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1001 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -162,7 +162,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
n_layers=10,
),
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -173,12 +173,11 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
def test_get_instance_placements_one_node_not_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
cic = place_instance_command(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
@@ -189,7 +188,7 @@ def test_get_instance_placements_one_node_not_fit(
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {})
place_instance(cic, topology, {}, profiles)
def test_get_transition_events_no_change(instance: Instance):
@@ -235,190 +234,102 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_prioritizes_leaf_cycle_with_less_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
# cycle having LESS total memory than D-E-F. The algorithm should still choose
# the cycle that contains a leaf node.
# arrange
topology = Topology()
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_meta.storage_size = Memory.from_bytes(1000)
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
node_id_x = NodeId()
node_id_y = NodeId()
node_id_z = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Daisy chain topology
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_a, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(1))
topology.add_connection(node_id_c, node_id_b, create_connection(1))
topology.add_connection(node_id_c, node_id_d, create_connection(1))
topology.add_connection(node_id_d, node_id_c, create_connection(1))
# Extra nodes with tiny memory so they can't form singleton placements
topology.add_node(create_node(10, node_id_x))
topology.add_node(create_node(10, node_id_y))
topology.add_node(create_node(10, node_id_z))
# Build directed cycles
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_d))
# Add extra outgoing edges from D/E/F so none of them are leaves
topology.add_connection(create_connection(node_id_d, node_id_x))
topology.add_connection(create_connection(node_id_e, node_id_y))
topology.add_connection(create_connection(node_id_f, node_id_z))
logger.info(list(topology.list_connections()))
cic = place_instance_command(
model_meta=model_meta,
)
# Act
placements = place_instance(cic, topology, {})
# act
placements = place_instance(cic, topology, {}, profiles)
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
# D-E-F has more total memory.
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
instance = list(placements.values())[0]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(node_id_c, node_id_d)
)
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="192.168.1.100",
)
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
topology.add_connection(node_a, node_b, create_rdma_connection(3))
topology.add_connection(node_b, node_c, create_rdma_connection(4))
topology.add_connection(node_c, node_a, create_rdma_connection(5))
topology.add_connection(node_b, node_a, create_rdma_connection(3))
topology.add_connection(node_c, node_b, create_rdma_connection(4))
topology.add_connection(node_a, node_c, create_rdma_connection(5))
topology.add_connection(node_a, node_b, ethernet_conn)
topology.add_connection(node_b, node_c, ethernet_conn)
topology.add_connection(node_c, node_a, ethernet_conn)
topology.add_connection(node_a, node_c, ethernet_conn)
topology.add_connection(node_b, node_a, ethernet_conn)
topology.add_connection(node_c, node_b, ethernet_conn)
cic = PlaceInstance(
sharding=Sharding.Tensor,
@@ -428,7 +339,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
min_nodes=1,
)
placements = place_instance(cic, topology, {})
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -436,10 +347,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.ibv_devices is not None
assert instance.ibv_coordinators is not None
assert instance.jaccl_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.ibv_devices
matrix = instance.jaccl_devices
assert len(matrix) == 3
for i in range(3):
@@ -448,21 +359,21 @@ def test_tensor_rdma_backend_connectivity_matrix(
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
# Verify coordinators are set for all nodes
assert len(instance.ibv_coordinators) == 3
assert len(instance.jaccl_coordinators) == 3
for node_id in assigned_nodes:
assert node_id in instance.ibv_coordinators
coordinator = instance.ibv_coordinators[node_id]
assert node_id in instance.jaccl_coordinators
coordinator = instance.jaccl_coordinators[node_id]
assert ":" in coordinator
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
if node_id == assigned_nodes[0]:

View File

@@ -1,56 +1,48 @@
from typing import Callable
import pytest
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_coordinators,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_connection, create_node_profile
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_cycles_by_memory():
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
topology.add_node(node1)
topology.add_node(node2)
topology.add_node(node1_id)
topology.add_node(node2_id)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
connection1 = create_connection(1)
connection2 = create_connection(2)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
# assert
assert len(filtered_cycles) == 1
@@ -58,64 +50,65 @@ def test_filter_cycles_by_memory(
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_cycles_by_insufficient_memory():
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
topology.add_node(node1)
topology.add_node(node2)
topology.add_node(node1_id)
topology.add_node(node2_id)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
connection1 = create_connection(1)
connection2 = create_connection(2)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), Memory.from_kb(2001)
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_filter_multiple_cycles_by_memory():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
# assert
assert len(filtered_cycles) == 1
@@ -127,31 +120,38 @@ def test_filter_multiple_cycles_by_memory(
}
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
def test_get_smallest_cycles():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
# act
smallest_cycles = get_smallest_cycles(topology.get_cycles())
smallest_cycles = get_smallest_cycles(cycles)
# assert
assert len(smallest_cycles) == 1
@@ -168,9 +168,6 @@ def test_get_smallest_cycles(
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -179,19 +176,25 @@ def test_get_shard_assignments(
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_c_id, create_connection(2))
topology.add_connection(node_c_id, node_a_id, create_connection(3))
topology.add_connection(node_b_id, node_a_id, create_connection(4))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
@@ -199,7 +202,11 @@ def test_get_shard_assignments(
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
)
cycles = topology.get_cycles()
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
selected_cycle = cycles[0]
# act
@@ -228,28 +235,21 @@ def test_get_shard_assignments(
)
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
def test_get_hosts_from_subgraph():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -257,108 +257,47 @@ def test_get_hosts_from_subgraph(
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
Host(ip=("169.254.0.2"), port=1234),
Host(ip=("169.254.0.3"), port=1234),
Host(ip=("169.254.0.4"), port=1234),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_ibv_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
def test_get_mlx_jaccl_coordinators():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(1)
conn_b_a = create_connection(2)
conn_b_c = create_connection(3)
conn_c_b = create_connection(4)
conn_c_a = create_connection(5)
conn_a_c = create_connection(6)
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_b)
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
topology.add_connection(node_a_id, node_b_id, conn_a_b)
topology.add_connection(node_b_id, node_a_id, conn_b_a)
topology.add_connection(node_b_id, node_c_id, conn_b_c)
topology.add_connection(node_c_id, node_b_id, conn_c_b)
topology.add_connection(node_c_id, node_a_id, conn_c_a)
topology.add_connection(node_a_id, node_c_id, conn_a_c)
# act
coordinators = get_mlx_ibv_coordinators(
cycle, coordinator_port=5000, cycle_digraph=topology
coordinators = get_mlx_jaccl_coordinators(
node_a_id, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -387,11 +326,11 @@ def test_get_mlx_ibv_coordinators(
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
"node_b should use the IP from conn_b_a"
)
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
"node_c should use the IP from conn_c_a"
)

View File

@@ -1,13 +1,14 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
from exo.shared.types.topology import SocketConnection
@pytest.fixture
@@ -16,20 +17,15 @@ def topology() -> Topology:
@pytest.fixture
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
def connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile.from_bytes(
memory_profile = MemoryUsage.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -43,162 +39,85 @@ def node_profile() -> NodePerformanceProfile:
)
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
def test_add_node(topology: Topology):
# arrange
node_id = NodeId()
# act
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
topology.add_node(node_id)
# assert
data = topology.get_node_profile(node_id)
assert data == node_profile
assert topology.node_is_leaf(node_id)
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_add_connection(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
# act
data = topology.get_connection_profile(connection)
data = list(conn for _, _, conn in topology.list_connections())
# assert
assert data == connection.connection_profile
assert data == [connection]
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_remove_connection_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
topology: Topology, connection: SocketConnection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
# act
topology.remove_connection(connection)
topology.remove_connection(node_a, node_b, connection)
# assert
assert topology.get_connection_profile(connection) is None
assert list(topology.get_all_connections_between(node_a, node_b)) == []
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
# act
topology.remove_node(connection.local_node_id)
topology.remove_node(node_b)
# assert
assert topology.get_node_profile(connection.local_node_id) is None
assert list(topology.out_edges(node_a)) == []
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
def test_list_nodes(topology: Topology, connection: SocketConnection):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}
assert all(isinstance(node, NodeId) for node in nodes)
assert {node for node in nodes} == {node_a, node_b}

View File

@@ -15,6 +15,7 @@ def buffer() -> OrderedBuffer[Event]:
return OrderedBuffer[Event]()
@pytest.mark.asyncio
async def test_initial_state(buffer: OrderedBuffer[Event]):
"""Tests that a new buffer is empty and starts at index 1."""
assert buffer.next_idx_to_release == 0
@@ -22,6 +23,7 @@ async def test_initial_state(buffer: OrderedBuffer[Event]):
assert buffer.drain() == []
@pytest.mark.asyncio
async def test_ingest_and_drain_sequential_events(buffer: OrderedBuffer[Event]):
"""Tests ingesting and draining a simple, ordered sequence of events."""
events = [make_indexed_event(0), make_indexed_event(1), make_indexed_event(2)]
@@ -33,6 +35,7 @@ async def test_ingest_and_drain_sequential_events(buffer: OrderedBuffer[Event]):
assert not buffer.store
@pytest.mark.asyncio
async def test_ingest_out_of_order_events(buffer: OrderedBuffer[Event]):
"""Tests that out-of-order events are buffered and drained in the correct sequence."""
event1 = make_indexed_event(0)
@@ -48,6 +51,7 @@ async def test_ingest_out_of_order_events(buffer: OrderedBuffer[Event]):
assert buffer.next_idx_to_release == 3
@pytest.mark.asyncio
async def test_drain_with_gap_in_sequence(buffer: OrderedBuffer[Event]):
"""Tests that draining stops when there is a gap in the event indices."""
event1 = make_indexed_event(0)
@@ -64,6 +68,7 @@ async def test_drain_with_gap_in_sequence(buffer: OrderedBuffer[Event]):
assert 2 in buffer.store
@pytest.mark.asyncio
async def test_fill_gap_and_drain_remaining(buffer: OrderedBuffer[Event]):
"""Tests that once a gap is filled, the rest of the sequence is drained."""
event0 = make_indexed_event(0)
@@ -82,6 +87,7 @@ async def test_fill_gap_and_drain_remaining(buffer: OrderedBuffer[Event]):
assert buffer.next_idx_to_release == 3
@pytest.mark.asyncio
async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
"""Tests that if multiple events for the same index are ingested, the first one wins."""
event2_first = make_indexed_event(1)
@@ -100,6 +106,7 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
assert drained[1][1].event_id != event2_second[1].event_id
@pytest.mark.asyncio
async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):
"""Tests that events with an index lower than next_idx_to_release are dropped."""
buffer.ingest(*make_indexed_event(0))
@@ -117,6 +124,7 @@ async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):
assert buffer.drain() == []
@pytest.mark.asyncio
async def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]):
"""Tests reusing the buffer after it has been fully drained."""
buffer.ingest(*make_indexed_event(0))

View File

@@ -11,10 +11,8 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeGatheredInfo,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -27,13 +25,23 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import NodeInfo
from exo.shared.types.topology import RDMAConnection
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacTBConnections,
MacTBIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
@@ -47,16 +55,12 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -188,7 +192,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.copy(state.topology)
topology = copy.deepcopy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -196,8 +200,12 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -205,103 +213,69 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
# TODO: should be broken up into individual events instead of this monster
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
# TODO: makes me slightly sad
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
case MacTBIdentifiers():
profile.tb_interfaces = info.idents
case MacTBConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
(
conn_map[tb_conn.sink_uuid][0],
RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
return state.model_copy(
update={
"node_profiles": new_profiles,
"topology": topology,
"last_seen": last_seen,
"topology": topology,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
topology = copy.deepcopy(state.topology)
topology.add_connection(event.source, event.sink, event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.sink, event.source, event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -1,35 +1,47 @@
import os
import sys
from pathlib import Path
EXO_HOME_RELATIVE_PATH = os.environ.get("EXO_HOME", ".exo")
EXO_HOME = Path.home() / EXO_HOME_RELATIVE_PATH
_EXO_HOME_ENV = os.environ.get("EXO_HOME", None)
EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR")
EXO_MODELS_DIR = Path(EXO_MODELS_DIR_ENV) if EXO_MODELS_DIR_ENV else EXO_HOME / "models"
EXO_GLOBAL_EVENT_DB = EXO_HOME / "global_events.db"
EXO_WORKER_EVENT_DB = EXO_HOME / "worker_events.db"
EXO_MASTER_STATE = EXO_HOME / "master_state.json"
EXO_WORKER_STATE = EXO_HOME / "worker_state.json"
EXO_MASTER_LOG = EXO_HOME / "master.log"
EXO_WORKER_LOG = EXO_HOME / "worker.log"
EXO_LOG = EXO_HOME / "exo.log"
EXO_TEST_LOG = EXO_HOME / "exo_test.log"
def _get_xdg_dir(env_var: str, fallback: str) -> Path:
"""Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo."""
EXO_NODE_ID_KEYPAIR = EXO_HOME / "node_id.keypair"
if _EXO_HOME_ENV is not None:
return Path.home() / _EXO_HOME_ENV
EXO_WORKER_KEYRING_FILE = EXO_HOME / "worker_keyring"
EXO_MASTER_KEYRING_FILE = EXO_HOME / "master_keyring"
if sys.platform != "linux":
return Path.home() / ".exo"
EXO_IPC_DIR = EXO_HOME / "ipc"
xdg_value = os.environ.get(env_var, None)
if xdg_value is not None:
return Path(xdg_value) / "exo"
return Path.home() / fallback / "exo"
EXO_CONFIG_HOME = _get_xdg_dir("XDG_CONFIG_HOME", ".config")
EXO_DATA_HOME = _get_xdg_dir("XDG_DATA_HOME", ".local/share")
EXO_CACHE_HOME = _get_xdg_dir("XDG_CACHE_HOME", ".cache")
# Models directory (data)
_EXO_MODELS_DIR_ENV = os.environ.get("EXO_MODELS_DIR", None)
EXO_MODELS_DIR = (
EXO_DATA_HOME / "models"
if _EXO_MODELS_DIR_ENV is None
else Path.home() / _EXO_MODELS_DIR_ENV
)
# Log files (data/logs or cache)
EXO_LOG = EXO_CACHE_HOME / "exo.log"
EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
# lower bounds define timeouts for flops and memory bandwidth - these are the values for the M1 chip.
LB_TFLOPS = 2.3
LB_MEMBW_GBPS = 68
LB_DISK_GBPS = 1.5

View File

@@ -24,6 +24,8 @@ class _InterceptHandler(logging.Handler):
except ValueError:
level = record.levelno
return
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())

View File

@@ -1,5 +1,7 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
@@ -48,7 +50,7 @@ class ConfigData(BaseModel):
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await target_dir.mkdir(parents=True, exist_ok=True)
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
@@ -58,14 +60,14 @@ async def get_config_data(model_id: str) -> ConfigData:
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with await config_path.open("r") as f:
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await target_dir.mkdir(parents=True, exist_ok=True)
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
@@ -75,7 +77,7 @@ async def get_safetensors_size(model_id: str) -> Memory:
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with await index_path.open("r") as f:
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata

View File

@@ -1,5 +1,8 @@
"""Pytest configuration and shared fixtures for shared package tests."""
import asyncio
from typing import Generator
import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
@@ -9,6 +12,21 @@ from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""Create an event loop for the test session."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()
@pytest.fixture(autouse=True)
def reset_event_loop():
"""Reset the event loop for each test to ensure clean state."""
# This ensures each test gets a fresh event loop state
def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
@@ -34,7 +52,7 @@ def caplog(caplog: LogCaptureFixture):
format="{message}",
level=0,
filter=lambda record: record["level"].no >= caplog.handler.level,
enqueue=True,
enqueue=True, # Set to 'True' if your test is spawning child processes.
)
yield caplog
logger.remove(handler_id)

View File

@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
NodeDownloadProgress(download_progress=event), state
)
assert new_state == State(downloads={NodeId("node-1"): [event]})
assert new_state.downloads == {NodeId("node-1"): [event]}
def test_apply_two_node_download_progress():
@@ -39,7 +39,4 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -46,6 +46,7 @@ def fast_election_timeout(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
@pytest.mark.anyio
async def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> None:
"""
Start a round by injecting an ElectionMessage with higher clock.
@@ -101,6 +102,7 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> No
assert election.seniority == 2
@pytest.mark.anyio
async def test_peer_with_higher_seniority_wins_and_we_switch_master() -> None:
"""
If a peer with clearly higher seniority participates in the round, they should win.
@@ -154,6 +156,7 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master() -> None:
assert election.seniority == 0
@pytest.mark.anyio
async def test_ignores_older_messages() -> None:
"""
Messages with a lower clock than the current round are ignored by the receiver.
@@ -202,6 +205,7 @@ async def test_ignores_older_messages() -> None:
# Not asserting on the result; focus is on ignore behavior.
@pytest.mark.anyio
async def test_two_rounds_emit_two_broadcasts_and_increment_clock() -> None:
"""
Two successive rounds → two broadcasts. Second round triggered by a higher-clock message.
@@ -247,6 +251,7 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock() -> None:
# Not asserting on who won; just that both rounds were broadcast.
@pytest.mark.anyio
async def test_promotion_new_seniority_counts_participants() -> None:
"""
When we win against two peers in the same round, our seniority becomes
@@ -295,6 +300,7 @@ async def test_promotion_new_seniority_counts_participants() -> None:
assert election.seniority == 3
@pytest.mark.anyio
async def test_connection_message_triggers_new_round_broadcast() -> None:
"""
A connection message increments the clock and starts a new campaign.
@@ -346,6 +352,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
# After cancellation (before election finishes), no seniority changes asserted here.
@pytest.mark.anyio
async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
"""
With equal seniority, the node that has seen more commands should win the election.

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection
from exo.shared.types.topology import SocketConnection
def test_state_serialization_roundtrip() -> None:
@@ -11,14 +11,12 @@ def test_state_serialization_roundtrip() -> None:
node_a = NodeId("node-a")
node_b = NodeId("node-b")
connection = Connection(
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
connection = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
state.topology.add_connection(connection)
state.topology.add_connection(node_a, node_b, connection)
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)

View File

@@ -0,0 +1,118 @@
"""Tests for XDG Base Directory Specification compliance."""
import os
import sys
from pathlib import Path
from unittest import mock
def test_xdg_paths_on_linux():
"""Test that XDG paths are used on Linux when XDG env vars are set."""
with (
mock.patch.dict(
os.environ,
{
"XDG_CONFIG_HOME": "/tmp/test-config",
"XDG_DATA_HOME": "/tmp/test-data",
"XDG_CACHE_HOME": "/tmp/test-cache",
},
clear=False,
),
mock.patch.object(sys, "platform", "linux"),
):
# Re-import to pick up mocked values
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
assert Path("/tmp/test-config/exo") == constants.EXO_CONFIG_HOME
assert Path("/tmp/test-data/exo") == constants.EXO_DATA_HOME
assert Path("/tmp/test-cache/exo") == constants.EXO_CACHE_HOME
def test_xdg_default_paths_on_linux():
"""Test that XDG default paths are used on Linux when env vars are not set."""
# Remove XDG env vars and EXO_HOME
env = {
k: v
for k, v in os.environ.items()
if not k.startswith("XDG_") and k != "EXO_HOME"
}
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch.object(sys, "platform", "linux"),
):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
home = Path.home()
assert home / ".config" / "exo" == constants.EXO_CONFIG_HOME
assert home / ".local/share" / "exo" == constants.EXO_DATA_HOME
assert home / ".cache" / "exo" == constants.EXO_CACHE_HOME
def test_legacy_exo_home_takes_precedence():
"""Test that EXO_HOME environment variable takes precedence for backward compatibility."""
with mock.patch.dict(
os.environ,
{
"EXO_HOME": ".custom-exo",
"XDG_CONFIG_HOME": "/tmp/test-config",
},
clear=False,
):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
home = Path.home()
assert home / ".custom-exo" == constants.EXO_CONFIG_HOME
assert home / ".custom-exo" == constants.EXO_DATA_HOME
def test_macos_uses_traditional_paths():
"""Test that macOS uses traditional ~/.exo directory."""
# Remove EXO_HOME to ensure we test the default behavior
env = {k: v for k, v in os.environ.items() if k != "EXO_HOME"}
with (
mock.patch.dict(os.environ, env, clear=True),
mock.patch.object(sys, "platform", "darwin"),
):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
home = Path.home()
assert home / ".exo" == constants.EXO_CONFIG_HOME
assert home / ".exo" == constants.EXO_DATA_HOME
assert home / ".exo" == constants.EXO_CACHE_HOME
def test_node_id_in_config_dir():
"""Test that node ID keypair is in the config directory."""
import exo.shared.constants as constants
assert constants.EXO_NODE_ID_KEYPAIR.parent == constants.EXO_CONFIG_HOME
def test_models_in_data_dir():
"""Test that models directory is in the data directory."""
# Clear EXO_MODELS_DIR to test default behavior
env = {k: v for k, v in os.environ.items() if k != "EXO_MODELS_DIR"}
with mock.patch.dict(os.environ, env, clear=True):
import importlib
import exo.shared.constants as constants
importlib.reload(constants)
assert constants.EXO_MODELS_DIR.parent == constants.EXO_DATA_HOME

View File

@@ -1,203 +1,219 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
class TopologySnapshot(BaseModel):
nodes: list[NodeInfo]
connections: list[Connection]
nodes: Sequence[NodeId]
connections: Mapping[
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
]
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
model_config = ConfigDict(frozen=True, extra="forbid")
@dataclass
class Topology:
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
# the _graph can be used as a int -> NodeId map.
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
nodes=list(self.list_nodes()), connections=self.map_connections()
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node in snapshot.nodes:
for node_id in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node)
topology.add_node(node_id)
for connection in snapshot.connections:
topology.add_connection(connection)
for source in snapshot.connections:
for sink in snapshot.connections[source]:
for conn in snapshot.connections[source][sink]:
topology.add_connection(source, sink, conn)
return topology
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
return
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
]
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
def out_edges(
self, node_id: NodeId
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
if node_id not in self._vertex_indices:
return []
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
]
return (
(self._graph[nid], conn)
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
)
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._node_id_to_rx_id_map
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
return node_id in self._vertex_indices
def add_connection(
self,
connection: Connection,
source: NodeId,
sink: NodeId,
connection: SocketConnection | RDMAConnection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
if connection in self.get_all_connections_between(source, sink):
return
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
_ = self._graph.add_edge(src_id, sink_id, connection)
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
try:
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def list_connections(
self,
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
return (
(
self._graph[src_id],
self._graph[sink_id],
connection,
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._node_id_to_rx_id_map:
if node_id not in self._vertex_indices:
return
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
rx_idx = self._vertex_indices[node_id]
self._graph.remove_node(rx_idx)
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
del self._vertex_indices[node_id]
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
def replace_all_out_tb_connections(
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for sink, conn in new_connections:
self.add_connection(source, sink, conn)
def remove_connection(
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
) -> None:
if source not in self._vertex_indices or sink not in self._vertex_indices:
return
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[source], self._vertex_indices[sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == edge:
self._graph.remove_edge_from_index(conn_idx)
def get_cycles(self) -> list[list[NodeInfo]]:
def get_cycles(self) -> list[list[NodeId]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_cycles_tb(self) -> list[list[NodeInfo]]:
def get_cycles_tb(self) -> list[list[NodeId]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
tb_graph.add_edge(u, v, conn)
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[list[NodeInfo]] = []
cycles: list[list[NodeId]] = []
for cycle_idx in cycle_idxs:
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
topology = Topology()
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
for source, sink, connection in self.list_connections():
if source in node_ids and sink in node_ids:
topology.add_connection(source, sink, connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.topology import SocketConnection
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,25 +76,15 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
class NodePerformanceMeasured(BaseEvent):
# TODO: bikeshed this naem
class NodeGatheredInfo(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
class NodeDownloadProgress(BaseEvent):
@@ -107,11 +97,15 @@ class ChunkGenerated(BaseEvent):
class TopologyEdgeCreated(BaseEvent):
edge: Connection
source: NodeId
sink: NodeId
edge: SocketConnection
class TopologyEdgeDeleted(BaseEvent):
edge: Connection
source: NodeId
sink: NodeId
edge: SocketConnection
Event = (
@@ -125,10 +119,8 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeGatheredInfo
| NodeDownloadProgress
| ChunkGenerated
| TopologyEdgeCreated

View File

@@ -1,10 +1,11 @@
import re
from typing import ClassVar
from pydantic import BaseModel, computed_field, field_validator
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,12 +1,14 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import TBIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryPerformanceProfile(CamelCaseModel):
class MemoryUsage(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -44,7 +46,6 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -53,15 +54,16 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[TBIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
class ConnectionProfile(CamelCaseModel):
throughput: float
latency: float
jitter: float
pass

View File

@@ -0,0 +1,64 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class TBConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class TBIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
# Intentionally minimal, only collecting data we care about - there's a lot more
class TBReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str
class TBConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None
class TBConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None
device_name_key: str
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: TBReceptacleTag
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
if self.domain_uuid_key is None:
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
iface = f"rdma_{ifaces[tag]}"
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
def conn(self) -> TBConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
)
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
class TBConnectivity(BaseModel):
SPThunderboltDataType: list[TBConnectivityData]
@classmethod
async def gather(cls) -> list[TBConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType

View File

@@ -1,37 +1,32 @@
from exo.shared.types.common import NodeId
from enum import Enum
from loguru import logger
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import FrozenModel
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
def is_thunderbolt(self) -> bool:
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
logger.warning("duh")
return True
# TODO
class LinkType(str, Enum):
Thunderbolt = "Thunderbolt"
Ethernet = "Ethernet"
WiFi = "WiFi"
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -29,8 +29,8 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
ibv_devices: list[list[str | None]]
ibv_coordinators: dict[NodeId, str]
jaccl_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]
# TODO: Single node instance

View File

@@ -36,9 +36,6 @@ class Sender[T](AnyioSender[T]):
raise ClosedResourceError
return Receiver(_state=self._state)
def __enter__(self) -> Self:
return self
class Receiver[T](AnyioReceiver[T]):
def clone(self) -> "Receiver[T]":

View File

View File

@@ -0,0 +1,231 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacTBIdentifiers(TaggedModel):
idents: Sequence[TBIdentifier]
class MacTBConnections(TaggedModel):
conns: Sequence[TBConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
# TODO
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacTBIdentifiers
| MacTBConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
if IS_DARWIN:
tg.start_soon(self._monitor_system_profiler)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await TBConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacTBIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacTBConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -0,0 +1,70 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -0,0 +1,56 @@
import socket
from collections.abc import Mapping
from ipaddress import ip_address
from anyio import create_task_group, to_thread
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
# TODO: ref. api port
async def check_reachability(
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout
try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
except socket.gaierror:
# seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}")
return
finally:
sock.close()
if result == 0:
if target_node_id not in out:
out[target_node_id] = set()
out[target_node_id].add(target_ip)
async def check_reachable(
our_node_id: NodeId,
topology: Topology,
profiles: Mapping[NodeId, NodePerformanceProfile],
) -> Mapping[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
our_profile = profiles.get(our_node_id, None)
if our_profile is None:
return {}
our_interfaces = our_profile.network_interfaces
async with create_task_group() as tg:
for node_id in topology.list_nodes():
if node_id not in profiles or node_id == our_node_id:
continue
for iface in profiles[node_id].network_interfaces:
if ip_address(iface.ip_address).is_loopback:
# Definitely a loopback address
continue
if iface in our_interfaces:
# Skip duplicates with our own interfaces
continue
tg.start_soon(check_reachability, iface.ip_address, node_id, reachable)
return reachable

View File

@@ -19,11 +19,20 @@ class CamelCaseModel(BaseModel):
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
strict=True,
)
class FrozenModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
strict=True,
frozen=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):

View File

@@ -0,0 +1,77 @@
"""Tests for macmon error handling.
These tests verify that MacMon errors are handled gracefully without
crashing the application or spamming logs.
"""
import platform
from subprocess import CalledProcessError
from unittest.mock import AsyncMock, patch
import pytest
from exo.worker.utils.macmon import MacMonError, get_metrics_async
@pytest.mark.skipif(
platform.system().lower() != "darwin" or "arm" not in platform.machine().lower(),
reason="MacMon only supports macOS with Apple Silicon",
)
class TestMacMonErrorHandling:
"""Test MacMon error handling."""
async def test_called_process_error_wrapped_as_macmon_error(self) -> None:
"""CalledProcessError should be wrapped as MacMonError."""
mock_error = CalledProcessError(
returncode=1,
cmd=["macmon", "pipe", "-s", "1"],
stderr=b"some error message",
)
with (
patch(
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
),
patch(
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
) as mock_run,
):
mock_run.side_effect = mock_error
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon failed with return code 1" in str(exc_info.value)
assert "some error message" in str(exc_info.value)
async def test_called_process_error_with_no_stderr(self) -> None:
"""CalledProcessError with no stderr should be handled gracefully."""
mock_error = CalledProcessError(
returncode=1,
cmd=["macmon", "pipe", "-s", "1"],
stderr=None,
)
with (
patch(
"exo.worker.utils.macmon.shutil.which", return_value="/usr/bin/macmon"
),
patch(
"exo.worker.utils.macmon.run_process", new_callable=AsyncMock
) as mock_run,
):
mock_run.side_effect = mock_error
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon failed with return code 1" in str(exc_info.value)
assert "no stderr" in str(exc_info.value)
async def test_macmon_not_found_raises_macmon_error(self) -> None:
"""When macmon is not found in PATH, MacMonError should be raised."""
with patch("exo.worker.utils.macmon.shutil.which", return_value=None):
with pytest.raises(MacMonError) as exc_info:
await get_metrics_async()
assert "MacMon not found in PATH" in str(exc_info.value)

View File

@@ -1,6 +1,7 @@
import multiprocessing as mp
import time
import pytest
from anyio import fail_after
from loguru import logger
@@ -27,8 +28,8 @@ def bar(send: MpSender[str]):
send.close()
# not async, just want the fail_after
async def test_channel_setup():
@pytest.mark.anyio
async def test_channel_ipc():
with fail_after(0.5):
s, r = mp_channel[str]()
p1 = mp.Process(target=foo, args=(r,))

View File

@@ -235,6 +235,7 @@ class L1C(TaggedModel):
L1 = L1A | L1B | L1C
@pytest.mark.anyio
async def test_tagged_union_is_fast():
# payload along the "C" path (worst case for DFS if branches are tried A->B->C)
payload = {"L1C": {"child": {"L2C": {"child": {"L3C": {"x": 123}}}}}}

View File

@@ -1,15 +1,19 @@
import asyncio
import hashlib
import os
import shutil
import ssl
import time
import traceback
from datetime import timedelta
from typing import Literal, cast
from pathlib import Path
from typing import Callable, Literal
from urllib.parse import urljoin
import anyio
import httpx
from anyio import Path, to_thread
import aiofiles
import aiofiles.os as aios
import aiohttp
import certifi
from loguru import logger
from pydantic import (
BaseModel,
@@ -20,11 +24,10 @@ from pydantic import (
TypeAdapter,
)
from exo.shared.constants import EXO_HOME, EXO_MODELS_DIR
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
@@ -129,37 +132,16 @@ async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def ensure_exo_home() -> Path:
home = Path(EXO_HOME)
await home.mkdir(parents=True, exist_ok=True)
return home
async def has_exo_home_read_access() -> bool:
try:
return await to_thread.run_sync(os.access, EXO_HOME, os.R_OK)
except OSError:
return False
async def has_exo_home_write_access() -> bool:
try:
return await to_thread.run_sync(os.access, EXO_HOME, os.W_OK)
except OSError:
return False
async def ensure_models_dir() -> Path:
models_dir = Path(EXO_MODELS_DIR)
await models_dir.mkdir(parents=True, exist_ok=True)
return models_dir
await aios.makedirs(EXO_MODELS_DIR, exist_ok=True)
return EXO_MODELS_DIR
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
if not await model_dir.exists():
if not await aios.path.exists(model_dir):
return False
await to_thread.run_sync(shutil.rmtree, model_dir)
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
return True
@@ -167,14 +149,14 @@ async def seed_models(seed_dir: str | Path):
"""Move model in resources folder of app to .cache/huggingface/hub"""
source_dir = Path(seed_dir)
dest_dir = await ensure_models_dir()
async for path in source_dir.iterdir():
if await path.is_dir() and path.name.startswith("models--"):
for path in source_dir.iterdir():
if path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir / path.name
if await dest_path.exists():
if await aios.path.exists(dest_path):
logger.info("Skipping moving model to .cache directory")
else:
try:
await path.rename(dest_path)
await aios.rename(str(path), str(dest_path))
except Exception:
logger.error(f"Error seeding model {path} to {dest_path}")
logger.error(traceback.format_exc())
@@ -186,16 +168,16 @@ async def fetch_file_list_with_cache(
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
await target_dir.mkdir(parents=True, exist_ok=True)
await aios.makedirs(target_dir, exist_ok=True)
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
if await cache_file.exists():
async with await cache_file.open("r") as f:
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
await cache_file.parent.mkdir(parents=True, exist_ok=True)
async with await cache_file.open("w") as f:
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
return file_list
@@ -210,7 +192,7 @@ async def fetch_file_list_with_retry(
except Exception as e:
if attempt == n_attempts - 1:
raise e
await anyio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
)
@@ -224,55 +206,66 @@ async def _fetch_file_list(
headers = await get_download_headers()
async with (
create_http_client(timeout_profile="short") as client,
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
response = await client.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Failed to fetch file list: {response.status_code}")
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
for item in data:
if item.type == "file":
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
async def get_download_headers() -> dict[str, str]:
return {**(await get_auth_headers()), "Accept-Encoding": "identity"}
def create_http_client(
def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long",
) -> httpx.AsyncClient:
) -> aiohttp.ClientSession:
if timeout_profile == "short":
timeout = httpx.Timeout(
connect=10,
read=30,
write=30,
pool=30,
)
total_timeout = 30
connect_timeout = 10
sock_read_timeout = 30
sock_connect_timeout = 10
else:
timeout = httpx.Timeout(
connect=60,
read=1800,
write=1800,
pool=1800,
)
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 1800
sock_connect_timeout = 60
return httpx.AsyncClient(timeout=timeout)
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
sock_read=sock_read_timeout,
sock_connect=sock_connect_timeout,
),
)
async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -> str:
hasher = hashlib.sha1() if hash_type == "sha1" else hashlib.sha256()
if hash_type == "sha1":
header = f"blob {(await path.stat()).st_size}\0".encode()
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
hasher.update(header)
async with await path.open("rb") as f:
async with aiofiles.open(path, "rb") as f:
while chunk := await f.read(8 * 1024 * 1024):
hasher.update(chunk)
return hasher.hexdigest()
@@ -288,28 +281,24 @@ async def file_meta(
)
headers = await get_download_headers()
async with (
create_http_client(timeout_profile="short") as client,
create_http_session(timeout_profile="short") as session,
session.head(url, headers=headers) as r,
):
r = await client.head(url, headers=headers)
if r.status_code == 307:
if r.status == 307:
# On redirect, only trust Hugging Face's x-linked-* headers.
if "x-linked-size" in r.headers and "x-linked-etag" in r.headers:
content_length = int(r.headers["x-linked-size"])
etag = trim_etag(r.headers["x-linked-etag"])
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("x-linked-etag")
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag)
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers["location"]
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
# this can totally fail in weird ways if the HF endpoint behaves weirdly
content_length = int(
r.headers.get("x-linked-size", None)
or r.headers.get("content-length", None)
or 0
)
etag = cast(
str | None,
r.headers.get("x-linked-etag", None) or r.headers.get("etag", None),
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
etag = trim_etag(etag)
@@ -321,13 +310,13 @@ async def download_file_with_retry(
revision: str,
path: str,
target_dir: Path,
progress_sender: Sender[tuple[int, int, bool]] | None,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _download_file(
repo_id, revision, path, target_dir, progress_sender
repo_id, revision, path, target_dir, on_progress
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
@@ -336,7 +325,7 @@ async def download_file_with_retry(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await anyio.sleep(min(8, 0.1 * (2.0**attempt)))
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
)
@@ -347,16 +336,18 @@ async def _download_file(
revision: str,
path: str,
target_dir: Path,
progress_sender: Sender[tuple[int, int, bool]] | None,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path:
if await (target_dir / path).exists():
if await aios.path.exists(target_dir / path):
return target_dir / path
await ((target_dir / path).parent).mkdir(parents=True, exist_ok=True)
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(repo_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
(await partial_path.stat()).st_size if (await partial_path.exists()) else None
(await aios.stat(partial_path)).st_size
if (await aios.path.exists(partial_path))
else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
@@ -365,19 +356,20 @@ async def _download_file(
headers["Range"] = f"bytes={resume_byte_pos}-"
n_read = resume_byte_pos or 0
async with (
create_http_client(timeout_profile="long") as client,
client.stream("GET", url, headers=headers) as r,
create_http_session(timeout_profile="long") as session,
session.get(url, headers=headers) as r,
):
if r.status_code == 404:
if r.status == 404:
raise FileNotFoundError(f"File not found: {url}")
assert r.status_code in [200, 206], (
f"Failed to download {path} from {url}: {r.status_code}"
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
)
async with await partial_path.open("ab" if resume_byte_pos else "wb") as f:
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
async with aiofiles.open(
partial_path, "ab" if resume_byte_pos else "wb"
) as f:
while chunk := await r.content.read(8 * 1024 * 1024):
n_read = n_read + (await f.write(chunk))
if progress_sender is not None:
await progress_sender.send((n_read, length, False))
on_progress(n_read, length, False)
final_hash = await calc_hash(
partial_path, hash_type="sha256" if len(remote_hash) == 64 else "sha1"
@@ -385,15 +377,14 @@ async def _download_file(
integrity = final_hash == remote_hash
if not integrity:
try:
await partial_path.unlink()
await aios.remove(partial_path)
except Exception as e:
logger.error(f"Error removing partial file {partial_path}: {e}")
raise Exception(
f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}"
)
await partial_path.rename(target_dir / path)
if progress_sender is not None:
await progress_sender.send((length, length, True))
await aios.rename(partial_path, target_dir / path)
on_progress(length, length, True)
return target_dir / path
@@ -449,11 +440,11 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await (target_dir).mkdir(parents=True, exist_ok=True)
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir, None
repo_id, revision, "model.safetensors.index.json", target_dir
)
async with await index_file.open("r") as f:
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
@@ -470,10 +461,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await path.exists():
return (await path.stat()).st_size
if await partial_path.exists():
return (await partial_path.stat()).st_size
if await aios.path.exists(path):
return (await aios.stat(path)).st_size
if await aios.path.exists(partial_path):
return (await aios.stat(partial_path)).st_size
return 0
@@ -484,12 +475,12 @@ async def download_progress_for_local_path(
total_files = 0
total_bytes = 0
if await local_path.is_dir():
if await aios.path.isdir(local_path):
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
file_path = Path(root) / f
size = (await (file_path).stat()).st_size
size = (await aios.stat(file_path)).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
@@ -524,10 +515,9 @@ async def download_progress_for_local_path(
)
# this function still has disgusting amounts of currying, but its better
async def download_shard(
shard: ShardMetadata,
progress_sender: Sender[RepoDownloadProgress],
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], None],
max_parallel_downloads: int = 8,
skip_download: bool = False,
allow_patterns: list[str] | None = None,
@@ -536,7 +526,7 @@ async def download_shard(
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await Path(str(shard.model_meta.model_id)).exists():
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
@@ -548,7 +538,7 @@ async def download_shard(
"/", "--"
)
if not skip_download:
await target_dir.mkdir(parents=True, exist_ok=True)
await aios.makedirs(target_dir, exist_ok=True)
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
@@ -568,14 +558,9 @@ async def download_shard(
)
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def huh(file: FileListEntry, recv: Receiver[tuple[int, int, bool]]):
async with recv:
async for curr, total, done in recv:
await progress_sender.send(on_progress_wrapper(file, curr, total, done))
def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> RepoDownloadProgress:
):
start_time = (
file_progress[file.path].start_time
if file.path in file_progress
@@ -611,12 +596,15 @@ async def download_shard(
else "in_progress",
start_time=start_time,
)
return calculate_repo_progress(
on_progress(
shard,
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
calculate_repo_progress(
shard,
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
),
)
for file in filtered_file_list:
@@ -634,31 +622,28 @@ async def download_shard(
start_time=time.time(),
)
semaphore = anyio.Semaphore(max_parallel_downloads)
semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(
file: FileListEntry, sender: Sender[tuple[int, int, bool]]
):
async def download_with_semaphore(file: FileListEntry):
async with semaphore:
await download_file_with_retry(
str(shard.model_meta.model_id),
revision,
file.path,
target_dir,
sender,
lambda curr_bytes, total_bytes, is_renamed: on_progress_wrapper(
file, curr_bytes, total_bytes, is_renamed
),
)
if not skip_download:
async with anyio.create_task_group() as tg:
for file in filtered_file_list:
send, recv = channel[tuple[int, int, bool]](1)
tg.start_soon(download_with_semaphore, file, send)
tg.start_soon(huh, file, recv)
await asyncio.gather(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
)
await progress_sender.send(final_repo_progress)
on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
return target_dir / gguf.path, final_repo_progress
else:

View File

@@ -3,7 +3,8 @@ from fnmatch import fnmatch
from pathlib import Path
from typing import Callable, Generator, Iterable
import anyio
import aiofiles
import aiofiles.os as aios
from loguru import logger
from exo.shared.types.worker.shards import ShardMetadata
@@ -68,9 +69,9 @@ def get_hf_home() -> Path:
async def get_hf_token() -> str | None:
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
token_path = anyio.Path(get_hf_home() / "token")
if await token_path.exists():
async with await anyio.open_file(token_path, "r") as f:
token_path = get_hf_home() / "token"
if await aios.path.exists(token_path):
async with aiofiles.open(token_path, "r") as f:
return (await f.read()).strip()
return None

View File

@@ -1,8 +1,7 @@
import asyncio
from pathlib import Path
from typing import AsyncIterator, Callable
from anyio import Path
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.worker.shards import (

View File

@@ -1,11 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import timedelta
from typing import AsyncIterator, Callable, Self
import anyio
from anyio import Path, create_task_group
from anyio.abc import CancelScope, TaskGroup
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
@@ -13,54 +9,7 @@ from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
@dataclass
class ShardDownloader2:
progress_sender: Sender[RepoDownloadProgress]
max_parallel_downloads = 8
# The last item on the shard stack is currently being downloaded
shard_stack: list[tuple[ShardMetadata, bool]] = field(
init=False, default_factory=list
)
_top_scope: CancelScope | None = field(init=False, default=None)
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
def start_shard(self, shard: ShardMetadata, config_only: bool = False):
self.shard_stack.append((shard, config_only))
# Cancel current tasks
if self._top_scope:
self._top_scope.cancel()
# Create a new scope
self._top_scope = CancelScope()
async def run(self):
async with self._tg:
await anyio.sleep_forever()
def shutdown(self):
self.progress_sender.close()
self._tg.cancel_scope.cancel()
async def _new_download(self, scope: CancelScope):
(shard, config_only) = self.shard_stack[-1]
with self.progress_sender.clone() as send, scope:
allow_patterns = ["config.json"] if config_only else None
target_dir, _ = await download_shard(
shard,
send,
max_parallel_downloads=self.max_parallel_downloads,
allow_patterns=allow_patterns,
)
return target_dir
@classmethod
def default(cls) -> tuple[Self, Receiver[RepoDownloadProgress]]:
send, recv = channel[RepoDownloadProgress](10)
return cls(send), recv
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?

View File

@@ -101,13 +101,7 @@ def mlx_distributed_init(
bound_instance: BoundInstance,
) -> mx.distributed.Group:
"""
Initialize the MLX distributed (runs in thread pool).
Either hosts or mlx_ibv_devices must be provided:
- hosts: traditional host-based connectivity using MLX_HOSTFILE
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
Initialize the MLX distributed
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
@@ -129,22 +123,22 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
ibv_devices_json = json.dumps(ibv_devices)
jaccl_devices_json = json.dumps(jaccl_devices)
with open(devices_file, "w") as f:
_ = f.write(ibv_devices_json)
_ = f.write(jaccl_devices_json)
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file
logger.info(f"rank {rank} MLX_JACCL_DEVICES: {jaccl_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_IBV_COORDINATOR"] = ibv_coordinator
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")
@@ -252,15 +246,22 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
# TODO: Let's move away from this custom logic to mlx_lm.load()
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [163586]
elif "glm" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [151336, 151329, 151338]
else:
eos_token_ids = None
tokenizer = cast(
TokenizerWrapper,
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
# TODO: HACK for Kimi K2 wrong eos token id
eos_token_ids=[163586]
if "kimi-k2" in shard_metadata.model_meta.model_id.lower()
else None,
eos_token_ids=eos_token_ids,
),
)
assert isinstance(tokenizer, TokenizerWrapper)

View File

@@ -16,15 +16,13 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -33,7 +31,7 @@ from exo.shared.types.tasks import (
Task,
TaskStatus,
)
from exo.shared.types.topology import Connection
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -44,14 +42,14 @@ from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
class Worker:
@@ -85,7 +83,7 @@ class Worker:
self.state: State = State()
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
self._tg: TaskGroup = create_task_group()
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -97,37 +95,13 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -140,6 +114,17 @@ class Worker:
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -159,7 +144,6 @@ class Worker:
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
assert self._tg
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
@@ -248,8 +232,7 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
if self._tg:
self._tg.cancel_scope.cancel()
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
@@ -266,24 +249,24 @@ class Worker:
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
),
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
),
)
async def _nack_request(self, since_idx: int) -> None:
@@ -332,7 +315,6 @@ class Worker:
event_sender=self.event_sender.clone(),
)
self.runners[task.bound_instance.bound_runner_id] = runner
assert self._tg
self._tg.start_soon(runner.run)
return runner
@@ -357,11 +339,10 @@ class Worker:
# TODO: i hate callbacks
def download_progress_callback(
_: ShardMetadata, progress: RepoDownloadProgress
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
shard = progress.shard
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
self.download_status[shard] = status
@@ -392,7 +373,6 @@ class Worker:
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
@@ -415,28 +395,35 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology)
conns = await check_reachable(
self.node_id, self.state.topology, self.state.node_profiles
)
for nid in conns:
for ip in conns[nid]:
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
edge = SocketConnection(
# nonsense multiaddr
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
await self.event_sender.send(
TopologyEdgeCreated(
source=self.node_id, sink=nid, edge=edge
)
)
for nid, conn in self.state.topology.out_edges(self.node_id):
if (
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
if not isinstance(conn, SocketConnection):
continue
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
nid, set()
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await self.event_sender.send(
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
)
await anyio.sleep(10)

View File

@@ -22,7 +22,7 @@ def entrypoint(
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
and len(bound_instance.instance.jaccl_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"

View File

@@ -1,6 +0,0 @@
from .profile import start_polling_memory_metrics, start_polling_node_metrics
__all__ = [
"start_polling_node_metrics",
"start_polling_memory_metrics",
]

View File

@@ -1,6 +1,7 @@
import platform
import shutil
from subprocess import CalledProcessError
from typing import cast
from anyio import run_process
from pydantic import BaseModel, ConfigDict, ValidationError
@@ -80,7 +81,6 @@ async def get_metrics_async() -> Metrics:
"""
path = _get_binary_path()
result = None
try:
# TODO: Keep Macmon running in the background?
result = await run_process([path, "pipe", "-s", "1"])
@@ -90,8 +90,14 @@ async def get_metrics_async() -> Metrics:
except ValidationError as e:
raise MacMonError(f"Error parsing JSON output: {e}") from e
except CalledProcessError as e:
if result:
raise MacMonError(
f"MacMon failed with return code {result.returncode}"
) from e
raise e
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
raise MacMonError(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
) from e

View File

@@ -1,41 +0,0 @@
import socket
from anyio import create_task_group, to_thread
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
# TODO: ref. api port
async def check_reachability(
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout
try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
except socket.gaierror:
# seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}")
return
finally:
sock.close()
if result == 0:
if target_node_id not in out:
out[target_node_id] = set()
out[target_node_id].add(target_ip)
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability, iface.ip_address, node.node_id, reachable
)
return reachable

View File

@@ -1,108 +0,0 @@
import os
import platform
from typing import Any, Callable, Coroutine
import anyio
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from .macmon import (
MacMonError,
Metrics,
)
from .macmon import (
get_metrics_async as macmon_get_metrics_async,
)
from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
)
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
if platform.system().lower() == "darwin":
return await macmon_get_metrics_async()
def get_memory_profile() -> MemoryPerformanceProfile:
"""Construct a MemoryPerformanceProfile using psutil"""
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
async def start_polling_memory_metrics(
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 0.5,
) -> None:
"""Continuously poll and emit memory-only metrics at a faster cadence.
Parameters
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
- poll_interval_s: interval between polls
"""
while True:
try:
mem = get_memory_profile()
await callback(mem)
except MacMonError as e:
logger.opt(exception=e).error("Memory Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
)
)
except MacMonError as e:
logger.opt(exception=e).error("Resource Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)

843
uv.lock generated
View File

File diff suppressed because it is too large Load Diff