Compare commits

...

56 Commits

Author SHA1 Message Date
ciaranbor
92de82bdf6 Add traces page to UI 2026-01-30 18:59:02 +00:00
ciaranbor
e792b19f5d Integrate trace collection with exo runners 2026-01-30 18:59:02 +00:00
ciaranbor
edc3a88b12 Minor tweaks 2026-01-30 18:58:30 +00:00
ciaranbor
0ae0c788d5 Average per step 2026-01-30 18:58:30 +00:00
ciaranbor
09abf44d49 Add statistics interpretation 2026-01-30 18:58:30 +00:00
ciaranbor
ac4b78349b Require external evals 2026-01-30 18:58:30 +00:00
ciaranbor
96c440c3b1 Instrument distributed runner 2026-01-30 18:58:30 +00:00
ciaranbor
c14d63cf61 Add tracing utils 2026-01-30 18:58:30 +00:00
Evan Quiney
cd946742f7 fix skipping logic in worker plan (#1342)
the worker plan function had some skipping logic missing, leading to
double-submitting tasks.
2026-01-30 14:31:40 +00:00
rltakashige
a5bc38ad1f Check all nodes to evict (#1341)
## Motivation

If nodes have uneven memory, one node may evict cache that remains on
another node. This will break prefill on some setups.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-30 13:42:09 +00:00
Evan Quiney
2a4e0d4629 make node-ids unique per-session (#1338)
we currently have no strict reuqirements that node ids persist across
sessions, so we can generate fresh nodeids each time

this avoids issues like #1332, but prevents further features such as
caching downloads or node-id dialling

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-30 13:33:31 +00:00
Evan Quiney
46a14153dd switch to ModelCard.load outside of download log (#1339)
some attempts to load model cards (i.e. build_base_shard) always went
through networking rather than using downloaded model cards. we should
always default to ModelCard.load in these scenarios
2026-01-30 11:20:20 +00:00
Evan
9ba61f3733 improve log message in shard downloader
closes #1336
2026-01-30 10:35:01 +00:00
rltakashige
d9eca75895 Add usage stats (#1333)
## Motivation

(Probably) the final missing piece of the Chat Completions API 

## Changes

Add UsageStats 

## Why It Works

OpenCode reviewed my PR and gave me stats:

<img width="1150" height="802" alt="image"
src="https://github.com/user-attachments/assets/ebc06bae-797f-4087-87d5-2f26cf60fc48"
/>


## Test Plan

### Automated Testing
No tests were broken.
2026-01-30 10:23:08 +00:00
rltakashige
9dabde7e57 Fix bench after recent updates (#1331)
## Motivation

A lot of changes happened without much attention to the state of exo
bench.

## Changes

Use TaggedModel for BenchChatCompletion so it serialises properly.
Don't break after gpt oss tool call to preserve parity with the rest of
the codebase.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<img width="2856" height="678" alt="image"
src="https://github.com/user-attachments/assets/2e18cf0d-c0f8-467c-9763-1a6a59c8a327"
/>

Also tested GPT OSS tool calling in OpenCode
2026-01-29 19:14:40 +00:00
ciaranbor
a31942ce12 Ciaran/image non streaming (#1328)
## Motivation

The dashboard UI attempted to parse all image generation responses as
SSE streams, even when streaming was disabled. This broke non-streaming
image generation.

## Changes

- Parse JSON responses directly when not streaming, use SSE parser only
when stream=true AND partialImages > 0
- explicitly disable partial images when not streaming

## Why It Works

Both API and dashboard now use the same condition (stream &&
partialImages > 0) to determine response format, ensuring correct
parsing.

## Test Plan

### Manual Testing

Non-streamed image generation results appear in the UI. Streamed image
generation still works
2026-01-29 17:24:32 +00:00
Alex Cheema
7cc313b22a Treat Swift/Xcode build warnings as errors (#1322)
## Motivation

Warnings that go unchecked tend to accumulate and hide real issues.
Treating them as errors ensures they are addressed immediately, both
locally during development and in CI.

## Changes

Added `SWIFT_TREAT_WARNINGS_AS_ERRORS = YES` and
`GCC_TREAT_WARNINGS_AS_ERRORS = YES` to the **project-level** Debug and
Release build configurations in `project.pbxproj`. This applies to all
targets (EXO, EXOTests, EXOUITests).

## Why It Works

Xcode's `SWIFT_TREAT_WARNINGS_AS_ERRORS` and
`GCC_TREAT_WARNINGS_AS_ERRORS` build settings promote Swift and C/ObjC
warnings to errors at compile time. Setting them at the project level
means all targets inherit the policy without needing per-target or
CI-level overrides.

## Test Plan

### Manual Testing
- Built the EXO scheme in Release configuration with `xcodebuild` — no
warning-as-error failures from Swift or C/ObjC sources.

### Automated Testing
- CI already builds with `-configuration Release`, so it will
automatically enforce warnings-as-errors via the inherited project
settings — no CI changes needed.
2026-01-29 17:15:49 +00:00
rltakashige
2837225dc7 Load pipeline layers sequentially (#1329)
## Motivation

Slightly annoyed by needing this change, but same story as for tensor
loading...
2026-01-29 17:08:38 +00:00
Jake Hillion
e4c6a7dbb4 nix: add Python packaging with uv2nix
Add uv2nix to build Python packages from uv.lock. This creates a fully
Nix-managed Python environment with the Rust bindings injected via overlay.

Changes:
- Add pyproject-nix, uv2nix, and pyproject-build-systems flake inputs
- Create python/parts.nix with overlays to inject Nix-built Rust wheel
- Export packages.exo on macOS (wraps exo/exo-master/exo-worker with dashboard)
- Add checks.lint (ruff, all platforms) and checks.pytest (macOS only)
- Simplify CI typecheck job using nicknovitski/nix-develop action
- Delete .github/actions/typecheck composite action (no longer needed)
- Add no-build-package for MLX packages in pyproject.toml (use wheels)

The Python build is currently macOS-only since MLX requires Metal. Linux
support will be added once the pyproject dependencies are simplified.

Test plan:
- Run `nix flake check` on macOS to verify pytest and lint pass
- Build exo package on macOS: `nix build .#exo`
- Verify CI pipeline passes with simplified typecheck job
2026-01-29 16:35:58 +00:00
Evan
b1e88a3d06 shfmt
adds shfmt, a shell formatter, and formats the bash files
2026-01-29 15:24:36 +00:00
Jake Hillion
ebeddfb308 mlx: build with Nix (#1285)
In order to make testing and deployment simpler and more reproducible,
we want to provide a Nix derivation for our macOS .app build. We already
build the Rust and dashboard with Nix, but so far the Python has been
blocked because we haven't had an MLX build.

This change adds a Metal compiler derivation that uses `requireFile` to
be provided a NAR of the unfree macOS Metal compiler. It is documented
how to get this file, but effectively you have to trigger the download,
mount the DMG, and NAR the result. Once this is added to the store by
hash we can build MLX using it. The MLX build itself is quite self
explanatory.

Test plan:
- CI. We follow the instructions to grab the Metal compiler. Once this
is in Cachix we should really never do this again, and I can pin the
path too to ensure it doesn't leave.
- MLX tests run as part of the MLX derivation's build. They pass.
- `NIXPKGS_ALLOW_UNFREE=1 nix build .#mlx.passthru.tests.mlxTest
--impure --option sandbox false`

---------

Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-01-29 14:07:00 +00:00
Alex Cheema
9111575997 Add startup delay and update network setup message (#1309)
## Summary
- Add 20-second startup delay to wait for macOS to finish network setup
after boot
- Update user-facing message to clarify the service configures local
networking, disables Thunderbolt Bridge (preventing packet storms), and
installs a Network Location

## Test plan
- [ ] Manual verification of Swift syntax
- [ ] Test network setup on macOS device after reboot

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-29 13:05:50 +00:00
Sami Khan
ffacabe7e4 Fix uninstall button error (#1306)
## Motivation

Fix "Network setup script failed" error when clicking uninstall button
and resolve Xcode compiler warnings.

## Changes

- NetworkSetupHelper.swift: Add || true guards and explicit return 0 in
find_and_enable_thunderbolt_bridge to prevent script failures with set
-euo pipefail
- ThunderboltBridgeService.swift: Use withCString and
withUnsafeMutablePointer for Authorization API calls to fix pointer
lifetime warnings
- EXOApp.swift: Mark showNotification as nonisolated to fix main actor
isolation warning

## Why It Works

- The uninstall script's Thunderbolt re-enable function could exit
non-zero in edge cases (no bridges, no matches). Since this is a cleanup
step, failures should not abort uninstall.
- Swift requires explicit pointer lifetime management when passing
strings/structs to C APIs.
- showNotification is called from a nonisolated delegate method and uses
thread-safe APIs.

## Test Plan

### Manual Testing
Hardware: MacBook Pro

- Clicked Uninstall button, verified it completes without error
- Built in Xcode, verified no warnings   

### Automated Testing
N/A
2026-01-29 12:57:48 +00:00
rltakashige
9e58a57599 Add RDMA caveats to README.md (#1316)
## Motivation

Running RDMA from source is not well documented as is. Several
surprising things that took time to debug internally too.

App should be updated to detect MacOS versions in future.
2026-01-28 18:44:00 +00:00
Evan Quiney
748a026071 fix configdata validation for kimi-k2 (#1314)
## motivation
our shard downloader could not correctly fetch data for kimi-k2, as it
deferred some values to a text_config field.
## changes
config_data now prioritizes this field if it exists in information like
layer_count
2026-01-28 14:29:36 +00:00
Alex Cheema
f1a2d054ec Update tagline to "Run frontier AI locally" (#1313)
- Update README tagline from "Run your own AI cluster at home with
everyday devices" to "Run frontier AI locally"
2026-01-28 12:38:14 +00:00
Alex Cheema
b3c8f85fc8 Update MLX to 0.30.4 (#1311)
## Summary
- Bump mlx from 0.30.3 to 0.30.4

## Test plan
- [x] `uv lock` succeeds
- [x] Type checking passes (`uv run basedpyright`)
- [x] Run inference tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 04:30:21 -08:00
rltakashige
a562114ba5 Add Kimi K2.5 support (#1302)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
2026-01-28 05:44:19 +00:00
Evan Quiney
991d278119 replace nix fmt with treefmt in just lint (#1301)
man evaluating the nix flake is so slow. treefmt speeeedy
2026-01-27 17:03:01 +00:00
rltakashige
c55cbf6739 Add mlx lm style tensor sharding for Minimax (#1299)
## Motivation

Broken right now. We'll potentially add a better one later

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Used for evals without any issue.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-27 15:29:06 +00:00
Alex Cheema
bd4f0bf048 Fix download speed/ETA display for re-downloads (#1294)
## Motivation

After the download verification fix, when files are re-downloaded due to
upstream changes (size mismatch), the download progress displays
correctly (completion %, bytes, file counts), but speed shows 0 B/s and
ETA shows "--" for both overall and per-file progress.

## Changes

- Modified `on_progress_wrapper` in `src/exo/download/download_utils.py`
to detect re-download scenarios
- Added re-download detection: when `curr_bytes < previous_downloaded`,
the file was deleted and download restarted
- On re-download: reset `start_time` to current time and set
`downloaded_this_session = curr_bytes`
- Added two tests to `test_download_verification.py` covering
re-download and continuing download scenarios

## Why It Works

The bug occurred because:
1. `file_progress` is initialized with the OLD local file size (e.g.,
1.5GB)
2. When `_download_file` detects size mismatch, it deletes the file and
starts fresh
3. Progress callback receives small `curr_bytes` (e.g., 8KB) but
compares against old size
4. `downloaded_this_session = 0 + (8KB - 1.5GB) = -1.5GB` (negative!)
5. Negative session bytes → 0 or negative speed → ETA shows "--"

The fix detects when `curr_bytes < previous_downloaded` (indicating
re-download started) and resets tracking to treat it as a fresh
download.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Download a model, modify a file to change its size, restart exo,
verify speed/ETA display correctly during re-download

### Automated Testing
- Added `TestProgressResetOnRedownload` class with two tests:
- `test_progress_resets_correctly_on_redownload`: Verifies progress
resets correctly when re-download starts
- `test_progress_accumulates_on_continuing_download`: Verifies
continuing downloads still accumulate correctly
- All 11 download tests pass
- Type checking (basedpyright): 0 errors
- Linting (ruff): All checks passed

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 21:56:58 +00:00
rltakashige
cd8c01b7c8 Fix kv prefix cache (#1262)
## Motivation

OpenCode sends very large prompts, most of which are repeated on the
next call.

## Changes

Add prefix caching, reducing average time in prefill (in testing) from
40 seconds to 4. This massively improves user experience.

Also evicts KV caches from this prefix cache in a LRU-style manner. 

## Why It Works

We no longer prefill repeatedly but rather use kv cache stored in
memory. A future update may want to use storage to make the prefix cache
larger.

## Test Plan

### Manual Testing
Tested speedup on OpenCode

### Automated Testing
Added a lot of tests

---------

Co-authored-by: David Hind <davehind@yahoo.co.uk>
2026-01-26 20:13:58 +00:00
rltakashige
59e991ce15 Only ignore message if actually empty (#1292)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-26 19:33:23 +00:00
ciaranbor
ffba340e70 Ciaran/image quantization (#1272)
## Motivation

Enable users to select and use quantized variants (8-bit, 4-bit) of
image models

## Changes

Use exolabs HF org for image models

## Why It Works

Quantized versions have been uploaded to exolabs HF org

## Test Plan

Loaded and ran different quantized variants. Confirmed lower memory
usage and different outputs for the same seed. Verified chat completion
still works.
2026-01-26 19:25:05 +00:00
rltakashige
9968abe816 Leo/fix basic model shard (#1291)
## Motivation

Some models, on some configurations, would have several issues that
caused the model to be stuck on loading.

## Changes

Several loading issues were with upstream mlx lm shard loading for
tensor parallel.
GLM 4.7 Flash now uses GLM 4.7 Lite.
A final portion of the issues were from mlx memory not being properly
released before calling mx.eval(model), causing the system to run out of
memory.

## Test Plan

### Manual Testing
Done a bunch (thanks @AlexCheema), hopefully exhaustive. 

### Automated Testing
A bunch of automated testing is imminent but not landed yet.

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
2026-01-26 17:49:09 +00:00
Alex Cheema
0e30b0830f Fix download system for upstream file changes (#1290)
## Motivation

When upstream files change on Hugging Face, exo's download system
doesn't detect the change and downloads get stuck. The only workaround
is deleting `~/.exo/models/` and the cache.

Root causes:
1. Existing files are never re-verified against remote metadata
2. File list cache is never invalidated, causing stale sizes to be used

## Changes

1. **Verify existing files against remote size** (`_download_file`):
Before returning early for existing files, verify the local file size
matches remote. If mismatched, delete and re-download. If network fails
(offline), fall back to trusting local file.

2. **Always try fresh file list first** (`fetch_file_list_with_cache`):
Always attempt to fetch fresh data from Hugging Face. On success, update
the cache. On failure, fall back to cached data if available.

3. **Clear cache on model delete** (`delete_model`): When a model is
deleted, also delete its cache entry to prevent stale metadata.

## Why It Works

- **Online**: Stale local files are detected via size mismatch and
re-downloaded. Fresh file list is always fetched and cache is updated.
- **Offline with cache**: Existing files are trusted. Cached file list
is used as fallback.
- **Offline without cache**: Fails gracefully (can't download without
knowing what files to get).

The size check is O(1) so there's no performance impact. Hash
verification still happens after download completes (existing behavior).

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Download a model, manually modify a local file's content, restart exo,
verify it re-downloads

### Automated Testing
Added 9 new tests in
`src/exo/download/tests/test_download_verification.py`:
- Re-download when file size changes upstream
- Skip download when file size matches
- Offline fallback uses local file
- Fetch fresh file list and update cache
- Fall back to cache when fetch fails
- Error propagates when no cache exists
- Model delete clears cache
- Delete when only cache exists
- Delete nonexistent model

All tests pass: `uv run pytest src/exo/download/tests/ -v`

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 09:14:58 -08:00
Alex Cheema
44453c4c8b Remove change-detection checks from info gatherer monitors (#1283)
## Summary
- When a node times out, its info gets cleared from state. The monitor
functions only sent data when something changed, leaving no mechanism to
re-populate this info after a timeout.
- Removes change-detection checks from `_monitor_misc`,
`_monitor_system_profiler_thunderbolt_data`, `_watch_system_info`, and
`_monitor_thunderbolt_bridge_status` so data is sent periodically
regardless of whether it changed.

## Test plan
- [ ] Verify type checker passes: `uv run basedpyright`
- [ ] Verify linter passes: `uv run ruff check`
- [ ] Verify tests pass: `uv run pytest`
- [ ] Manually test that node info is re-populated after a timeout by
observing cluster behavior

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-26 12:23:22 +00:00
Jake Hillion
1290e8ed9f dashboard: fix prettier-svelte rebuilding on every file change
The prettier-svelte package was rebuilding whenever any file in the
repository changed because dashboardStubSrc referenced inputs.self
directly. Since inputs.self's store path hash is computed from the
entire repository contents, any file modification invalidated the
derivation.

Added dashboardLockfileSrc using lib.cleanSourceWith to filter
inputs.self to only include package.json and package-lock.json from
the dashboard directory. Updated dashboardStubSrc to reference this
filtered source instead of inputs.self directly.

This ensures prettier-svelte only rebuilds when the lockfiles actually
change, significantly improving build caching for unrelated changes.

Test plan:
- Built prettier-svelte with nix build .#prettier-svelte
- Modified src/exo/main.py and rebuilt - same store path (no rebuild)
- Modified dashboard/package.json and rebuilt - different store path (rebuild triggered)
- Ran nix flake check successfully
2026-01-26 12:02:05 +00:00
Evan Quiney
d93db3d6bf re enable the evil network script (#1277)
seems like we still need the interfaces to be routable for mdns. at
least we're not dependent on this behaviour anymore.
2026-01-24 13:36:06 +00:00
Alex Cheema
ff4a2022f7 Revert state compaction (#1259) (#1275)
## Summary

Reverts the state compaction feature (#1259) to investigate issues with
nodes staying as "unknown" after joining a cluster.

## Test plan

- [ ] Verify nodes properly show up after joining cluster
- [ ] Verify state catchup works correctly without compaction

🤖 Generated with [Claude Code](https://claude.com/claude-code)
2026-01-23 16:29:48 -08:00
rltakashige
cee48f6f34 Parse GPT OSS tool calling (#1271)
## Motivation

<img width="3162" height="858" alt="image"
src="https://github.com/user-attachments/assets/e552f373-620a-4522-894b-6f93fd7f1e50"
/>

## Changes

OpenAI Harmony StreamableParser does parsing for us.

## Why It Works

<img width="3230" height="588" alt="image"
src="https://github.com/user-attachments/assets/81f8a43e-c04b-4bd0-9fd0-65e9b5f6ea1d"
/>
2026-01-23 20:43:53 +00:00
Evan Quiney
2b67e84a03 state compaction (#1259)
## motivation

a node joining a long-running cluster would bring down networking. this
attempts to mitigate that issue by compacting the state for catching up
new devices

## changes

introduces a new topic ("state_catchup") over which a full state can be
sent. currently the master sends the worker + api this new state, and
they update only if they have no other events applied - otherwise usual
NACK systems function

## testing

manually tested on two and eight nodes - its an improvement, not a fix

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-23 20:32:49 +00:00
Alex Cheema
7204fdeb4a Restore Thunderbolt Bridge LaunchDaemon (#1270)
## Motivation

The LaunchDaemon approach for disabling Thunderbolt Bridge was removed
in commit 43f12f5d and replaced with dynamic cycle detection. However,
the LaunchDaemon runs automatically on reboot, ensuring the bridge is
always disabled before it can cause packet storms.

## Changes

- Restore `NetworkSetupHelper.promptAndInstallIfNeeded()` to install a
LaunchDaemon that disables Thunderbolt Bridge on startup
- Show user prompt explaining what will be installed before requesting
admin password
- Remove old cleanup-only logic from `EXOApp.swift`
- Installer removes any existing installation before installing fresh
(handles upgrades)

## Why It Works

The LaunchDaemon runs at boot with `RunAtLoad=true` and periodically
(every ~30 min), destroying bridge0 and disabling Thunderbolt Bridge
before it can cause packet storms. The daemon is only installed
once—`daemonAlreadyInstalled()` checks script content and plist config
match before prompting.

## Test Plan

### Manual Testing
- Run app first time → should see prompt → click Install → enter admin
password → daemon installed
- Run app again → no prompt (already installed)
- Reboot → bridge0 should be destroyed/disabled automatically
- Check daemon: `launchctl list | grep io.exo.networksetup`
- Check files: `/Library/LaunchDaemons/io.exo.networksetup.plist`,
`/Library/Application Support/EXO/disable_bridge.sh`

### Automated Testing
N/A - requires admin privileges and system-level changes

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 20:25:37 +00:00
Evan Quiney
ec345a4315 fix: deprioritise uncertain ethernet devices (#1267)
we were placing coordinators on uncertain devices (enX+) that are listed
as "USB LAN" - these could be thunderbolt ports breaking RDMA instances
2026-01-23 20:13:28 +00:00
ciaranbor
9967dfa734 Prevent conversation collision (#1266)
## Motivation

When a user switched conversations while a response was still streaming,
the streaming content would be written to the currently selected
conversation instead of the original one. For streamed image generation,
each partial image would be written to the open conversation

## Changes

Added helper methods to track and update the correct conversation during
streaming:
- updateConversationMessage() - Update a message in a specific
conversation by ID
- syncActiveMessagesIfNeeded() - Sync this.messages from target
conversation only if it's active
- conversationExists() - Check if a conversation still exists (handles
mid-stream deletion)
  - persistConversation() - Persist a specific conversation to storage
- addMessageToConversation() - Add a message directly to a specific
conversation


## Why It Works

Capturing the conversation ID at the start of the request ensures we
know which conversation to update

## Test Plan

### Manual Testing

Tested switching conversation during generation across each model type
2026-01-23 19:59:08 +00:00
ciaranbor
23fd37fe4d Add FLUX.1-Krea-dev model (#1269)
## Why It Works

Same implementation as FLUX.1-dev, just different weights
2026-01-23 19:48:24 +00:00
Alex Cheema
d229df38f9 Fix placement filter to use subset matching instead of exact match (#1265)
## Motivation

When using the dashboard's instance placement filter (clicking nodes in
the topology), it was filtering to placements that use exactly the
selected nodes. This isn't the expected behavior - users want to see
placements that include all selected nodes, but may also include
additional nodes.

For example, selecting nodes [A, B] should show placements using [A, B],
[A, B, C], [A, B, C, D], etc. - not just [A, B].

## Changes

- Added `required_nodes` parameter to `place_instance()` in
`placement.py`
- Filter cycles early in placement to only those containing all required
nodes (subset matching)
- Simplified `api.py` by removing the subgraph topology filtering and
passing `required_nodes` directly to placement
- Renamed internal `node_ids` variable to `placement_node_ids` to avoid
shadowing the parameter

## Why It Works

By filtering cycles at the placement level using
`required_nodes.issubset(cycle.node_ids)`, we ensure that only cycles
containing all the user-selected nodes are considered. This happens
early in the placement algorithm, so we don't waste time computing
placements that would be filtered out later.

## Test Plan

### Manual Testing
- Select nodes in the dashboard topology view
- Verify that placements shown include all selected nodes (but may
include additional nodes)
- Verify that placements not containing the selected nodes are filtered
out

### Automated Testing
- Existing placement tests pass
- `uv run pytest src/exo/master/tests/ -v` - 37 tests pass

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 19:40:31 +00:00
Alex Cheema
8a595fee2f Fix Thunderbolt bridge cycle detection to include 2-node cycles (#1261)
## Motivation

Packet storms occur with Thunderbolt bridge enabled on 2 machines
connected by Thunderbolt, not just 3+ node cycles as previously assumed.
The cycle detection was too conservative and missed this case.

## Changes

- Changed the minimum cycle length from >2 (3+ nodes) to >=2 (2+ nodes)
- Updated the early return threshold from `< 3` to `< 2` enabled nodes
- Updated docstring to reflect the new behavior

## Why It Works

A Thunderbolt bridge loop between just 2 machines can still create
broadcast storms when both have the bridge enabled. The previous
threshold of 3+ was based on an incorrect assumption that 2-node
connections wouldn't cause this problem.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
- Tested with 2 machines connected via Thunderbolt with bridge enabled
- Confirmed packet storms occur in this configuration
- Verified the fix correctly detects and handles 2-node cycles

### Automated Testing
- Existing topology tests cover cycle detection logic

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 19:34:48 +00:00
ciaranbor
c8571a17a3 Fix guidance (#1264)
## Motivation

Previously, we only handled user-provided guidance parameter for CFG
models.

## Changes

Just pass the parameter to model setup
2026-01-23 19:13:45 +00:00
Evan Quiney
771a86331b fix instance port assignment (#1268)
we were overassigning the port 52414 to instances because of an error in placement
2026-01-23 18:37:40 +00:00
Jake Hillion
6dbbe7797b downloads: add download and delete buttons to downloads UI
The downloads page showed model download progress but provided no way
for users to trigger downloads or remove completed models from disk.

Added API endpoints (POST /download/start, DELETE /download/{node_id}/{model_id})
that send StartDownload and DeleteDownload commands via the download_command_sender.
Updated the dashboard downloads page with per-model buttons: a download button
for incomplete downloads and a delete button for completed ones.

This allows users to manage downloads directly from the UI without needing
to trigger downloads through other means.

Test plan:
- Deployed on a 3 machine cluster. Did several downloads/deletions - all
  work and the dashboard updates relatively fluently. It takes roughly 5
  seconds to render a 131GB model deletion which isn't too bad.
2026-01-23 18:11:17 +00:00
Jake Hillion
9357503c6f downloads: refactor to run at node level
The Worker previously owned the ShardDownloader directly via dependency
injection, which prevented --no-worker nodes from downloading and made
it impossible for multiple Workers to share a single downloader instance.

Moved download functionality to a new DownloadCoordinator component at
the Node level that communicates via the DOWNLOAD_COMMANDS pub/sub topic.
Workers now send StartDownload commands instead of calling the downloader
directly, and receive progress updates through the event-sourced state.

This decouples downloads from the Worker lifecycle and enables future
features like UI-triggered downloads to specific nodes and multi-worker
download sharing.

Test plan:
- Mostly tested in the next PR that adds explicit downloads/deletions to
  the dashboard.
- Started a model that isn't downloaded - it works.
2026-01-23 18:04:09 +00:00
ciaranbor
ba19940828 Fix regenerate for image models (#1263)
## Motivation

The 'regenerate' button was hardcoded to chat completion. Clicking
'regenerate' for image request would result in an error after the model
is loaded

## Changes

Store request type and dispatch to appropriate request upon regeneration

## Why It Works

We make sure to repeat the same request type as was performed originally

## Test Plan

### Manual Testing

Checked 'regenerate' works for chat completion, image generation, image
editing
2026-01-23 16:33:01 +00:00
Jake Hillion
f255345a1a dashboard: decouple prettier-svelte from dashboard source
The prettier-svelte formatter depended on the full dashboard build
(dashboardFull), causing the devshell to rebuild whenever any dashboard
source file changed.

Created a deps-only dream2nix derivation (deps.nix) that uses a stub
source containing only package.json, package-lock.json, and minimal
files for vite to succeed. Updated prettier-svelte to use this
derivation instead of dashboardFull.

The stub source is constant unless lockfiles change, so prettier-svelte
and the devshell no longer rebuild when dashboard source files are
modified.

Test plan:
- nix flake check passed
- nix fmt successfully formatted svelte files
2026-01-23 15:16:48 +00:00
ciaranbor
a1939c89f2 Enable UI settings for image editing (#1258)
## Motivation

Image editing was missing UI controls for quality, output format, and
advanced parameters that text-to-image generation already supported.

## Changes

- Added quality, output_format, and advanced_params to image edit API
endpoints
- Extended isImageModel check to include image editing models

## Why It Works

The API now accepts and forwards these settings for image edits, and the
UI displays the appropriate controls for image editing models.

## Test Plan

### Manual Testing

Verified parameters can be set in UI and that they progagate through to
model inference
2026-01-23 13:37:25 +00:00
ciaranbor
cb9c9ee55c Enable generating multiple images. Optionally stream partial images (#1251)
## Motivation

Support OpenAI API `n` setting

## Changes

- Users can select `n` to generate more than one image with the same
prompt
- each image uses a different seed -> different results
- `stream` and `partial_images` settings can be overwritten in UI
2026-01-23 11:19:58 +00:00
80 changed files with 7333 additions and 2585 deletions

View File

@@ -1,12 +0,0 @@
name: Type Check
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

View File

@@ -26,73 +26,14 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Sync dependencies
run: uv sync --all-packages
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
nix:
name: Build and check (${{ matrix.system }})
@@ -123,6 +64,63 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build Metal packages (macOS only)
if: runner.os == 'macOS'
run: |
# Try to build metal-toolchain first (may succeed via cachix cache hit)
if nix build .#metal-toolchain 2>/dev/null; then
echo "metal-toolchain built successfully (likely cache hit)"
else
echo "metal-toolchain build failed, extracting from Xcode..."
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
NAR_NAME="metal-toolchain-17C48.nar"
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
WORK_DIR="${RUNNER_TEMP}/metal-work"
mkdir -p "$WORK_DIR"
# Download the Metal toolchain component
xcodebuild -downloadComponent MetalToolchain
# Find and mount the DMG
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
if [ -z "$DMG_PATH" ]; then
echo "Error: Could not find Metal toolchain DMG"
exit 1
fi
echo "Found DMG at: $DMG_PATH"
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
# Copy the toolchain
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
hdiutil detach "${WORK_DIR}/metal-dmg"
# Create NAR and add to store
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
echo "Added NAR to store: $STORE_PATH"
# Verify the hash matches
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
echo "Warning: NAR hash mismatch!"
echo "Expected: $NAR_HASH"
echo "Actual: $ACTUAL_HASH"
echo "The metal-toolchain.nix may need updating"
fi
# Clean up
rm -rf "$WORK_DIR"
# Retry the build now that NAR is in store
nix build .#metal-toolchain
fi
# Build mlx (depends on metal-toolchain)
nix build .#mlx
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
@@ -134,3 +132,14 @@ jobs:
- name: Run nix flake check
run: nix flake check
- name: Run pytest (macOS only)
if: runner.os == 'macOS'
run: |
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

View File

@@ -5,7 +5,7 @@
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
@@ -107,6 +107,10 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
### Run from Source (Linux)
**Prerequisites:**
@@ -230,7 +234,7 @@ This removes:
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
Please refer to the caveats for immediate troubleshooting.
To enable RDMA on macOS, follow these steps:
@@ -247,6 +251,14 @@ To enable RDMA on macOS, follow these steps:
After that, RDMA will be enabled in macOS and exo will take care of the rest.
**Important Caveats**
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
2. The cables must support TB5.
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
---
### Using the API

View File

@@ -342,6 +342,8 @@
SDKROOT = macosx;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Debug;
};
@@ -397,6 +399,8 @@
MTL_FAST_MATH = YES;
SDKROOT = macosx;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Release;
};

View File

@@ -45,8 +45,8 @@ struct EXOApp: App {
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
enableLaunchAtLoginIfNeeded()
// Remove old LaunchDaemon components if they exist (from previous versions)
cleanupLegacyNetworkSetup()
// Install LaunchDaemon to disable Thunderbolt Bridge on startup (prevents network loops)
NetworkSetupHelper.promptAndInstallIfNeeded()
// Check local network access periodically (warning disappears when user grants permission)
localNetwork.startPeriodicChecking(interval: 10)
controller.scheduleLaunch(after: 15)
@@ -136,36 +136,6 @@ struct EXOApp: App {
}
}
private func cleanupLegacyNetworkSetup() {
guard NetworkSetupHelper.hasInstalledComponents() else { return }
// Dispatch async to ensure app is ready before showing alert
DispatchQueue.main.async {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to configure local network discovery on your device. This requires granting permission once."
alert.alertStyle = .informational
alert.addButton(withTitle: "Continue")
alert.addButton(withTitle: "Later")
let response = alert.runModal()
guard response == .alertFirstButtonReturn else {
Logger().info("User deferred legacy network setup cleanup")
return
}
do {
try NetworkSetupHelper.uninstall()
Logger().info("Cleaned up legacy network setup components")
} catch {
// Non-fatal: user may have cancelled admin prompt or cleanup may have
// partially succeeded. The app will continue normally.
Logger().warning(
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
)
}
}
}
}
/// Helper for managing EXO's launch-at-login registration
@@ -255,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
}
}
private func showNotification(title: String, body: String) {
nonisolated private func showNotification(title: String, body: String) {
let center = UNUserNotificationCenter.current()
let content = UNMutableNotificationContent()
content.title = title

View File

@@ -11,6 +11,100 @@ enum NetworkSetupHelper {
private static let legacyScriptDestination =
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1786
private static let setupScript = """
#!/usr/bin/env bash
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 20
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
/// Prompts user and installs the LaunchDaemon if not already installed.
/// Shows an alert explaining what will be installed before requesting admin privileges.
static func promptAndInstallIfNeeded() {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
// If already correctly installed, skip
if daemonAlreadyInstalled() {
return
}
// Show alert on main thread
let shouldInstall = await MainActor.run {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")
return alert.runModal() == .alertFirstButtonReturn
}
guard shouldInstall else {
logger.info("User deferred network setup daemon installation")
return
}
do {
try installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
@@ -30,6 +124,100 @@ enum NetworkSetupHelper {
return scriptExists || legacyScriptExists || plistExists
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
guard
let interval = plist["StartInterval"] as? Int,
interval == requiredStartInterval
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
}
private static func installLaunchDaemon() throws {
let installerScript = makeInstallerScript()
try runShellAsAdmin(installerScript)
}
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
# First, completely remove any existing installation
launchctl bootout system/"$LABEL" 2>/dev/null || true
rm -f "$PLIST_DEST"
rm -f "$SCRIPT_DEST"
rm -f "$LEGACY_SCRIPT_DEST"
rm -f "$LOG_OUT" "$LOG_ERR"
# Install fresh
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func makeUninstallScript() -> String {
"""
set -euo pipefail
@@ -56,11 +244,11 @@ enum NetworkSetupHelper {
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic 2>/dev/null || true
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
# Delete the exo network location if it exists
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
networksetup -deletelocation exo >/dev/null 2>&1 || true
} || true
# Re-enable any Thunderbolt Bridge service if it exists
@@ -70,12 +258,12 @@ enum NetworkSetupHelper {
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
')
') || true
[ -z "$tb_devices" ] && return 0
# For each bridge device, check if it contains Thunderbolt interfaces
for bridge in bridge0 bridge1 bridge2; do
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
[ -z "$members" ] && continue
for tb_dev in $tb_devices; do
@@ -84,7 +272,7 @@ enum NetworkSetupHelper {
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
/Device:/ && $0 ~ dev { print svc; exit }
')
') || true
if [ -n "$service_name" ]; then
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
return 0
@@ -92,8 +280,9 @@ enum NetworkSetupHelper {
fi
done
done
return 0
}
find_and_enable_thunderbolt_bridge
find_and_enable_thunderbolt_bridge || true
echo "EXO network components removed successfully"
"""

View File

@@ -127,21 +127,24 @@ final class ThunderboltBridgeService: ObservableObject {
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
status = rightName.withCString { nameCString in
var item = AuthorizationItem(
name: nameCString,
valueLength: 0,
value: nil,
flags: 0
)
return withUnsafeMutablePointer(to: &item) { itemPointer in
var rights = AuthorizationRights(count: 1, items: itemPointer)
return AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
}
}
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled

View File

@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
@@ -55,64 +55,64 @@ echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
if [[ -f $PLIST_DEST ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
if [[ -f $SCRIPT_DEST ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
echo_info "Removed EXO support directory" ||
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
if [[ -d $app_path ]]; then
if [[ $APP_FOUND == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
@@ -151,4 +151,3 @@ echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

View File

@@ -865,7 +865,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -905,7 +904,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1522,7 +1520,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1532,7 +1529,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1945,7 +1941,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2653,7 +2648,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2696,7 +2690,6 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2869,7 +2862,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3014,7 +3006,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3036,7 +3027,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -3,6 +3,61 @@
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to ONLY include package.json and package-lock.json
# This ensures prettier-svelte only rebuilds when lockfiles change
dashboardLockfileSrc = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
isDashboardDir = baseName == "dashboard" && type == "directory";
isPackageFile =
(lib.hasInfix "/dashboard/" path || lib.hasSuffix "/dashboard" (builtins.dirOf path))
&& (baseName == "package.json" || baseName == "package-lock.json");
in
isDashboardDir || isPackageFile;
};
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${dashboardLockfileSrc}/dashboard/package.json $out/
cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src
touch $out/src/app.html
'';
# Deps-only build using stub source (for prettier-svelte)
# Only rebuilds when package.json or package-lock.json change
dashboardDeps = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
{
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
}
# Override build phases to skip the actual build - just need node_modules
{
mkDerivation = {
buildPhase = lib.mkForce "true";
installPhase = lib.mkForce ''
runHook preInstall
runHook postInstall
'';
};
}
];
};
# Filter source to only include dashboard directory
dashboardSrc = lib.cleanSourceWith {
src = inputs.self;
@@ -42,11 +97,12 @@
'';
# Prettier with svelte plugin for treefmt
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
exec ${pkgs.nodejs}/bin/node \
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
"$@"
'';
};

View File

@@ -89,7 +89,10 @@
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsTextToImage(currentModel);
return (
modelSupportsTextToImage(currentModel) ||
modelSupportsImageEditing(currentModel)
);
});
const isEditOnlyWithoutImage = $derived(
@@ -646,6 +649,23 @@
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg

View File

@@ -110,6 +110,36 @@
setImageGenerationParams({ negativePrompt: value || null });
}
function handleNumImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ numImages: 1 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 1) {
setImageGenerationParams({ numImages: num });
}
}
}
function handleStreamChange(enabled: boolean) {
setImageGenerationParams({ stream: enabled });
}
function handlePartialImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ partialImages: 0 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 0) {
setImageGenerationParams({ partialImages: num });
}
}
}
function clearSteps() {
setImageGenerationParams({ numInferenceSteps: null });
}
@@ -134,90 +164,92 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{/if}
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -325,6 +357,59 @@
</div>
</div>
<!-- Number of Images (not in edit mode) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>IMAGES:</span
>
<input
type="number"
min="1"
value={params.numImages}
oninput={handleNumImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Stream toggle -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>STREAM:</span
>
<button
type="button"
onclick={() => handleStreamChange(!params.stream)}
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
? 'bg-exo-yellow'
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
>
<div
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
? 'right-0.5 bg-exo-black'
: 'left-0.5 bg-exo-light-gray'}"
></div>
</button>
</div>
<!-- Partial Images (only when streaming) -->
{#if params.stream}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>PARTIALS:</span
>
<input
type="number"
min="0"
value={params.partialImages}
oninput={handlePartialImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Input Fidelity (edit mode only) -->
{#if isEditMode}
<div class="flex items-center gap-1.5">

View File

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,8 @@
type DownloadProgress,
refreshState,
lastUpdate as lastUpdateStore,
startDownload,
deleteDownload,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
@@ -28,6 +30,7 @@
etaMs: number;
status: "completed" | "downloading";
files: FileProgress[];
shardMetadata?: Record<string, unknown>;
};
type NodeEntry = {
@@ -269,6 +272,12 @@
}
}
// Extract shard_metadata for use with download actions
const shardMetadata = (downloadPayload.shard_metadata ??
downloadPayload.shardMetadata) as
| Record<string, unknown>
| undefined;
const entry: ModelEntry = {
modelId,
prettyName,
@@ -285,6 +294,7 @@
? "completed"
: "downloading",
files,
shardMetadata,
};
const existing = modelMap.get(modelId);
@@ -469,6 +479,52 @@
>
{pct.toFixed(1)}%
</span>
{#if model.status !== "completed" && model.shardMetadata}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
onclick={() =>
startDownload(node.nodeId, model.shardMetadata!)}
title="Start download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
{#if model.status === "completed"}
<button
type="button"
class="text-exo-light-gray hover:text-red-400 transition-colors"
onclick={() =>
deleteDownload(node.nodeId, model.modelId)}
title="Delete download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"

View File

@@ -0,0 +1,172 @@
<script lang="ts">
import { onMount } from "svelte";
import {
listTraces,
getTraceRawUrl,
type TraceListItem,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
let traces = $state<TraceListItem[]>([]);
let loading = $state(true);
let error = $state<string | null>(null);
function formatBytes(bytes: number): string {
if (!bytes || bytes <= 0) return "0B";
const units = ["B", "KB", "MB", "GB"];
const i = Math.min(
Math.floor(Math.log(bytes) / Math.log(1024)),
units.length - 1,
);
const val = bytes / Math.pow(1024, i);
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
}
function formatDate(isoString: string): string {
const date = new Date(isoString);
return date.toLocaleString();
}
async function openInPerfetto(taskId: string) {
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
async function refresh() {
loading = true;
error = null;
try {
const response = await listTraces();
traces = response.traces;
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load traces";
} finally {
loading = false;
}
}
onMount(() => {
refresh();
});
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Traces
</h1>
</div>
<div class="flex items-center gap-3">
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={refresh}
disabled={loading}
>
Refresh
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading traces...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if traces.length === 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
>
<div class="text-sm">No traces found.</div>
<div class="text-xs text-exo-light-gray/70">
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
</div>
</div>
{:else}
<div class="space-y-3">
{#each traces as trace}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
>
<div class="min-w-0 flex-1">
<a
href="#/traces/{trace.taskId}"
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
>
{trace.taskId}
</a>
<div class="text-xs text-exo-light-gray font-mono mt-1">
{formatDate(trace.createdAt)} &bull; {formatBytes(
trace.fileSize,
)}
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
<a
href="#/traces/{trace.taskId}"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
>
View Stats
</a>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
onclick={() => openInPerfetto(trace.taskId)}
>
View Trace
</button>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>

View File

@@ -0,0 +1,347 @@
<script lang="ts">
import { page } from "$app/stores";
import { onMount } from "svelte";
import {
fetchTraceStats,
getTraceRawUrl,
type TraceStatsResponse,
type TraceCategoryStats,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
const taskId = $derived($page.params.taskId);
let stats = $state<TraceStatsResponse | null>(null);
let loading = $state(true);
let error = $state<string | null>(null);
function formatDuration(us: number): string {
if (us < 1000) return `${us.toFixed(0)}us`;
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
return `${(us / 1_000_000).toFixed(2)}s`;
}
function formatPercentage(part: number, total: number): string {
if (total === 0) return "0.0%";
return `${((part / total) * 100).toFixed(1)}%`;
}
// Parse hierarchical categories like "sync/compute" into phases
type PhaseData = {
name: string;
subcategories: { name: string; stats: TraceCategoryStats }[];
totalUs: number; // From outer span (e.g., "sync" category)
stepCount: number; // Count of outer span events
};
function parsePhases(
byCategory: Record<string, TraceCategoryStats>,
): PhaseData[] {
const phases = new Map<
string,
{
subcats: Map<string, TraceCategoryStats>;
outerStats: TraceCategoryStats | null;
}
>();
for (const [category, catStats] of Object.entries(byCategory)) {
if (category.includes("/")) {
const [phase, subcat] = category.split("/", 2);
if (!phases.has(phase)) {
phases.set(phase, { subcats: new Map(), outerStats: null });
}
phases.get(phase)!.subcats.set(subcat, catStats);
} else {
// Outer span - this IS the phase total
if (!phases.has(category)) {
phases.set(category, { subcats: new Map(), outerStats: null });
}
phases.get(category)!.outerStats = catStats;
}
}
return Array.from(phases.entries())
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
.map(([name, data]) => ({
name,
subcategories: Array.from(data.subcats.entries())
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
totalUs: data.outerStats!.totalUs, // Outer span total
stepCount: data.outerStats!.count, // Number of steps
}))
.sort((a, b) => b.totalUs - a.totalUs);
}
async function openInPerfetto() {
if (!taskId) return;
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
onMount(async () => {
if (!taskId) {
error = "No task ID provided";
loading = false;
return;
}
try {
stats = await fetchTraceStats(taskId);
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load trace";
} finally {
loading = false;
}
});
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
const sortedRanks = $derived(
stats
? Object.keys(stats.byRank)
.map(Number)
.sort((a, b) => a - b)
: [],
);
const nodeCount = $derived(sortedRanks.length || 1);
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Trace
</h1>
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
{taskId}
</p>
</div>
<div class="flex items-center gap-3">
<a
href="#/traces"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
>
All Traces
</a>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
onclick={openInPerfetto}
disabled={loading || !!error}
>
View Trace
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading trace data...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if stats}
<!-- Wall Time Summary -->
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
Summary
</h2>
<div class="text-3xl font-mono text-exo-yellow">
{formatDuration(stats.totalWallTimeUs)}
</div>
<div class="text-xs text-exo-light-gray">Total wall time</div>
</div>
<!-- By Phase -->
{#if phases.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
</h2>
<div class="space-y-4">
{#each phases as phase}
{@const normalizedTotal = phase.totalUs / nodeCount}
{@const normalizedStepCount = phase.stepCount / nodeCount}
<div class="space-y-2">
<div class="flex items-center justify-between">
<span class="text-sm font-mono text-white">{phase.name}</span>
<span class="text-sm font-mono">
<span class="text-exo-yellow"
>{formatDuration(normalizedTotal)}</span
>
<span class="text-exo-light-gray ml-2">
({normalizedStepCount} steps, {formatDuration(
normalizedTotal / normalizedStepCount,
)}/step)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-4 space-y-1.5">
{#each phase.subcategories as subcat}
{@const normalizedSubcat =
subcat.stats.totalUs / nodeCount}
{@const pct = formatPercentage(
normalizedSubcat,
normalizedTotal,
)}
{@const perStep = normalizedSubcat / normalizedStepCount}
<div
class="flex items-center justify-between text-xs font-mono"
>
<span class="text-exo-light-gray">{subcat.name}</span>
<span class="text-white">
{formatDuration(normalizedSubcat)}
<span class="text-exo-light-gray ml-2">({pct})</span>
<span class="text-exo-light-gray/60 ml-2"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
<!-- Progress bar -->
<div
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
style="width: {pct}"
></div>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/if}
<!-- By Rank -->
{#if sortedRanks.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Rank
</h2>
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{#each sortedRanks as rank}
{@const rankStats = stats.byRank[rank]}
{@const rankPhases = parsePhases(rankStats.byCategory)}
<div
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
>
<div class="text-sm font-mono text-exo-yellow">
Rank {rank}
</div>
<div class="space-y-2">
{#each rankPhases as phase}
<div class="space-y-1">
<div class="flex items-center justify-between text-xs">
<span class="font-mono text-exo-light-gray"
>{phase.name}</span
>
<span class="font-mono text-white">
{formatDuration(phase.totalUs)}
<span class="text-exo-light-gray/50 ml-1">
({phase.stepCount}x)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-2 space-y-0.5">
{#each phase.subcategories as subcat}
{@const pct = formatPercentage(
subcat.stats.totalUs,
phase.totalUs,
)}
{@const perStep =
subcat.stats.totalUs / phase.stepCount}
<div
class="flex items-center justify-between text-[10px] font-mono"
>
<span class="text-exo-light-gray/70"
>{subcat.name}</span
>
<span class="text-exo-light-gray">
{formatDuration(subcat.stats.totalUs)}
<span class="text-exo-light-gray/50"
>({pct})</span
>
<span class="text-exo-light-gray/30 ml-1"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>

65
flake.lock generated
View File

@@ -21,7 +21,9 @@
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1765953015,
@@ -149,19 +151,44 @@
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1763662255,
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"lastModified": 1764134915,
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
"type": "github"
},
"original": {
@@ -178,7 +205,10 @@
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"treefmt-nix": "treefmt-nix",
"uv2nix": "uv2nix"
}
},
"rust-analyzer-src": {
@@ -239,6 +269,29 @@
"repo": "treefmt-nix",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1767701098,
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",

View File

@@ -24,6 +24,26 @@
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
};
# Python packaging with uv2nix
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
@@ -48,6 +68,7 @@
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
];
perSystem =
@@ -58,6 +79,11 @@
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
_module.args.pkgs = import inputs.nixpkgs {
inherit system;
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
};
treefmt = {
projectRootFile = "flake.nix";
programs = {
@@ -79,14 +105,24 @@
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
shfmt.enable = true;
};
};
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit uvLockMlxVersion;
};
}
);
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];

View File

@@ -1,7 +1,7 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt
treefmt || nix fmt
lint:
uv run ruff check --fix

View File

@@ -0,0 +1,79 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0ed30932..d8528132 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
- # Throw an error if xcrun not found
- execute_process(
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
- OUTPUT_VARIABLE MACOS_SDK_VERSION
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
+ set(MACOS_SDK_VERSION @sdkVersion@)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
- execute_process(
- COMMAND
- zsh "-c"
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
+ set(
+ MLX_METAL_VERSION @metalVersion@)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
index 13db804a..5b385132 100644
--- a/cmake/extension.cmake
+++ b/cmake/extension.cmake
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
- xcrun -sdk macosx metal
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
index 262b0495..5c7446ad 100644
--- a/mlx/backend/metal/kernels/CMakeLists.txt
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
add_custom_command(
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
@@ -170,7 +170,7 @@ endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
+ COMMAND metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
index bb55ed3a..94ea7dd7 100644
--- a/mlx/backend/metal/make_compiled_preamble.sh
+++ b/mlx/backend/metal/make_compiled_preamble.sh
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# Use the metal compiler to get a list of headers (with depth)
-CCC="xcrun -sdk macosx metal -x metal"
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
# Remove any included system frameworks (for MetalPerformancePrimitive headers)

56
nix/metal-toolchain.nix Normal file
View File

@@ -0,0 +1,56 @@
{ lib, stdenvNoCC, requireFile, nix }:
let
narFile = requireFile {
name = "metal-toolchain-17C48.nar";
message = ''
The Metal Toolchain NAR must be available.
If you have cachix configured for exo.cachix.org, this should be automatic.
Otherwise:
1. Install Xcode 26+ from the App Store
2. Run: xcodebuild -downloadComponent MetalToolchain
3. Export the toolchain:
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
hdiutil detach /tmp/metal-dmg
4. Create NAR and add to store:
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
'';
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
};
in
stdenvNoCC.mkDerivation {
pname = "metal-toolchain";
version = "17C48";
dontUnpack = true;
dontBuild = true;
dontFixup = true;
nativeBuildInputs = [ nix ];
installPhase = ''
runHook preInstall
nix-store --restore $out < ${narFile}
# Create bin directory with symlinks for PATH
mkdir -p $out/bin
ln -s $out/usr/bin/metal $out/bin/metal
ln -s $out/usr/bin/metallib $out/bin/metallib
runHook postInstall
'';
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
passthru.metalVersion = "400";
meta = {
description = "Apple Metal compiler toolchain";
platforms = [ "aarch64-darwin" ];
license = lib.licenses.unfree;
};
}

158
nix/mlx.nix Normal file
View File

@@ -0,0 +1,158 @@
{ stdenv
, lib
, fetchFromGitHub
, replaceVars
, fetchzip
, cmake
, nlohmann_json
, apple-sdk_26
, metal-toolchain
, runCommand
, fmt
, python313Packages
, uvLockMlxVersion
}:
assert stdenv.isDarwin;
let
python = python313Packages.python;
# Static dependencies included directly during compilation
gguf-tools = fetchFromGitHub {
owner = "antirez";
repo = "gguf-tools";
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
};
metal_cpp = fetchzip {
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
};
nanobind = fetchFromGitHub {
owner = "wjakob";
repo = "nanobind";
rev = "v2.10.2";
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
fetchSubmodules = true;
};
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.4"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
};
patches = [
(replaceVars ./darwin-build-fixes.patch {
sdkVersion = apple-sdk_26.version;
metalVersion = metal-toolchain.metalVersion;
})
];
postPatch = ''
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
--replace-fail "g++" "$CXX"
'';
dontUseCmakeConfigure = true;
enableParallelBuilding = true;
# Allows multiple cores to be used in Python builds.
postUnpack = ''
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
'';
# Updates the wrong fetcher rev attribute
passthru.skipBulkUpdate = true;
env = {
DEV_RELEASE = 1;
CMAKE_ARGS = toString [
(lib.cmakeBool "USE_SYSTEM_FMT" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
];
SDKROOT = apple-sdk_26.passthru.sdkroot;
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
};
build-system = [
python313Packages.setuptools
];
nativeBuildInputs = [
cmake
metal-toolchain
python313Packages.pypaBuildHook
python313Packages.pypaInstallHook
python313Packages.setuptools
python313Packages.typing-extensions
python313Packages.wheel
python313Packages.cmake
python313Packages.ninja
];
buildInputs = [
fmt
gguf-tools
python313Packages.nanobind
python313Packages.pybind11
apple-sdk_26
];
# Tests require Metal GPU access which isn't available in the Nix sandbox.
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
doCheck = false;
pythonImportsCheck = [ "mlx" ];
passthru.tests = {
# Runs example scripts to verify MLX works. Requires --option sandbox false
# since Metal GPU access is needed.
mlxTest =
runCommand "run-mlx-examples"
{
buildInputs = [ mlx ];
nativeBuildInputs = [ python ];
}
''
cp ${src}/examples/python/logistic_regression.py .
${python.interpreter} logistic_regression.py
rm logistic_regression.py
cp ${src}/examples/python/linear_regression.py .
${python.interpreter} linear_regression.py
rm linear_regression.py
touch $out
'';
};
meta = {
homepage = "https://github.com/ml-explore/mlx";
description = "Array framework for Apple silicon";
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
license = lib.licenses.mit;
platforms = [ "aarch64-darwin" ];
};
};
in
mlx

View File

@@ -17,16 +17,16 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.14.2",
"mflux==0.15.4",
"python-multipart>=0.0.21",
]
@@ -63,6 +63,7 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }

93
python/parts.nix Normal file
View File

@@ -0,0 +1,93 @@
{ inputs, ... }:
{
perSystem =
{ config, self', pkgs, lib, system, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = inputs.self;
};
# Create overlay from workspace
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
# Override overlay to inject Nix-built components
exoOverlay = final: prev: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
src = self'.packages.exo_pyo3_bindings;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
};
};
python = pkgs.python313;
# Overlay to provide build systems and custom packages
buildSystemsOverlay = final: prev: {
# Use our pure Nix-built MLX with Metal support
mlx = self'.packages.mlx;
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
exoOverlay
buildSystemsOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
}
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard}
done
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
};
};
}

View File

@@ -0,0 +1,284 @@
import asyncio
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.download.download_utils import (
RepoDownloadProgress,
delete_model,
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
# Only process commands targeting this node
if cmd.command.target_node_id != self.node_id:
continue
match cmd.command:
case StartDownload(shard_metadata=shard):
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
# Check if already downloading or complete
if model_id in self.download_status:
status = self.download_status[model_id]
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
logger.debug(
f"Download for {model_id} already in progress or complete, skipping"
)
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
# Check initial status from downloader
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
return
# Start actual download
self._start_download_task(shard, initial_progress)
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
model_id = shard.model_card.model_id
# Emit ongoing status
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[callback_shard.model_card.model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
# Clean up active download tracking
if callback_shard.model_card.model_id in self.active_downloads:
del self.active_downloads[callback_shard.model_card.model_id]
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[callback_shard.model_card.model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
except Exception as e:
logger.error(f"Download failed for {model_id}: {e}")
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
finally:
if model_id in self.active_downloads:
del self.active_downloads[model_id]
task = asyncio.create_task(download_wrapper())
self.active_downloads[model_id] = task
async def _delete_download(self, model_id: ModelId) -> None:
# Cancel if active
if model_id in self.active_downloads:
logger.info(f"Cancelling active download for {model_id} before deletion")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Delete from disk
logger.info(f"Deleting model files for {model_id}")
deleted = await delete_model(model_id)
if deleted:
logger.info(f"Successfully deleted model {model_id}")
else:
logger.warning(f"Model {model_id} was not found on disk")
# Emit pending status to reset UI state, then remove from local tracking
if model_id in self.download_status:
current_status = self.download_status[model_id]
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"
)

View File

@@ -24,7 +24,15 @@ from pydantic import (
TypeAdapter,
)
from exo.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
@@ -35,13 +43,6 @@ from exo.shared.types.worker.downloads import (
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
class HuggingFaceAuthenticationError(Exception):
@@ -120,11 +121,20 @@ async def ensure_models_dir() -> Path:
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
return True
models_dir = await ensure_models_dir()
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
deleted = False
if await aios.path.exists(model_dir):
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
deleted = True
# Also clear cache
if await aios.path.exists(cache_dir):
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
return deleted
async def seed_models(seed_dir: str | Path):
@@ -150,16 +160,28 @@ async def fetch_file_list_with_cache(
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
return file_list
# Always try fresh first
try:
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
# Update cache with fresh data
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
return file_list
except Exception as e:
# Fetch failed - try cache fallback
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
# No cache available, propagate the error
raise
async def fetch_file_list_with_retry(
@@ -331,8 +353,28 @@ async def _download_file(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path:
if await aios.path.exists(target_dir / path):
return target_dir / path
target_path = target_dir / path
if await aios.path.exists(target_path):
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
try:
remote_size, _ = await file_meta(model_id, revision, path)
if local_size != remote_size:
logger.info(
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
)
await aios.remove(target_path)
else:
return target_path
except Exception as e:
# Offline or network error - trust local file
logger.debug(
f"Could not verify {path} against remote (offline?): {e}, using local file"
)
return target_path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -481,6 +523,11 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
return ["*"]
def is_image_model(shard: ShardMetadata) -> bool:
tasks = shard.model_card.tasks
return ModelTask.TextToImage in tasks or ModelTask.ImageToImage in tasks
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path):
@@ -522,22 +569,40 @@ async def download_shard(
file_list, allow_patterns=allow_patterns, key=lambda x: x.path
)
)
# For image models, skip root-level safetensors files since weights
# are stored in component subdirectories (e.g., transformer/, vae/)
if is_image_model(shard):
filtered_file_list = [
f
for f in filtered_file_list
if "/" in f.path or not f.path.endswith(".safetensors")
]
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None:
start_time = (
file_progress[file.path].start_time
if file.path in file_progress
else time.time()
)
downloaded_this_session = (
file_progress[file.path].downloaded_this_session.in_bytes
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
if file.path in file_progress
else curr_bytes
previous_progress = file_progress.get(file.path)
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
speed = (
downloaded_this_session / (time.time() - start_time)
if time.time() - start_time > 0

View File

@@ -5,13 +5,13 @@ from typing import AsyncIterator, Callable
from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
from exo.worker.download.shard_downloader import ShardDownloader
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
model_card = await ModelCard.load(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -166,9 +166,8 @@ class ResumableShardDownloader(ShardDownloader):
for task in asyncio.as_completed(tasks):
try:
yield await task
# TODO: except Exception
except Exception as e:
logger.error("Error downloading shard:", e)
logger.warning(f"Error downloading shard: {type(e).__name__}")
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -5,13 +5,13 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.download.download_utils import RepoDownloadProgress
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?

View File

View File

@@ -0,0 +1,451 @@
"""Tests for download verification and cache behavior."""
import time
from collections.abc import AsyncIterator
from datetime import timedelta
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from pydantic import TypeAdapter
from exo.download.download_utils import (
delete_model,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
"""Set up a temporary models directory for testing."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestFileVerification:
"""Tests for file size verification in _download_file."""
async def test_redownload_when_file_size_changes_upstream(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with mismatched sizes are re-downloaded."""
# Import inside test to allow patching
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file with wrong size
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content") # 13 bytes
remote_size = 1000 # Different from local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
# Set up mock HTTP response for re-download
mock_response = MagicMock()
mock_response.status = 200
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
side_effect=[b"x" * remote_size, b""]
)
mock_session = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_response
)
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_session
)
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
# Mock calc_hash to return the expected hash
with patch(
"exo.download.download_utils.calc_hash",
new_callable=AsyncMock,
return_value=remote_hash,
):
await _download_file(model_id, "main", "test.safetensors", target_dir)
# file_meta should be called twice: once for verification, once for download
assert mock_file_meta.call_count == 2
async def test_skip_download_when_file_size_matches(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with matching sizes are not re-downloaded."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
local_content = b"local content"
async with aiofiles.open(local_file, "wb") as f:
await f.write(local_content)
remote_size = len(local_content) # Same as local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return immediately without downloading
assert result == local_file
mock_file_meta.assert_called_once()
mock_session_factory.assert_not_called()
async def test_offline_fallback_uses_local_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that local files are used when network is unavailable."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content")
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return local file without attempting download
assert result == local_file
mock_session_factory.assert_not_called()
class TestFileListCache:
"""Tests for file list caching behavior."""
async def test_fetch_fresh_and_update_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that fresh data is fetched and cache is updated."""
models_dir = tmp_path / "models"
file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=100),
]
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
return_value=file_list,
) as mock_fetch,
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == file_list
mock_fetch.assert_called_once()
# Verify cache was written
cache_file = (
models_dir
/ "caches"
/ model_id.normalize()
/ f"{model_id.normalize()}--main--file_list.json"
)
assert await aios.path.exists(cache_file)
async with aiofiles.open(cache_file, "r") as f:
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
await f.read()
)
assert cached_data == file_list
async def test_fallback_to_cache_when_fetch_fails(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that cached data is used when fetch fails."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
# Create cache file
cached_file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
)
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == cached_file_list
async def test_error_propagates_when_no_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that errors propagate when fetch fails and no cache exists."""
models_dir = tmp_path / "models"
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
pytest.raises(Exception, match="Network error"),
):
await fetch_file_list_with_cache(model_id, "main")
class TestModelDeletion:
"""Tests for model deletion including cache cleanup."""
async def test_delete_model_clears_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that deleting a model also deletes its cache."""
models_dir = tmp_path / "models"
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
# Create model and cache directories
await aios.makedirs(model_dir, exist_ok=True)
await aios.makedirs(cache_dir, exist_ok=True)
# Add some files
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
await f.write("model data")
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is True
assert not await aios.path.exists(model_dir)
assert not await aios.path.exists(cache_dir)
async def test_delete_model_only_cache_exists(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting when only cache exists (model already deleted)."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
# Only create cache directory
await aios.makedirs(cache_dir, exist_ok=True)
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
# Returns False because model dir didn't exist
assert result is False
# But cache should still be cleaned up
assert not await aios.path.exists(cache_dir)
async def test_delete_nonexistent_model(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting a model that doesn't exist."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is False
class TestProgressResetOnRedownload:
"""Tests for progress tracking when files are re-downloaded."""
async def test_progress_resets_correctly_on_redownload(
self, model_id: ModelId
) -> None:
"""Test that progress tracking resets when a file is re-downloaded.
When a file is deleted and re-downloaded (due to size mismatch),
the progress tracking should reset rather than calculating negative
downloaded_this_session values.
"""
# Simulate file_progress dict as it exists in download_shard
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with old file progress (simulating existing large file)
old_file_size = 1_500_000_000 # 1.5 GB
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(old_file_size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(old_file_size),
speed=0,
eta=timedelta(0),
status="not_started",
start_time=time.time() - 10, # Started 10 seconds ago
)
# Simulate the logic from on_progress_wrapper after re-download starts
# This is the exact logic from the fixed on_progress_wrapper
curr_bytes = 100_000 # 100 KB - new download just started
previous_progress = file_progress.get("model.safetensors")
# Detect re-download: curr_bytes < previous downloaded
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is True, "Should detect re-download scenario"
assert downloaded_this_session == curr_bytes, (
"downloaded_this_session should equal curr_bytes on re-download"
)
assert downloaded_this_session > 0, (
"downloaded_this_session should be positive, not negative"
)
# Calculate speed (should be positive)
elapsed = time.time() - start_time
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
assert speed >= 0, "Speed should be non-negative"
async def test_progress_accumulates_on_continuing_download(
self, model_id: ModelId
) -> None:
"""Test that progress accumulates correctly for continuing downloads.
When a download continues from where it left off (resume),
the progress should accumulate correctly.
"""
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with partial download progress
initial_downloaded = 500_000 # 500 KB already downloaded
start_time = time.time() - 5 # Started 5 seconds ago
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(initial_downloaded),
downloaded_this_session=Memory.from_bytes(initial_downloaded),
total=Memory.from_bytes(1_000_000),
speed=100_000,
eta=timedelta(seconds=5),
status="in_progress",
start_time=start_time,
)
# Progress callback with more bytes downloaded
curr_bytes = 600_000 # 600 KB - continuing download
previous_progress = file_progress.get("model.safetensors")
# This is NOT a re-download (curr_bytes > previous downloaded)
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
downloaded_this_session = curr_bytes
used_start_time = time.time()
else:
used_start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is False, (
"Should NOT detect re-download for continuing download"
)
assert used_start_time == start_time, "Should preserve original start_time"
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
assert downloaded_this_session == expected_session, (
f"Should accumulate: {downloaded_this_session} == {expected_session}"
)
assert downloaded_this_session == 600_000, (
"downloaded_this_session should equal total downloaded so far"
)

View File

@@ -1,10 +1,11 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
from typing import Iterator, Self
import anyio
from anyio.abc import TaskGroup
@@ -12,6 +13,8 @@ from loguru import logger
from pydantic import PositiveInt
import exo.routing.topics as topics
from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
@@ -21,7 +24,6 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
@@ -29,6 +31,7 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
@@ -36,6 +39,7 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -49,8 +53,26 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -58,6 +80,7 @@ class Node:
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
)
else:
@@ -67,11 +90,12 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -99,13 +123,25 @@ class Node:
election_result_sender=er_send,
)
return cls(router, worker, election, er_recv, master, api, node_id)
return cls(
router,
download_coordinator,
worker,
election,
er_recv,
master,
api,
node_id,
event_index_counter,
)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
@@ -170,13 +206,27 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -185,6 +235,10 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -226,6 +280,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -268,6 +323,11 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--no-downloads",
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -1,8 +1,11 @@
import base64
import contextlib
import json
import time
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from http import HTTPStatus
from pathlib import Path
from typing import Annotated, Literal, cast
from uuid import uuid4
@@ -32,7 +35,9 @@ from exo.shared.models.model_cards import (
ModelCard,
ModelId,
)
from exo.shared.tracing import compute_stats, load_trace_file
from exo.shared.types.api import (
AdvancedImageParams,
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
BenchImageGenerationResponse,
@@ -42,6 +47,7 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
@@ -59,8 +65,19 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
StreamOptions,
ToolCall,
TraceCategoryStats,
TraceEventResponse,
TraceListItem,
TraceListResponse,
TraceRankStats,
TraceResponse,
TraceStatsResponse,
Usage,
)
from exo.shared.types.chunks import (
ErrorChunk,
@@ -73,12 +90,16 @@ from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteDownload,
DeleteInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
StartDownload,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -104,7 +125,9 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
chunk: TokenChunk | ToolCallChunk,
command_id: CommandId,
usage: Usage | None,
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -129,21 +152,10 @@ def chunk_to_response(
finish_reason=chunk.finish_reason,
)
],
usage=usage,
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -154,12 +166,14 @@ class API:
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
) -> None:
self.state = State()
self._event_log: list[Event] = []
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
@@ -258,10 +272,16 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
self.app.get("/v1/traces")(self.list_traces)
self.app.get("/v1/traces/{task_id}")(self.get_trace)
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_card=await ModelCard.load(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -278,7 +298,7 @@ class API:
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
model_card = await ModelCard.load(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
@@ -306,7 +326,7 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_card = await ModelCard.load(model_id)
try:
placements = get_instance_placements(
@@ -343,14 +363,9 @@ class API:
) -> PlacementPreviewResponse:
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
previews: list[PlacementPreview] = []
required_nodes = set(node_ids) if node_ids else None
# Create filtered topology if node_ids specified
if node_ids and len(node_ids) > 0:
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
else:
topology = self.state.topology
if len(list(topology.list_nodes())) == 0:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
@@ -363,7 +378,9 @@ class API:
instance_combinations.extend(
[
(sharding, instance_meta, i)
for i in range(1, len(list(topology.list_nodes())) + 1)
for i in range(
1, len(list(self.state.topology.list_nodes())) + 1
)
]
)
# TODO: PDD
@@ -381,8 +398,9 @@ class API:
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=topology,
topology=self.state.topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
@@ -421,14 +439,16 @@ class API:
instance = new_instances[0]
shard_assignments = instance.shard_assignments
node_ids = list(shard_assignments.node_to_runner.keys())
placement_node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if node_ids:
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
@@ -436,7 +456,7 @@ class API:
model_card.model_id,
sharding,
instance_meta,
len(node_ids),
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
@@ -448,7 +468,14 @@ class API:
error=None,
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
return PlacementPreviewResponse(previews=previews)
@@ -502,9 +529,10 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, stream_options: StreamOptions | None = None
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
include_usage = stream_options.include_usage if stream_options else False
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
@@ -520,8 +548,10 @@ class API:
yield "data: [DONE]\n\n"
return
usage = chunk.usage if include_usage else None
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
chunk, command_id, usage=usage
)
logger.debug(f"chunk_response: {chunk_response}")
@@ -537,8 +567,9 @@ class API:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
usage: Usage | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ErrorChunk):
@@ -563,6 +594,9 @@ class API:
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.usage is not None:
usage = chunk.usage
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -584,6 +618,7 @@ class API:
finish_reason=finish_reason,
)
],
usage=usage,
)
async def _collect_chat_completion_with_stats(
@@ -591,7 +626,7 @@ class API:
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
@@ -644,7 +679,7 @@ class API:
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
)
@@ -653,7 +688,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await ModelCard.load(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -671,7 +706,7 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, payload.stream_options),
media_type="text/event-stream",
)
@@ -680,7 +715,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await ModelCard.load(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -700,12 +735,12 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(ModelId(model))
model_card = await ModelCard.load(model)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
@@ -751,7 +786,7 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
command = ImageGeneration(
request_params=payload,
@@ -835,6 +870,7 @@ class API:
# Yield partial image event (always use b64_json for partials)
event_data = {
"type": "partial",
"image_index": chunk.image_index,
"partial_index": partial_idx,
"total_partials": total_partials,
"format": str(chunk.format),
@@ -995,7 +1031,7 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
@@ -1016,7 +1052,7 @@ class API:
self,
image: UploadFile,
prompt: str,
model: str,
model: ModelId,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1024,6 +1060,9 @@ class API:
stream: bool,
partial_images: int,
bench: bool,
quality: Literal["high", "medium", "low"],
output_format: Literal["png", "jpeg", "webp"],
advanced_params: AdvancedImageParams | None,
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
@@ -1052,6 +1091,9 @@ class API:
stream=stream,
partial_images=partial_images,
bench=bench,
quality=quality,
output_format=output_format,
advanced_params=advanced_params,
),
)
@@ -1086,16 +1128,26 @@ class API:
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
partial_images: str = Form("0"),
quality: Literal["high", "medium", "low"] = Form("medium"),
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
advanced_params: str | None = Form(None),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
# Parse string form values to proper types
stream_bool = stream.lower() in ("true", "1", "yes")
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
parsed_advanced_params = AdvancedImageParams.model_validate_json(
advanced_params
)
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,
@@ -1103,6 +1155,9 @@ class API:
stream=stream_bool,
partial_images=partial_images_int,
bench=False,
quality=quality,
output_format=output_format,
advanced_params=parsed_advanced_params,
)
if stream_bool and partial_images_int > 0:
@@ -1133,12 +1188,22 @@ class API:
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
advanced_params: str | None = Form(None),
) -> BenchImageGenerationResponse:
"""Handle benchmark image editing requests with generation stats."""
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
parsed_advanced_params = AdvancedImageParams.model_validate_json(
advanced_params
)
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,
@@ -1146,6 +1211,9 @@ class API:
stream=False,
partial_images=0,
bench=True,
quality=quality,
output_format=output_format,
advanced_params=parsed_advanced_params,
)
return await self._collect_image_generation_with_stats(
@@ -1257,3 +1325,135 @@ class API:
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(
self, payload: StartDownloadParams
) -> StartDownloadResponse:
command = StartDownload(
target_node_id=payload.target_node_id,
shard_metadata=payload.shard_metadata,
)
await self._send_download(command)
return StartDownloadResponse(command_id=command.command_id)
async def delete_download(
self, node_id: NodeId, model_id: ModelId
) -> DeleteDownloadResponse:
command = DeleteDownload(
target_node_id=node_id,
model_id=ModelId(model_id),
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)
def _get_traces_dir(self) -> Path:
return Path.home() / ".exo" / "traces"
def _get_trace_path(self, task_id: str) -> Path:
return self._get_traces_dir() / f"trace_{task_id}.json"
async def list_traces(self) -> TraceListResponse:
traces_dir = self._get_traces_dir()
traces: list[TraceListItem] = []
if not traces_dir.exists():
return TraceListResponse(traces=[])
for trace_file in sorted(
traces_dir.glob("trace_*.json"),
key=lambda p: p.stat().st_mtime,
reverse=True,
):
# Extract task_id from filename (trace_{task_id}.json)
task_id = trace_file.stem.removeprefix("trace_")
stat = trace_file.stat()
created_at = datetime.fromtimestamp(
stat.st_mtime, tz=timezone.utc
).isoformat()
traces.append(
TraceListItem(
task_id=task_id,
created_at=created_at,
file_size=stat.st_size,
)
)
return TraceListResponse(traces=traces)
async def get_trace(self, task_id: str) -> TraceResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
trace_events = load_trace_file(trace_path)
return TraceResponse(
task_id=task_id,
traces=[
TraceEventResponse(
name=event.name,
start_us=event.start_us,
duration_us=event.duration_us,
rank=event.rank,
category=event.category,
)
for event in trace_events
],
)
async def get_trace_stats(self, task_id: str) -> TraceStatsResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
trace_events = load_trace_file(trace_path)
stats = compute_stats(trace_events)
return TraceStatsResponse(
task_id=task_id,
total_wall_time_us=stats.total_wall_time_us,
by_category={
category: TraceCategoryStats(
total_us=cat_stats.total_us,
count=cat_stats.count,
min_us=cat_stats.min_us,
max_us=cat_stats.max_us,
avg_us=cat_stats.avg_us,
)
for category, cat_stats in stats.by_category.items()
},
by_rank={
rank: TraceRankStats(
by_category={
category: TraceCategoryStats(
total_us=cat_stats.total_us,
count=cat_stats.count,
min_us=cat_stats.min_us,
max_us=cat_stats.max_us,
avg_us=cat_stats.avg_us,
)
for category, cat_stats in rank_stats.items()
}
)
for rank, rank_stats in stats.by_rank.items()
},
)
async def get_trace_raw(self, task_id: str) -> FileResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
return FileResponse(
path=trace_path,
media_type="application/json",
filename=f"trace_{task_id}.json",
)

View File

@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone
from pathlib import Path
import anyio
from anyio.abc import TaskGroup
@@ -11,6 +12,7 @@ from exo.master.placement import (
place_instance,
)
from exo.shared.apply import apply
from exo.shared.tracing import TraceEvent, export_trace, is_tracing_enabled
from exo.shared.types.commands import (
ChatCompletion,
CreateInstance,
@@ -35,6 +37,8 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TraceEventData,
TracesCollected,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -86,6 +90,8 @@ class Master:
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
async def run(self):
logger.info("Starting Master")
@@ -187,13 +193,14 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
instance_id=selected_instance_id,
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
@@ -201,6 +208,17 @@ class Master:
)
self.command_task_mapping[command.command_id] = task_id
if is_tracing_enabled():
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case ImageEdits():
for instance in self.state.instances.values():
if (
@@ -229,13 +247,14 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
instance_id=selected_instance_id,
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
@@ -243,6 +262,17 @@ class Master:
)
self.command_task_mapping[command.command_id] = task_id
if is_tracing_enabled():
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
@@ -335,6 +365,10 @@ class Master:
local_event.origin,
)
for event in self._multi_buffer.drain():
if isinstance(event, TracesCollected):
self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
@@ -373,3 +407,38 @@ class Master:
event=event.event,
)
)
def _handle_traces_collected(self, event: TracesCollected) -> None:
task_id = event.task_id
if task_id not in self._pending_traces:
self._pending_traces[task_id] = {}
self._pending_traces[task_id][event.rank] = event.traces
if (
task_id in self._expected_ranks
and set(self._pending_traces[task_id].keys())
>= self._expected_ranks[task_id]
):
self._merge_and_save_traces(task_id)
def _merge_and_save_traces(self, task_id: TaskId) -> None:
all_traces: list[TraceEvent] = []
for trace_data in self._pending_traces[task_id].values():
for t in trace_data:
all_traces.append(
TraceEvent(
name=t.name,
start_us=t.start_us,
duration_us=t.duration_us,
rank=t.rank,
category=t.category,
)
)
output_path = Path.home() / ".exo" / "traces" / f"trace_{task_id}.json"
export_trace(all_traces, output_path)
logger.info(f"Merged traces saved to {output_path}")
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]

View File

@@ -35,7 +35,7 @@ from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
port = random.randint(49153, 65535)
return port - 1 if port <= 52415 else 52414
return port - 1 if port <= 52415 else port
def add_instance_to_placements(
@@ -54,9 +54,18 @@ def place_instance(
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
required_nodes: set[NodeId] | None = None,
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
# Filter to cycles containing all required nodes (subset matching)
if required_nodes:
candidate_cycles = [
cycle
for cycle in candidate_cycles
if required_nodes.issubset(cycle.node_ids)
]
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_card.storage_size
)

View File

@@ -257,7 +257,13 @@ def _find_ip_prioritised(
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
}
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))

View File

@@ -216,6 +216,8 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

View File

@@ -3,7 +3,7 @@ from enum import Enum
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
ForwarderEvent,
)
@@ -45,3 +45,6 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)

View File

@@ -25,6 +25,7 @@ from exo.shared.types.events import (
TestEvent,
TopologyEdgeCreated,
TopologyEdgeDeleted,
TracesCollected,
)
from exo.shared.types.profiling import (
NodeIdentity,
@@ -55,7 +56,11 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
TestEvent()
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| TracesCollected()
): # Pass-through events that don't modify state
return state
case InstanceCreated():

View File

@@ -53,3 +53,9 @@ EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
EXO_ENABLE_IMAGE_MODELS = (
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
)
EXO_TRACING_ENABLED = os.getenv("EXO_TRACING_ENABLED", "").lower() in (
"1",
"true",
"yes",
)

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Annotated
from typing import Annotated, Any
import aiofiles
import aiofiles.os as aios
@@ -7,7 +7,14 @@ import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt, field_validator
from pydantic import (
AliasChoices,
BaseModel,
Field,
PositiveInt,
field_validator,
model_validator,
)
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.types.common import ModelId
@@ -121,6 +128,14 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2.5": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2.5"),
storage_size=Memory.from_gb(617),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.1
"llama-3.1-8b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
@@ -413,9 +428,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
}
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
model_id=ModelId("exolabs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
@@ -428,7 +443,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -442,7 +457,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -457,7 +472,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
model_id=ModelId("exolabs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
@@ -470,7 +485,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -484,7 +499,49 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -499,9 +556,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
model_id=ModelId("exolabs/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
@@ -509,10 +566,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -533,9 +590,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
@@ -543,10 +600,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -568,6 +625,92 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
),
}
def _generate_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
# quantizations = [8, 6, 5, 4, 3]
quantizations = [8, 4]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
model_id = ModelId(base_card.model_id + f"-{quant}bit")
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=model_id,
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
)
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
@@ -575,15 +718,18 @@ if EXO_ENABLE_IMAGE_MODELS:
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
hidden_size: Annotated[int, Field(ge=0)] | None = None
layer_count: int = Field(
validation_alias=AliasChoices(
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
)
)
@property
def supports_tensor(self) -> bool:
@@ -598,30 +744,32 @@ class ConfigData(BaseModel):
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
@model_validator(mode="before")
@classmethod
def defer_to_text_config(cls, data: dict[str, Any]):
text_config = data.get("text_config")
if text_config is None:
return data
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
for field in [
"architectures",
"hidden_size",
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
]:
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
data[field] = val
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
return data
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
from exo.worker.download.download_utils import (
from exo.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
@@ -643,11 +791,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
from exo.worker.download.download_utils import (
from exo.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)

View File

@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from loguru import logger
from pytest import LogCaptureFixture
from pytest import LogCaptureFixture, mark
from exo.routing.router import get_node_id_keypair
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
@@ -74,6 +74,7 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
os.remove(p)
@mark.skip(reason="this functionality is currently disabled but may return in future")
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10

View File

@@ -248,8 +248,8 @@ class Topology:
) -> list[list[NodeId]]:
"""
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
2 or fewer nodes don't cause the broadcast storm problem.
Only returns cycles with >=2 nodes (2+ machines in a loop), as
1 node doesn't cause the broadcast storm problem.
"""
enabled_nodes = {
node_id
@@ -257,7 +257,7 @@ class Topology:
if status.enabled
}
if len(enabled_nodes) < 3:
if len(enabled_nodes) < 2:
return []
thunderbolt_ips = _get_ips_with_interface_type(
@@ -288,7 +288,7 @@ class Topology:
return [
[graph[idx] for idx in cycle]
for cycle in rx.simple_cycles(graph)
if len(cycle) > 2
if len(cycle) >= 2
]

450
src/exo/shared/tracing.py Normal file
View File

@@ -0,0 +1,450 @@
from __future__ import annotations
import json
import logging
import time
from collections import defaultdict
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from pathlib import Path
from typing import cast, final
from exo.shared.constants import EXO_TRACING_ENABLED
logger = logging.getLogger(__name__)
# Context variable to track the current trace category for hierarchical nesting
_current_category: ContextVar[str | None] = ContextVar("current_category", default=None)
@final
@dataclass(frozen=True)
class TraceEvent:
name: str
start_us: int
duration_us: int
rank: int
category: str
@final
@dataclass
class CategoryStats:
total_us: int = 0
count: int = 0
min_us: int = 0
max_us: int = 0
def add(self, duration_us: int) -> None:
if self.count == 0:
self.min_us = duration_us
self.max_us = duration_us
else:
self.min_us = min(self.min_us, duration_us)
self.max_us = max(self.max_us, duration_us)
self.total_us += duration_us
self.count += 1
@property
def avg_us(self) -> float:
return self.total_us / self.count if self.count > 0 else 0.0
@final
@dataclass
class TraceStats:
total_wall_time_us: int = 0
by_category: dict[str, CategoryStats] = field(default_factory=dict)
by_rank: dict[int, dict[str, CategoryStats]] = field(default_factory=dict)
# Global trace buffer - each rank accumulates traces here
_trace_buffer: list[TraceEvent] = []
def is_tracing_enabled() -> bool:
"""Check if tracing is enabled via environment variable."""
return EXO_TRACING_ENABLED
def _record_span(
name: str, start_us: int, duration_us: int, rank: int, category: str
) -> None:
_trace_buffer.append(
TraceEvent(
name=name,
start_us=start_us,
duration_us=duration_us,
rank=rank,
category=category,
)
)
@contextmanager
def trace(
name: str,
rank: int,
category: str = "compute",
) -> Generator[None, None, None]:
"""Context manager to trace any operation.
Nested traces automatically inherit the parent category, creating hierarchical
categories like "sync/compute" or "async/comms".
Args:
name: Name of the operation (e.g., "recv 0", "send 1", "joint_blocks")
rank: This rank's ID
category: Category for grouping in trace viewer ("comm", "compute", "step")
Example:
with trace(f"sync {t}", rank, "sync"):
with trace("joint_blocks", rank, "compute"):
# Recorded with category "sync/compute"
hidden_states = some_computation(...)
"""
if not is_tracing_enabled():
yield
return
# Combine with parent category if nested
parent = _current_category.get()
full_category = f"{parent}/{category}" if parent else category
# Set as current for nested traces
token = _current_category.set(full_category)
try:
start_us = int(time.time() * 1_000_000)
start_perf = time.perf_counter()
yield
duration_us = int((time.perf_counter() - start_perf) * 1_000_000)
_record_span(name, start_us, duration_us, rank, full_category)
finally:
_current_category.reset(token)
def get_trace_buffer() -> list[TraceEvent]:
return list(_trace_buffer)
def clear_trace_buffer() -> None:
_trace_buffer.clear()
def export_trace(traces: list[TraceEvent], output_path: Path) -> None:
trace_events: list[dict[str, object]] = []
for event in traces:
# Chrome trace format uses "X" for complete events (with duration)
chrome_event: dict[str, object] = {
"name": event.name,
"cat": event.category,
"ph": "X",
"ts": event.start_us,
"dur": event.duration_us,
"pid": 0,
"tid": event.rank,
"args": {"rank": event.rank},
}
trace_events.append(chrome_event)
ranks_seen = set(t.rank for t in traces)
for rank in ranks_seen:
trace_events.append(
{
"name": "thread_name",
"ph": "M", # Metadata event
"pid": 0,
"tid": rank,
"args": {"name": f"Rank {rank}"},
}
)
chrome_trace = {"traceEvents": trace_events}
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(chrome_trace, f, indent=2)
except OSError as e:
logger.warning("Failed to export trace to %s: %s", output_path, e)
def export_local_traces(rank: int) -> None:
if not is_tracing_enabled():
return
local_traces = get_trace_buffer()
if local_traces:
output_path = Path.home() / ".exo" / "traces" / f"trace_{rank}.json"
try:
export_trace(local_traces, output_path)
except Exception as e:
logger.warning("Failed to export local traces for rank %d: %s", rank, e)
clear_trace_buffer()
def merge_trace_files(trace_dir: Path | None = None) -> Path | None:
if trace_dir is None:
trace_dir = Path.home() / ".exo" / "traces"
if not trace_dir.exists():
return None
trace_files = sorted(trace_dir.glob("trace_*.json"))
if not trace_files:
return None
merged_events: list[dict[str, object]] = []
for trace_file in trace_files:
file_rank = int(trace_file.stem.split("_")[1])
with open(trace_file) as f:
raw = f.read()
data = cast(dict[str, list[dict[str, object]]], json.loads(raw))
events: list[dict[str, object]] = data.get("traceEvents", [])
for event in events:
event["tid"] = file_rank
if "args" in event and isinstance(event["args"], dict):
event["args"]["rank"] = file_rank
merged_events.extend(events)
output_path = Path.home() / ".exo" / "traces" / "merged_trace.json"
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump({"traceEvents": merged_events}, f, indent=2)
except OSError as e:
logger.warning("Failed to write merged trace to %s: %s", output_path, e)
return None
return output_path
def _format_duration(us: int | float) -> str:
if us < 1000:
return f"{us:.0f}µs"
elif us < 1_000_000:
return f"{us / 1000:.2f}ms"
else:
return f"{us / 1_000_000:.2f}s"
def load_trace_file(path: Path) -> list[TraceEvent]:
"""Load a Chrome Trace Format JSON file into TraceEvent objects."""
with open(path) as f:
data = cast(dict[str, list[dict[str, object]]], json.load(f))
events = data.get("traceEvents", [])
traces: list[TraceEvent] = []
for event in events:
# Skip metadata events
if event.get("ph") == "M":
continue
name = str(event.get("name", ""))
category = str(event.get("cat", ""))
ts_value = event.get("ts", 0)
dur_value = event.get("dur", 0)
tid_value = event.get("tid", 0)
start_us = int(ts_value) if isinstance(ts_value, (int, float, str)) else 0
duration_us = int(dur_value) if isinstance(dur_value, (int, float, str)) else 0
# Get rank from tid or args
rank = int(tid_value) if isinstance(tid_value, (int, float, str)) else 0
args = event.get("args")
if isinstance(args, dict):
args_dict = cast(dict[str, object], args)
rank_from_args = args_dict.get("rank")
if isinstance(rank_from_args, (int, float, str)):
rank = int(rank_from_args)
traces.append(
TraceEvent(
name=name,
start_us=start_us,
duration_us=duration_us,
rank=rank,
category=category,
)
)
return traces
def compute_stats(traces: list[TraceEvent]) -> TraceStats:
"""Compute comprehensive statistics from trace events."""
stats = TraceStats()
if not traces:
return stats
# Calculate wall time from earliest start to latest end
min_start = min(t.start_us for t in traces)
max_end = max(t.start_us + t.duration_us for t in traces)
stats.total_wall_time_us = max_end - min_start
# Initialize nested dicts
by_category: dict[str, CategoryStats] = defaultdict(CategoryStats)
by_rank: dict[int, dict[str, CategoryStats]] = defaultdict(
lambda: defaultdict(CategoryStats)
)
for event in traces:
# By category
by_category[event.category].add(event.duration_us)
# By rank and category
by_rank[event.rank][event.category].add(event.duration_us)
stats.by_category = dict(by_category)
stats.by_rank = {k: dict(v) for k, v in by_rank.items()}
return stats
def print_stats(stats: TraceStats) -> None:
"""Print formatted trace statistics."""
print("=== Trace Statistics ===")
print()
print(f"Wall Time: {_format_duration(stats.total_wall_time_us)}")
print()
# Parse hierarchical categories (e.g., "sync/compute" -> phase="sync", subcat="compute")
if stats.by_category:
phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
has_hierarchical = False
for cat, cat_stats in stats.by_category.items():
if "/" in cat:
phase, subcat = cat.split("/", 1)
phases[phase][subcat] = cat_stats
has_hierarchical = True
else:
phases[cat]["_total"] = cat_stats
if has_hierarchical:
print("By Phase:")
for phase in sorted(phases.keys()):
subcats = phases[phase]
# Skip phases that only have _total (non-hierarchical top-level categories)
non_total_subcats = {k: v for k, v in subcats.items() if k != "_total"}
if not non_total_subcats:
continue
phase_total = sum(s.total_us for s in non_total_subcats.values())
print(f" {phase}:")
for subcat, subcat_stats in sorted(
non_total_subcats.items(),
key=lambda x: x[1].total_us,
reverse=True,
):
pct = (
subcat_stats.total_us / phase_total * 100 if phase_total else 0
)
# Use parent phase's step count for per-step average
phase_step_count = subcats.get("_total", CategoryStats()).count
if phase_step_count > 0:
avg_per_step = subcat_stats.total_us / phase_step_count
else:
avg_per_step = subcat_stats.avg_us # fallback
print(
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
)
print()
else:
# Fall back to flat category display if no hierarchical categories
print("By Category:")
total_time = sum(c.total_us for c in stats.by_category.values())
for category, cat_stats in sorted(
stats.by_category.items(), key=lambda x: x[1].total_us, reverse=True
):
pct = (cat_stats.total_us / total_time * 100) if total_time > 0 else 0
print(
f" {category:12s} {_format_duration(cat_stats.total_us):>10s} "
f"({pct:5.1f}%) avg: {_format_duration(cat_stats.avg_us):>8s} "
f"count: {cat_stats.count}"
)
print()
# By Rank
if stats.by_rank:
print("By Rank:")
for rank in sorted(stats.by_rank.keys()):
rank_stats = stats.by_rank[rank]
print(f" Rank {rank}:")
# Parse hierarchical categories for this rank
rank_phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
has_hierarchical = False
for cat, cat_stats in rank_stats.items():
if "/" in cat:
phase, subcat = cat.split("/", 1)
rank_phases[phase][subcat] = cat_stats
has_hierarchical = True
else:
rank_phases[cat]["_total"] = cat_stats
if has_hierarchical:
for phase in sorted(rank_phases.keys()):
subcats = rank_phases[phase]
non_total_subcats = {
k: v for k, v in subcats.items() if k != "_total"
}
if not non_total_subcats:
continue
phase_total = sum(s.total_us for s in non_total_subcats.values())
print(f" {phase}:")
for subcat, subcat_stats in sorted(
non_total_subcats.items(),
key=lambda x: x[1].total_us,
reverse=True,
):
pct = (
subcat_stats.total_us / phase_total * 100
if phase_total
else 0
)
# Use parent phase's step count for per-step average
phase_step_count = subcats.get("_total", CategoryStats()).count
if phase_step_count > 0:
avg_per_step = subcat_stats.total_us / phase_step_count
else:
avg_per_step = subcat_stats.avg_us # fallback
print(
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
)
else:
# Flat display fallback
for category, cat_stats in sorted(
rank_stats.items(), key=lambda x: x[1].total_us, reverse=True
):
print(f" {category}: {_format_duration(cat_stats.total_us)}")
print()
if __name__ == "__main__":
import sys
path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("trace.json")
if not path.exists():
print(f"Error: File not found: {path}")
sys.exit(1)
traces = load_trace_file(path)
if not traces:
print("No trace events found in file.")
sys.exit(0)
computed_stats = compute_stats(traces)
print_stats(computed_stats)

View File

@@ -7,10 +7,11 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -115,8 +116,8 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: PromptTokensDetails | None = None
completion_tokens_details: CompletionTokensDetails | None = None
prompt_tokens_details: PromptTokensDetails
completion_tokens_details: CompletionTokensDetails
class StreamingChoiceResponse(BaseModel):
@@ -169,7 +170,13 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class ChatCompletionTaskParams(BaseModel):
class StreamOptions(BaseModel):
include_usage: bool = False
class ChatCompletionTaskParams(TaggedModel):
model_config = ConfigDict(extra="ignore")
model: str
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
@@ -183,6 +190,7 @@ class ChatCompletionTaskParams(BaseModel):
seed: int | None = None
stop: str | list[str] | None = None
stream: bool = False
stream_options: StreamOptions | None = None
temperature: float | None = None
top_p: float | None = None
tools: list[dict[str, Any]] | None = None
@@ -352,3 +360,58 @@ class ImageListItem(BaseModel, frozen=True):
class ImageListResponse(BaseModel, frozen=True):
data: list[ImageListItem]
class StartDownloadParams(CamelCaseModel):
target_node_id: NodeId
shard_metadata: ShardMetadata
class StartDownloadResponse(CamelCaseModel):
command_id: CommandId
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId
class TraceEventResponse(CamelCaseModel):
name: str
start_us: int
duration_us: int
rank: int
category: str
class TraceResponse(CamelCaseModel):
task_id: str
traces: list[TraceEventResponse]
class TraceCategoryStats(CamelCaseModel):
total_us: int
count: int
min_us: int
max_us: int
avg_us: float
class TraceRankStats(CamelCaseModel):
by_category: dict[str, TraceCategoryStats]
class TraceStatsResponse(CamelCaseModel):
task_id: str
total_wall_time_us: int
by_category: dict[str, TraceCategoryStats]
by_rank: dict[int, TraceRankStats]
class TraceListItem(CamelCaseModel):
task_id: str
created_at: str
file_size: int
class TraceListResponse(CamelCaseModel):
traces: list[TraceListItem]

View File

@@ -2,7 +2,7 @@ from collections.abc import Generator
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -17,6 +17,7 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
usage: Usage | None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
@@ -28,6 +29,7 @@ class ErrorChunk(BaseChunk):
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
usage: Usage | None
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None

View File

@@ -1,7 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
@@ -9,7 +10,7 @@ from exo.shared.types.api import (
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -22,7 +23,7 @@ class TestCommand(BaseCommand):
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
class ImageGeneration(BaseCommand):
@@ -62,6 +63,19 @@ class RequestEventLog(BaseCommand):
since_idx: int
class StartDownload(BaseCommand):
target_node_id: NodeId
shard_metadata: ShardMetadata
class DeleteDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
Command = (
TestCommand
| RequestEventLog
@@ -79,3 +93,8 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
command: DownloadCommand

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import final
from pydantic import Field
@@ -10,7 +11,7 @@ from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
class EventId(Id):
@@ -109,6 +110,22 @@ class TopologyEdgeDeleted(BaseEvent):
conn: Connection
@final
class TraceEventData(FrozenModel):
name: str
start_us: int
duration_us: int
rank: int
category: str
@final
class TracesCollected(BaseEvent):
task_id: TaskId
rank: int
traces: list[TraceEventData]
Event = (
TestEvent
| TaskCreated
@@ -127,6 +144,7 @@ Event = (
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected
)

View File

@@ -0,0 +1,12 @@
"""Shared types for MLX-related functionality."""
from collections.abc import Sequence
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# This list contains one cache entry per transformer layer
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]

View File

@@ -48,7 +48,7 @@ class SystemPerformanceProfile(CamelCaseModel):
ecpu_usage: float = 0.0
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
InterfaceType = Literal["wifi", "ethernet", "maybe_ethernet", "thunderbolt", "unknown"]
class NetworkInterfaceInfo(CamelCaseModel):

View File

@@ -3,6 +3,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
@@ -54,7 +55,7 @@ class StartWarmup(BaseTask): # emitted by Worker
class ChatCompletion(BaseTask): # emitted by Master
command_id: CommandId
task_params: ChatCompletionTaskParams
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)

View File

@@ -6,6 +6,7 @@ from exo.shared.types.api import (
GenerationStats,
ImageGenerationStats,
ToolCallItem,
Usage,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -24,12 +25,14 @@ class GenerationResponse(BaseRunnerResponse):
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
usage: Usage | None
class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
stats: ImageGenerationStats | None = None
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
@@ -44,6 +47,7 @@ class PartialImageResponse(BaseRunnerResponse):
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
@@ -55,6 +59,7 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -349,13 +349,8 @@ class InfoGatherer:
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await self.info_sender.send(await MiscData.gather())
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler_thunderbolt_data(self):
@@ -365,15 +360,12 @@ class InfoGatherer:
if iface_map is None:
return
old_idents = []
while True:
data = await ThunderboltConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
old_idents = idents
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
@@ -398,22 +390,17 @@ class InfoGatherer:
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
nics = await get_network_interfaces()
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_thunderbolt_bridge_status(self):
if self.thunderbolt_bridge_poll_interval is None:
return
prev: ThunderboltBridgeInfo | None = None
while True:
curr = await ThunderboltBridgeInfo.gather()
if curr is not None and prev != curr:
prev = curr
if curr is not None:
await self.info_sender.send(curr)
await anyio.sleep(self.thunderbolt_bridge_poll_interval)

View File

@@ -1,6 +1,6 @@
import socket
import sys
from subprocess import CalledProcessError, run
from subprocess import CalledProcessError
import psutil
from anyio import run_process
@@ -16,8 +16,7 @@ async def get_friendly_name() -> str:
"""
hostname = socket.gethostname()
# TODO: better non mac support
if sys.platform != "darwin": # 'darwin' is the platform name for macOS
if sys.platform != "darwin":
return hostname
try:
@@ -28,21 +27,20 @@ async def get_friendly_name() -> str:
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
"""Parse networksetup -listallhardwareports to get interface types."""
if sys.platform != "darwin":
return {}
try:
result = run(
["networksetup", "-listallhardwareports"], capture_output=True, text=True
)
except Exception:
result = await run_process(["networksetup", "-listallhardwareports"])
except CalledProcessError:
return {}
types: dict[str, InterfaceType] = {}
current_type: InterfaceType = "unknown"
for line in result.stdout.splitlines():
for line in result.stdout.decode().splitlines():
if line.startswith("Hardware Port:"):
port_name = line.split(":", 1)[1].strip()
if "Wi-Fi" in port_name:
@@ -55,12 +53,15 @@ def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
current_type = "unknown"
elif line.startswith("Device:"):
device = line.split(":", 1)[1].strip()
# enX is ethernet adapters or thunderbolt - these must be deprioritised
if device.startswith("en") and device not in ["en0", "en1"]:
current_type = "maybe_ethernet"
types[device] = current_type
return types
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
"""
Retrieves detailed network interface information on macOS.
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
@@ -68,7 +69,7 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
Returns a list of NetworkInterfaceInfo objects.
"""
interfaces_info: list[NetworkInterfaceInfo] = []
interface_types = _get_interface_types_from_networksetup()
interface_types = await _get_interface_types_from_networksetup()
for iface, services in psutil.net_if_addrs().items():
for service in services:

View File

@@ -0,0 +1,32 @@
import time
from typing import Generic, TypeVar
K = TypeVar("K")
class KeyedBackoff(Generic[K]):
"""Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0):
self._base = base
self._cap = cap
self._attempts: dict[K, int] = {}
self._last_time: dict[K, float] = {}
def should_proceed(self, key: K) -> bool:
"""Returns True if enough time has elapsed since last attempt."""
now = time.monotonic()
last = self._last_time.get(key, 0.0)
attempts = self._attempts.get(key, 0)
delay = min(self._cap, self._base * (2.0**attempts))
return now - last >= delay
def record_attempt(self, key: K) -> None:
"""Record that an attempt was made for this key."""
self._last_time[key] = time.monotonic()
self._attempts[key] = self._attempts.get(key, 0) + 1
def reset(self, key: K) -> None:
"""Reset backoff state for a key (e.g., on success)."""
self._attempts.pop(key, None)
self._last_time.pop(key, None)

View File

@@ -6,10 +6,10 @@ import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.download.download_utils import build_model_path
from exo.shared.types.api import AdvancedImageParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
@@ -140,6 +140,7 @@ class DistributedImageModel:
width=width,
image_path=image_path,
model_config=self._adapter.model.model_config, # pyright: ignore[reportAny]
guidance=guidance_override if guidance_override is not None else 4.0,
)
num_sync_steps = self._config.get_num_sync_steps(steps)

View File

@@ -75,19 +75,20 @@ def generate_image(
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
advanced_params = task.advanced_params
if advanced_params is not None and advanced_params.seed is not None:
seed = advanced_params.seed
base_seed = advanced_params.seed
else:
seed = random.randint(0, 2**32 - 1)
base_seed = random.randint(0, 2**32 - 1)
is_bench = getattr(task, "bench", False)
num_images = task.n or 1
generation_start_time: float = 0.0
@@ -95,7 +96,11 @@ def generate_image(
mx.reset_peak_memory()
generation_start_time = time.perf_counter()
partial_images = task.partial_images or (3 if task.stream else 0)
partial_images = (
task.partial_images
if task.partial_images is not None and task.stream is not None and task.stream
else 0
)
image_path: Path | None = None
@@ -105,72 +110,81 @@ def generate_image(
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
for image_num in range(num_images):
# Increment seed for each image to ensure unique results
current_seed = base_seed + image_num
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
image = result
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=current_seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
stats: ImageGenerationStats | None = None
if is_bench:
generation_end_time = time.perf_counter()
total_generation_time = generation_end_time - generation_start_time
num_inference_steps = model.get_steps_for_quality(quality)
seconds_per_step = (
total_generation_time / num_inference_steps
if num_inference_steps > 0
else 0.0
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
image_index=image_num,
)
else:
image = result
peak_memory_gb = mx.get_peak_memory() / (1024**3)
# Only include stats on the final image
stats: ImageGenerationStats | None = None
if is_bench and image_num == num_images - 1:
generation_end_time = time.perf_counter()
total_generation_time = (
generation_end_time - generation_start_time
)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=task.n or 1,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
num_inference_steps = model.get_steps_for_quality(quality)
total_steps = num_inference_steps * num_images
seconds_per_step = (
total_generation_time / total_steps
if total_steps > 0
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
image_index=image_num,
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
)

View File

@@ -33,6 +33,7 @@ _ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,

View File

@@ -6,6 +6,11 @@ from mflux.models.common.config.config import Config
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.tracing import (
clear_trace_buffer,
is_tracing_enabled,
trace,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
@@ -324,6 +329,7 @@ class DiffusionRunner:
capture_steps = set()
self._reset_all_caches()
clear_trace_buffer()
time_steps = tqdm(range(runtime_config.num_inference_steps))
@@ -348,6 +354,7 @@ class DiffusionRunner:
ctx.in_loop( # pyright: ignore[reportAny]
t=t,
latents=latents,
time_steps=time_steps,
)
mx.eval(latents)
@@ -464,20 +471,22 @@ class DiffusionRunner:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < config.init_time_step + num_sync_steps:
return self._sync_pipeline_step(
t,
config,
latents,
prompt_data,
)
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
return self._sync_pipeline_step(
t,
config,
latents,
prompt_data,
)
else:
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
is_first_async_step=t == config.init_time_step + num_sync_steps,
)
with trace(name=f"async {t}", rank=self.rank, category="async"):
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
is_first_async_step=t == config.init_time_step + num_sync_steps,
)
def _single_node_step(
self,
@@ -585,30 +594,41 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
with trace(
name="joint_blocks",
rank=self.rank,
category="compute",
):
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if is_tracing_enabled():
mx.eval(encoder_hidden_states, hidden_states)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
@@ -619,45 +639,63 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if not self.is_last_stage:
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
with trace(
name="single blocks",
rank=self.rank,
category="compute",
):
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if is_tracing_enabled():
mx.eval(hidden_states)
if not self.is_last_stage:
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = hidden_states[:, text_seq_len:, ...]
@@ -741,14 +779,20 @@ class DiffusionRunner:
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
with trace(name="send 0", rank=self.rank, category="comms"):
hidden_states = mx.distributed.send(
hidden_states, 0, group=self.group
)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
with trace(
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
hidden_states = prev_latents
@@ -808,10 +852,13 @@ class DiffusionRunner:
and not self.is_last_stage
and not is_first_async_step
):
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
@@ -842,10 +889,13 @@ class DiffusionRunner:
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
@@ -884,22 +934,28 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
mx.eval(patch)
if patch_idx == 0:
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -908,14 +964,22 @@ class DiffusionRunner:
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
with trace(
name=f"joint patch {patch_idx}",
rank=self.rank,
category="compute",
):
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if is_tracing_enabled():
mx.eval(encoder_hidden_states, patch)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
@@ -924,49 +988,70 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
with trace(
name=f"single patch {patch_idx}",
rank=self.rank,
category="compute",
):
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if is_tracing_enabled():
mx.eval(patch)
if not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
patch = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch, text_embeddings)
return noise, encoder_hidden_states

View File

@@ -19,8 +19,11 @@ from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
@@ -145,6 +148,10 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -194,6 +201,9 @@ def pipeline_auto_parallel(
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
for layer in layers:
mx.eval(layer) # type: ignore
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],
@@ -252,10 +262,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call
@@ -334,15 +340,7 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -350,8 +348,7 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model, KimiK25Model)):
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -367,6 +364,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GLM4MoeLiteModel):
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
@@ -441,7 +446,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model
@@ -452,7 +457,7 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel, KimiK25Model)
) and hasattr(inner_model_instance, "num_layers"):
logger.info(
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
@@ -516,6 +521,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
mx.eval(layer)
return model
@@ -533,6 +540,84 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
return y
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GLM4MoeLiteModel, model)
for layer in model.layers: # type: ignore
layer = cast(Glm4MoeLiteDecoderLayer, layer)
eval_with_timeout(
layer.parameters(),
timeout_seconds / len(model.layers), # type: ignore
on_timeout,
)
if layer.self_attn.q_lora_rank is None: # type: ignore
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
else:
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Logic from upstream mlx
num_heads = layer.self_attn.num_heads
sh = self.group.rank() * num_heads
eh = sh + num_heads
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
return w[sh:eh]
layer.self_attn.embed_q.apply(shard_heads)
layer.self_attn.unembed_out.apply(shard_heads)
if isinstance(layer.mlp, Glm4MoeLiteMLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
else:
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # type: ignore
mx.eval(layer)
return model
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -541,6 +626,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
rank = self.group.rank()
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
@@ -550,6 +636,16 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
# Shard qk_norm weights if present (must match sharded head count)
if getattr(layer.self_attn, "use_qk_norm", False):
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
@@ -566,7 +662,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
@@ -607,6 +703,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
mx.eval(layer)
return model
@@ -661,7 +758,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model

View File

@@ -1,104 +1,234 @@
# type: ignore
# TODO: Fix this file, including types!
import os
from copy import deepcopy
from typing import Callable
from typing import Any, cast
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
import psutil
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
trim_prompt_cache,
)
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.9
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = []
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
def __init__(
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
):
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
self._group = group
def clear(self):
"""Clear all cached prompts and caches."""
self.prompts.clear()
self.caches.clear()
self._last_used.clear()
def add_kv_cache(self, prompt: str, cache: KVCacheType):
"""Add a new cache entry. Evicts LRU entries if memory is high."""
self._evict_if_needed()
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
self._access_counter += 1
self._last_used.append(self._access_counter)
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
def update_kv_cache(
self,
index: int,
prompt: str,
cache: KVCacheType,
):
"""Update an existing cache entry in-place."""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
self.prompts[index] = tokenized_prompt
self.caches[index] = deepcopy(cache)
self._access_counter += 1
self._last_used[index] = self._access_counter
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
) -> tuple[KVCacheType, mx.array, int | None]:
"""Get KV cache for prompt, returning remaining tokens to prefill.
Returns:
Tuple of (cache, remaining_tokens, matched_index) where:
- cache: KV cache to use for generation
- remaining_tokens: tokens that still need prefilling
- matched_index: index of the matched entry (None if no match)
"""
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
max_length = len(tokenized_prompt)
best_snapshot_index, best_snapshot_length = None, 0
for i, cached_prompt in enumerate(self.prompts):
length = _get_prefix_length(tokenized_prompt, cached_prompt)
length = get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
return self.caches[i]
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
self._access_counter += 1
self._last_used[i] = self._access_counter
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:], i
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
new_tokens = max_length - best_snapshot_length
logger.info(
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
)
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
prompt_cache = deepcopy(self.caches[best_snapshot_index])
return prompt_cache
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
self._access_counter += 1
self._last_used[best_snapshot_index] = self._access_counter
remaining_tokens = tokenized_prompt[best_snapshot_length:]
return prompt_cache, remaining_tokens, best_snapshot_index
else:
prompt_cache = make_kv_cache(model)
if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory usage is high."""
if len(self.caches) == 0:
return
# Evict LRU entries until below threshold or only one entry left
while (
len(self.caches) > 1
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
)
def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_used_percentage()
if self._group is None:
return local_pressure
all_pressure = mx.distributed.all_gather(
mx.array([local_pressure], dtype=mx.float32),
group=self._group,
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
# .item() evals.
max_pressure = float(mx.max(all_pressure).item())
return max_pressure
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
For chat-templated prompts (which have their own structure markers like
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass
def get_available_memory() -> Memory:
mem: int = psutil.virtual_memory().available
return Memory.from_bytes(mem)
def get_memory_used_percentage() -> float:
mem = psutil.virtual_memory()
# percent is 0-100
return float(mem.percent / 100)
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> KVCacheType:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")
return [KVCache() for _ in model.layers]
else:
logger.info("Using quantized KV cache")
return [
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
for _ in model.layers
]
else:
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]

View File

@@ -4,7 +4,7 @@
KV_GROUP_SIZE: int | None = 32
KV_BITS: int | None = None
ATTENTION_KV_BITS: int | None = 4
MAX_TOKENS: int = 8192
MAX_TOKENS: int = 32168
MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"

View File

@@ -1,48 +1,94 @@
import time
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
CompletionTokensDetails,
FinishReason,
GenerationStats,
PromptTokensDetails,
Usage,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_TO_UPDATE = 1000
def maybe_quantize_kv_cache(
prompt_cache: list[KVCache | Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> tuple[float, int]:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
then trims off the extra generated token.
Returns:
tokens_per_sec
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0, 0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cast(list[Any], cache), 1)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
logger.debug(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec, num_tokens
def warmup_inference(
@@ -120,18 +166,36 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
logger.info(f"{is_bench=}")
# Currently we support chat-completion tasks only.
logger.debug(f"task_params: {task}")
if task.seed is not None:
mx.random.seed(task.seed)
caches = make_kv_cache(model=model)
# Do not use the prefix cache if we are trying to do benchmarks.
if is_bench:
kv_prefix_cache = None
# Use prefix cache if available, otherwise create fresh cache
prefix_hit_length = 0
matched_index: int | None = None
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
else:
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
all_prompt_tokens = encode_prompt(tokenizer, prompt)
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
@@ -144,28 +208,54 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches
)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
usage: Usage | None = None
in_thinking = False
reasoning_tokens = 0
think_start = tokenizer.think_start
think_end = tokenizer.think_end
for completion_tokens, out in enumerate(
stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
),
start=1,
):
generated_text_parts.append(out.text)
logger.info(out.text)
if think_start is not None and out.text == think_start:
in_thinking = True
elif think_end is not None and out.text == think_end:
in_thinking = False
if in_thinking:
reasoning_tokens += 1
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
@@ -177,14 +267,47 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
usage = Usage(
prompt_tokens=int(out.prompt_tokens),
completion_tokens=completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),
completion_tokens_details=CompletionTokensDetails(
reasoning_tokens=reasoning_tokens
),
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
usage=usage,
)
if out.finish_reason is not None:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
generated_tokens = len(generated_text_parts)
generation_tps = (
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
)
logger.debug(
f"Generation complete: prefill {prompt_tokens} tokens @ "
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
f"{generation_tps:.1f} tok/s"
)
if kv_prefix_cache is not None:
full_prompt = prompt + "".join(generated_text_parts)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
else:
kv_prefix_cache.add_kv_cache(full_prompt, caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -18,15 +18,12 @@ try:
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
TRUST_REMOTE_CODE,
)
@@ -41,6 +38,7 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
@@ -55,7 +53,6 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
@@ -168,12 +165,11 @@ def mlx_distributed_init(
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
@@ -262,10 +258,10 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
# Estimate timeout based on model size (5x default for large queued workloads)
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
timeout_seconds = base_timeout + model_size_gb
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
@@ -342,8 +338,35 @@ def load_tokenizer_for_model_id(
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
if "kimi-k2" in model_id_lower:
import importlib.util
import types
sys.path.insert(0, str(model_path))
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
tool_decl_path = model_path / "tool_declaration_ts.py"
if tool_decl_path.exists():
spec = importlib.util.spec_from_file_location(
"tool_declaration_ts", tool_decl_path
)
if spec and spec.loader:
tool_decl_module = importlib.util.module_from_spec(spec)
sys.modules["tool_declaration_ts"] = tool_decl_module
spec.loader.exec_module(tool_decl_module)
# Load tokenization_kimi with patched source (convert relative to absolute import)
tok_path = model_path / "tokenization_kimi.py"
source = tok_path.read_text()
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
if spec:
tok_module = types.ModuleType("tokenization_kimi")
tok_module.__file__ = str(tok_path)
sys.modules["tokenization_kimi"] = tok_module
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
TikTokenTokenizer = tok_module.TikTokenTokenizer # type: ignore[attr-defined] # noqa: N806
else:
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
@@ -405,7 +428,11 @@ def apply_chat_template(
continue
message.content = "\n".join(c.text for c in message.content).strip()
if message.content is None and message.thinking is None:
if (
message.content is None
and message.thinking is None
and message.tool_calls is None
):
continue
# Null values are not valid when applying templates in tokenizer
@@ -462,31 +489,6 @@ class NullKVCache(KVCache):
raise NotImplementedError("We should not be setting a NullKVCache.")
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")
return [KVCache() for _ in model.layers]
else:
logger.info("Using quantized KV cache")
return [
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
for _ in model.layers
]
else:
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
def mlx_force_oom(size: int = 40000) -> None:
"""
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.

View File

@@ -1,8 +1,9 @@
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio import CancelScope, create_task_group, fail_after
from anyio.abc import TaskGroup
from loguru import logger
@@ -10,7 +11,12 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
@@ -18,7 +24,6 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -36,23 +41,12 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.utils.keyed_backoff import KeyedBackoff
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -62,7 +56,6 @@ class Worker:
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
@@ -70,23 +63,22 @@ class Worker:
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.local_event_index = 0
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
@@ -101,6 +93,8 @@ class Worker:
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
async def run(self):
logger.info("Starting Worker")
@@ -111,7 +105,6 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -121,6 +114,7 @@ class Worker:
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
@@ -179,11 +173,9 @@ class Worker:
async def plan_step(self):
while True:
await anyio.sleep(0.1)
# 3. based on the updated state, we plan & execute an operation.
task: Task | None = plan(
self.node_id,
self.runners,
self.download_status,
self.state.downloads,
self.state.instances,
self.state.runners,
@@ -207,42 +199,26 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
model_id = shard.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
continue
self._download_backoff.record_attempt(model_id)
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -387,104 +363,17 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
):
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_with_error_handling() -> None:
try:
await self.shard_downloader.ensure_shard(task.shard_metadata)
except Exception as e:
error_message = str(e)
logger.error(
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
)
failed_status = DownloadFailed(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
error_message=error_message,
)
self.download_status[task.shard_metadata.model_card.model_id] = (
failed_status
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed_status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Failed
)
)
self._tg.start_soon(download_with_error_handling)
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=self.local_event_index,
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
)
self.local_event_index += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
@@ -532,42 +421,3 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -2,7 +2,6 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
ChatCompletion,
@@ -45,9 +44,6 @@ def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
@@ -59,7 +55,7 @@ def plan(
return (
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
@@ -115,9 +111,15 @@ def _create_runner(
def _model_needs_download(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> DownloadModel | None:
local_downloads = global_download_status.get(node_id, [])
download_status = {
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
}
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (

View File

@@ -18,6 +18,7 @@ from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer, is_tracing_enabled
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
@@ -27,6 +28,8 @@ from exo.shared.types.events import (
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
TraceEventData,
TracesCollected,
)
from exo.shared.types.tasks import (
ChatCompletion,
@@ -37,6 +40,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -70,6 +74,7 @@ from exo.worker.engines.image import (
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -103,14 +108,19 @@ def main(
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
seen = set[TaskId]()
with task_receiver as tasks:
for task in tasks:
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -161,6 +171,8 @@ def main(
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
kv_prefix_cache = KVPrefixCache(tokenizer, group)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
@@ -170,7 +182,6 @@ def main(
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -238,12 +249,9 @@ def main(
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
if detect_thinking_prompt_suffix(prompt, tokenizer):
@@ -257,10 +265,16 @@ def main(
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
if "glm" in shard_metadata.model_card.model_id.lower():
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
@@ -271,9 +285,11 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
completion_tokens = 0
for response in mlx_generator:
match response:
case GenerationResponse():
completion_tokens += 1
if (
device_rank == 0
and response.finish_reason == "error"
@@ -301,6 +317,7 @@ def main(
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
usage=response.usage,
finish_reason=response.finish_reason,
stats=response.stats,
),
@@ -314,6 +331,7 @@ def main(
chunk=ToolCallChunk(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
usage=response.usage,
),
)
)
@@ -393,6 +411,10 @@ def main(
)
)
raise
finally:
_send_traces_if_enabled(
event_sender, task.task_id, shard_metadata.device_rank
)
current_status = RunnerReady()
logger.info("runner ready")
@@ -451,6 +473,10 @@ def main(
)
)
raise
finally:
_send_traces_if_enabled(
event_sender, task.task_id, shard_metadata.device_rank
)
current_status = RunnerReady()
logger.info("runner ready")
@@ -489,9 +515,10 @@ def get_gpt_oss_encoding():
def filter_kimi_tokens(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
@@ -501,17 +528,44 @@ def filter_kimi_tokens(
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
],
usage=response.usage,
)
tool_arg_parts = []
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
@@ -528,13 +582,12 @@ def parse_gpt_oss(
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
def parse_thinking_models(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -542,6 +595,9 @@ def parse_thinking_models(
"""
first = True
for response in responses:
if isinstance(response, ToolCallResponse):
yield response
continue
if first:
first = False
yield response.model_copy(
@@ -595,6 +651,36 @@ def _send_image_chunk(
)
def _send_traces_if_enabled(
event_sender: MpSender[Event],
task_id: TaskId,
rank: int,
) -> None:
if not is_tracing_enabled():
return
traces = get_trace_buffer()
if traces:
trace_data = [
TraceEventData(
name=t.name,
start_us=t.start_us,
duration_us=t.duration_us,
rank=t.rank,
category=t.category,
)
for t in traces
]
event_sender.send(
TracesCollected(
task_id=task_id,
rank=rank,
traces=trace_data,
)
)
clear_trace_buffer()
def _process_image_response(
response: ImageGenerationResponse | PartialImageResponse,
command_id: CommandId,
@@ -612,7 +698,7 @@ def _process_image_response(
command_id=command_id,
model_id=shard_metadata.model_card.model_id,
event_sender=event_sender,
image_index=response.partial_index if is_partial else image_index,
image_index=response.image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
@@ -622,7 +708,7 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
@@ -630,6 +716,7 @@ def parse_tool_calls(
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True
@@ -647,7 +734,7 @@ def parse_tool_calls(
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
except (
json.JSONDecodeError,

View File

@@ -127,20 +127,25 @@ class RunnerSupervisor:
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(
f"Skipping invalid task {task} as it has already been submitted"
)
return
if task.task_id in self.completed:
logger.info(
logger.warning(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
try:
self._task_sender.send(task)
await self._task_sender.send_async(task)
except ClosedResourceError:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:

View File

@@ -0,0 +1,545 @@
# type: ignore
import time
from typing import cast
from unittest.mock import patch
import mlx.core as mx
import pytest
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
cache_length,
encode_prompt,
get_prefix_length,
make_kv_cache,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.tests.unittests.test_mlx.conftest import (
DEFAULT_GPT_OSS_CONFIG,
DEFAULT_GPT_OSS_MODEL_ID,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert get_prefix_length(a, b) == 0
class TestKVPrefix:
@pytest.fixture
def mock_tokenizer(self):
"""Create a minimal mock tokenizer for tests that don't need real tokenization."""
from unittest.mock import MagicMock
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
return tokenizer
def test_starts_empty(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self, mock_tokenizer):
cache = KVPrefixCache(mock_tokenizer)
cache.clear()
assert len(cache.prompts) == 0
def _load_gpt_oss() -> tuple[Model, object]:
from mlx_lm.utils import load_model
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
model, _ = load_model(model_path, lazy=False)
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
return cast(Model, model), tokenizer
@pytest.mark.slow
@pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
)
class TestKVPrefixCacheWithModel:
@pytest.fixture(scope="class")
def model_and_tokenizer(self):
model, tokenizer = _load_gpt_oss()
return model, tokenizer
def test_prefill_populates_cache(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Test exact")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
assert matched_index == 0
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
model, tokenizer = model_and_tokenizer
short_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hi")],
max_tokens=1,
)
short_prompt = apply_chat_template(tokenizer, short_task)
short_tokens = encode_prompt(tokenizer, short_prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(short_prompt, cache)
# Query with longer prompt that shares the chat template prefix
long_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content="Hi there, how are you?")
],
max_tokens=1,
)
long_prompt = apply_chat_template(tokenizer, long_task)
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, long_prompt
)
assert matched_index == 0
# remaining_tokens should be the suffix after the shared prefix
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
def test_stored_cache_not_mutated_after_get_and_generation(
self, model_and_tokenizer
):
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
assert matched_index == 0
# Simulate generation: feed many additional tokens through the cache
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra_keys, extra_values)
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
):
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra = mx.random.normal((1, num_heads, 30, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# Consume the entire generator so the cache-saving code after yield runs
generated_tokens = 0
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
generated_tokens += 1
assert len(kv_prefix_cache.prompts) == 1
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
assert len(kv_prefix_cache.prompts) == 1
# Second call should find a prefix match (the stored cache contains
# prompt + generated tokens, which shares the prompt prefix)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, prompt
)
# The stored cache is longer than the prompt (it includes generated tokens),
# so this is a prefix match where our prompt is fully contained
assert matched_index == 0
# Exact match: remaining_tokens is just the last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
repeats = (1200 // len(base_tokens)) + 2
long_content = base_text * repeats
task1 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=long_content)],
max_tokens=5,
)
prompt1 = apply_chat_template(tokenizer, task1)
prompt1_tokens = encode_prompt(tokenizer, prompt1)
assert len(prompt1_tokens) > 1000, (
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
)
# First generation populates the cache (must prefill all tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content=long_content),
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
ChatCompletionMessage(role="user", content="Tell me more."),
],
max_tokens=5,
)
prompt2 = apply_chat_template(tokenizer, task2)
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
):
pass
second_gen_time = time.perf_counter() - t0
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
assert second_gen_time < first_gen_time * 0.5, (
f"Expected prefix cache speedup: "
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
)
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
firstcache_length = cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
# The first stored cache must not have been mutated by the second generation
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
for i, content in enumerate(prompts):
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=content)],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
assert len(kv_prefix_cache.prompts) == 3
# Access the third entry to make it most recently used
kv_prefix_cache._last_used[2] = 100.0
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
# Simulate memory pressure: active memory exceeds threshold
fake_limit = 1000
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
with (
patch(
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
return_value=fake_active,
),
patch(
"exo.worker.engines.mlx.cache.mx.metal.device_info",
return_value={"max_recommended_working_set_size": fake_limit},
),
):
# Trigger eviction by adding a new entry
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="New entry")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(prompt, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
# Since fake_active stays above threshold after each eviction (we don't change it),
# all old entries get evicted, leaving only the newly added one
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

View File

@@ -11,12 +11,12 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
from exo.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,

View File

@@ -1,5 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -45,13 +45,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=download_status,
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -92,14 +88,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
RUNNER_2_ID: RunnerConnected(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -116,7 +104,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -148,30 +135,26 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
# Global state shows shard is downloaded for NODE_A
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [],
NODE_A: [
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None
assert not isinstance(result, plan_mod.DownloadModel)
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
@@ -202,12 +185,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -220,7 +197,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -245,7 +221,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,

View File

@@ -47,8 +47,7 @@ def test_plan_kills_runner_when_instance_missing():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
runners=runners, # type: ignore[arg-type]
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -87,8 +86,7 @@ def test_plan_kills_runner_when_sibling_failed():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
runners=runners, # type: ignore[arg-type]
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -120,7 +118,6 @@ def test_plan_creates_runner_when_missing_for_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners,
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -158,8 +155,7 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
runners=runners, # type: ignore[arg-type]
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -189,7 +185,6 @@ def test_plan_does_not_create_runner_for_unassigned_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,

View File

@@ -65,7 +65,6 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -113,7 +112,6 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -158,7 +156,6 @@ def test_plan_does_not_forward_tasks_for_other_instances():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -221,7 +218,6 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -261,7 +257,6 @@ def test_plan_returns_none_when_nothing_to_do():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -57,7 +57,6 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -99,7 +98,6 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -140,7 +138,6 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -185,7 +182,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -202,7 +198,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -246,7 +241,6 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -289,7 +283,6 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -331,7 +324,6 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
@@ -120,7 +120,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
@@ -147,6 +147,14 @@ class MockTokenizer:
has_tool_calling = False
class MockGroup:
def rank(self) -> int:
return 0
def size(self) -> int:
return 1
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -182,6 +190,8 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
text="hi",
token_id=0,
finish_reason="stop",
usage=None,
stats=None,
),
)

View File

@@ -11,6 +11,10 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
@@ -36,10 +40,6 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.worker.runner.bootstrap import entrypoint

View File

@@ -11,7 +11,6 @@ if [[ $# -lt 2 ]]; then
exit 1
fi
kind=$1
shift
@@ -31,14 +30,14 @@ for name in "${hostnames[@]}"; do
weaved+=("$name" "$ip")
done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
@@ -48,9 +47,8 @@ for model_id in "${model_ids[@]}"; do
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done

View File

@@ -0,0 +1,18 @@
{
"$schema": "https://opencode.ai/config.json",
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
"provider": {
"exo": {
"api": "http://localhost:52415/v1",
"models": {
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
"name": "GPT OSS 120B",
"limit": {
"context": 32768,
"output": 8192
}
}
}
}
}
}

47
tmp/set_rdma_network_config.sh Executable file
View File

@@ -0,0 +1,47 @@
#!/usr/bin/env bash
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports |
awk -F': ' '/Hardware Port: / {print $2}' |
while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*) ;;
"Thunderbolt Bridge") ;;
"Thunderbolt "*)
networksetup -listallnetworkservices |
grep -q "EXO $name" ||
networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null ||
continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices |
grep -q "$name" ||
networksetup -createnetworkservice "$name" "$name" 2>/dev/null ||
continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

2191
uv.lock generated
View File

File diff suppressed because it is too large Load Diff