Compare commits

..

33 Commits

Author SHA1 Message Date
Evan
ff8709e4c2 and add minimax 2026-01-14 11:11:33 +00:00
Evan
c6e42d4918 add glm-47 2026-01-14 11:11:33 +00:00
Jake Hillion
bdb43e1dbb nix: drop noisy echos from devshell
Drop all the printing when entering a devshell. It's annoying, and not a
super accurate description of how to develop exo anyway.
2026-01-14 10:04:57 +00:00
Jake Hillion
e4a01e2b0e chore(deps): nix lock file maintenance
Update nix flake inputs. Add a second input as Swift is currently broken
in nixpkgs on Linux for `swift-format` as we want `nix fmt` to continue
being reproducible everywhere.
2026-01-13 19:57:14 +01:00
Evan Quiney
1200a7db64 Add tensor sharding for GPT-OSS (#1144)
## Motivation

GPT OSS did not previously support tensor sharding

## Changes

Add GPT sharding support in tensor_auto_parallel.
Code is mostly @rltakashige's

## Test Plan

### Manual Testing
Tested GPT-OSS - MLX Fast Sync causes issues in Tensor RDMA - this is a general problem at the moment.
2026-01-13 17:25:52 +00:00
Evan Quiney
47ceb54bc1 up the rlimit (#1148)
Fixes #1117 

Manual testing:
Launched 100 instances. worked. yay.
2026-01-13 15:00:54 +00:00
Jake Hillion
f8112fdf25 nix: convert to flake-parts
Preparing to add a flake-parts module for Rust builds. The flake-utils
library doesn't support the module system needed for cleanly separating
the Rust build configuration.

Converted from flake-utils to flake-parts, switching to the treefmt-nix
flakeModule import pattern. The devShell and formatter outputs remain
functionally equivalent.

Test plan:
- Ran `nix flake check` successfully
- Verified `nix develop` provides the same environment
2026-01-13 15:06:44 +01:00
Alex Cheema
e388f59480 docs: add AGENTS.md for AI coding agents guidance (#1132)
## Motivation

Add documentation to help AI coding agents (Claude Code, Cursor, GitHub
Copilot, etc.) understand the exo codebase and contribute effectively.

## Changes

- Add `AGENTS.md` with guidance for AI agents working on the codebase
- Add symlink `CLAUDE.md -> AGENTS.md` for backwards compatibility with
Claude Code

## Why It Works

`AGENTS.md` is becoming a standard convention for AI agent instructions.
The symlink ensures Claude Code (which looks for `CLAUDE.md`) continues
to work while supporting the broader `AGENTS.md` convention.

## Test Plan

### Manual Testing
- Verified symlink works correctly

### Automated Testing
- N/A (documentation only)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 13:05:47 +00:00
Alex Cheema
e5e74e1eef Upgrade mlx-lm to 0.30.2 with transformers 5.x compatibility (#1125)
## Motivation

Upgrade mlx-lm to version 0.30.2 which requires transformers 5.0.0rc2 as
a prerelease dependency. This enables support for newer models like Kimi
K2 Thinking while maintaining compatibility with existing models.

The transformers 5.x release includes breaking changes that affect
custom tokenizers like Kimi's TikTokenTokenizer, requiring compatibility
fixes.

## Changes

### Core Changes
- **mlx-lm upgrade**: Bump to 0.30.2 with locked exact versions for
mlx/mlx-lm to prevent breaking changes
- **transformers 5.x compatibility**: Enable prerelease transformers
dependency

### Kimi K2 Tokenizer Fixes
- Add `bytes_to_unicode` monkey-patch to restore function moved in
transformers 5.0.0rc2
- Load `TikTokenTokenizer` directly instead of via `AutoTokenizer` to
bypass transformers 5.x bug with `auto_map` fallback
- Patch `encode()` to use tiktoken directly with `allowed_special="all"`
to handle special tokens from chat templates

### Other Changes
- Dashboard: Show disk usage for completed model downloads
- CI: Add `workflow_dispatch` trigger to build-app workflow
- Docs: Add basic API documentation

### Testing
- Add comprehensive tokenizer unit tests for all supported models
- Tests verify encode/decode, special token handling, and chat template
encoding

## Why It Works

**bytes_to_unicode issue**: transformers 5.0.0rc2 moved
`bytes_to_unicode` from `transformers.models.gpt2.tokenization_gpt2` to
`transformers.convert_slow_tokenizer`. Kimi's `tokenization_kimi.py`
imports from the old location. The monkey-patch restores it at module
load time.

**AutoTokenizer issue**: transformers 5.x has a bug where
`tokenizer_class_from_name('TikTokenTokenizer')` returns `None` for
custom tokenizers with `auto_map`. Loading the tokenizer directly
bypasses this.

**encode() issue**: transformers 5.x's `pad()` method fails for slow
tokenizers. Using tiktoken's encode directly with
`allowed_special="all"` avoids this path and properly handles special
tokens like `<|im_user|>` from chat templates.

## Test Plan

### Manual Testing
- Hardware: 2x Mac Studios connected via Thunderbolt 5 (mike22 and
james21)
- Tested Kimi K2 Thinking, GPT-OSS-120B, GPT-OSS-20B, LLama-3.1-8B-bf16, qwen3-30B-A3B-8bit model with pipeline parallelism across both
nodes
- Verified warmup inference completes successfully
- Verified chat completions work with special tokens

### Automated Testing
- Added `test_tokenizers.py` with 31 tests covering:
- Basic encode/decode for all model families (deepseek, kimi, llama,
qwen, gpt-oss, glm)
  - Special token encoding (critical for chat templates)
  - Chat template application and encoding
  - Kimi-specific and GLM-specific edge cases
- All tests pass: `uv run pytest
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py`

### Failing Tests
RDMA with all models.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-01-13 12:06:04 +00:00
Jake Hillion
b968d6f0a0 ci: remove old commented out job 2026-01-13 12:42:04 +01:00
Jake Hillion
3bfffd9b4f ci: build all Nix outputs on all platforms and push to cachix
The CI was only running `nix flake check` on ubuntu-latest, missing
builds for other platforms and not caching packages or devShells.

Added a matrix-based `nix-build` job that runs on macos-26 (aarch64-darwin),
ubuntu-latest (x86_64-linux), and ubuntu-24.04-arm (aarch64-linux). Each
job enumerates all packages and devShells via `nix flake show --json`,
builds them in a single `nix build` call for parallelization, then runs
`nix flake check`. The cachix-action pushes all built outputs automatically.

This ensures all Nix outputs are built and cached for every supported
platform, speeding up local development and CI runs.

Test plan:
- Tested jq enumeration command locally, correctly outputs devShell paths
- Verified xargs pipeline works with the enumerated outputs
2026-01-13 12:37:12 +01:00
Jake Hillion
007eb80029 nix: enable cachix
Enable cachix and push to it in the pipeline.yml workflow. This won't
cache a huge amount yet but will automatically extend our caching as we
build more of the repo with Nix in CI. It can also be used by local
users by accepting our cache to improve the speed of local builds.

Test plan:
- CI
2026-01-12 17:24:59 +01:00
Jake Hillion
8d7b6789b3 dashboard: show disk usage for completed models
The downloads dashboard showed "Completed" for finished model downloads
but provided no indication of how much disk space each model or the
total models on a node were using.

Added total_bytes field to DownloadCompleted type so the size is
preserved when a download completes. Updated the dashboard to display
the model size next to "Completed" status (e.g., "Completed (251.1GB)")
and a total disk usage line below the model count for each node (e.g.,
"502.2GB on disk").

Test plan:
- Ran unit tests for download apply and planning logic
- Type checked all modified files with basedpyright
2026-01-12 16:34:29 +01:00
Jake Hillion
3c5b7ea670 ci: add workflow_dispatch trigger to build-app
Build app is the most convenient way to get a DMG for testing, but
currently it's a bit limited. You have to push to test-app every time
which is far from ideal and requires a bit too much force pushing for my
liking.

Add the workflow_dispatch trigger. This adds a button in the actions UI
to trigger a workflow for a named branch, which means you can use your
normal dev branch instead of having to push to test-app. We'll leave
that behaviour there for now too, though it may change in future.

Filter on `"${{ github.event_name }}" == "workflow_dispatch"` and set
those to alpha as well. Will verify by pushing the first version from
`main` just in case. Unfortunately we do have to merge this before we
can test it.

Test plan:
- Looking really hard.
2026-01-12 12:14:21 +01:00
PG
b74a610537 Add a basic documentation to the api interface (#1122)
## Motivation

Adds basic api documentation

## Changes

- Add docs/api.md
- Modify README.md
2026-01-11 18:44:40 +00:00
Jake Hillion
18c4e49f91 nix: put treefmt in devshell
treefmt is a useful to be able to access directly for some formatters like
`jj fix`. Expose it in the devshell.

Test plan:
- Used with `jj fix` on a large branch. It worked.
2026-01-09 17:53:50 +01:00
Sami Khan
d85b5d3781 feat: uninstall button (#1077)
## Motivation

https://github.com/exo-explore/exo/issues/1075

## Changes

- Added in-app "Uninstall" option under Advanced menu that cleanly
removes all system components
- Added NetworkSetupHelper.uninstall() to remove LaunchDaemon, scripts,
logs, and restore network settings
- Added LaunchAtLoginHelper.disable() to unregister from login items
- Created standalone uninstall-exo.sh script for users who already
deleted the app
- Added uninstall documentation to README

<img width="386" height="577" alt="image"
src="https://github.com/user-attachments/assets/6bbcd18a-992a-409d-8791-ed5e13bbcfe0"
/>
<img width="372" height="432" alt="image"
src="https://github.com/user-attachments/assets/ee76b45d-c111-4807-ab28-3f2f20e01140"
/>


## Why It Works

The in-app uninstaller runs a privileged shell script (via AppleScript)
to launchctl bootout the daemon, remove files, and restore the
"Automatic" network location. The standalone script provides the same
cleanup for users who already deleted the app.

## Test Plan

### Manual Testing
Hardware: MacBook Pro
- Built and ran app, verified LaunchDaemon and network location were
created
- Used in-app Uninstall, verified all components removed and network
restored to Automatic
- Rebuilt app, quit normally, ran sudo ./uninstall-exo.sh, verified same
cleanup

### Automated Testing
N/A

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-01-09 14:49:08 +00:00
Evan Quiney
caafc48693 Forward tools to the models chat template properly (#1106)
We did not properly forward tools to the chat template before. This is not a full tool calling impl - but it should improve things slightly.

## Changes made

Pass tools to the hf tokenizers chat template
Join message chunks into a larger message (opencode does this sometimes - we were ignoring before)

## Future work

We need to parse the model output and normalise the return format to be compatible with the openai api.
2026-01-09 13:28:41 +00:00
Evan
cca8c9984a cleanup unused dependencies
we have a lot of dependencies we have no intent of using. kill them with
fire!

## testing
exo still launches and does the worst inference known to man on my Qwen3
instance. tests pass too!!
2026-01-09 13:11:58 +00:00
Sami Khan
d1e88def42 scrollbars fixed (#1113)
## Motivation

Fixes https://github.com/exo-explore/exo/issues/1107 - Horizontal
scrollbar always appears in instances section, and vertical scrollbar
appears too early (with just 1-2 instances on large screens).


## Changes

- Added overflow-x-hidden to remove horizontal scrollbar
- Added xl:max-h-96 for responsive vertical height (384px on xl+ screens
vs 288px default)
- Added py-px to accommodate corner accent decorations that extend 1px
outside cards

## Why It Works

- overflow-x-hidden prevents horizontal scroll regardless of content
- Larger max-height on xl screens fits 2 instances without scrollbar;
3rd triggers it
- 1px vertical padding accommodates the -top-px/-bottom-px positioned
corner accents that caused tiny overflow

## Test Plan

### Manual Testing
<img width="1190" height="868" alt="image"
src="https://github.com/user-attachments/assets/2a582328-5b4f-4490-a488-52106f2e85ef"
/>

### Automated Testing
N/A
2026-01-09 12:51:05 +00:00
Sami Khan
59e7594e34 UNKNOWN to PREPARING (#1112)
## Motivation

The "UNKNOWN" status shown when first launching an instance is confusing
and unhelpful. "PREPARING" better describes what's actually happening.

![telegram-cloud-photo-size-4-5981245965962251168-x](https://github.com/user-attachments/assets/65b0802b-fb64-4fa7-bff7-c13757035b3a)


## Changes

- Renamed status from "UNKNOWN" to "PREPARING" in dashboard
(+page.svelte)
- Renamed unknown state to preparing in macOS app
(InstanceViewModel.swift, InstanceRowView.swift)

## Why It Works

The status appears when an instance exists but runners haven't reported
status yet. "PREPARING" accurately describes this transitional state.

## Test Plan

### Manual Testing
Hardware: MacBook Pro
<img width="319" height="200" alt="image"
src="https://github.com/user-attachments/assets/9a1c3caf-026d-47ea-80d1-63c6e41d93aa"
/>

### Automated Testing
N/A
2026-01-09 11:46:51 +00:00
Chris A
c65320acd3 Fix mlx seed (#1094)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-01-09 01:40:15 +00:00
Jake Hillion
b9a78f6f3a ci: compute CURRENT_PROJECT_VERSION from semver
Previous Sparkle builds were cut from a different repo with different
build numbers, breaking version ordering. Users aren't receiving updates
because CFBundleVersion values don't reflect the actual version sequence.

Added a step to compute the build version deterministically from semver:
PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR).
Release versions use prerelease=999 to ensure they're always higher than
their prereleases (e.g., 1.0.61 > 1.0.61-alpha.3).

This ensures consistent version ordering across repos, allowing Sparkle
to correctly identify and deliver updates to users.

Test plan:
- Verified formula with test script:

```sh
compute_version() {
  VERSION="$1"
  BASE_VERSION="${VERSION%%-*}"
  MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
  MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
  PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)

  if [[ "$VERSION" == *-* ]]; then
    PRERELEASE_PART="${VERSION#*-}"
    PRERELEASE_NUM="${PRERELEASE_PART##*.}"
    if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
      PRERELEASE_NUM=0
    fi
  else
    PRERELEASE_NUM=999
  fi

  BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
  printf "%-20s -> %12s\n" "$VERSION" "$BUILD_VERSION"
}

compute_version "1.0.61-alpha.2"
compute_version "1.0.61-alpha.3"
compute_version "1.0.61"
compute_version "1.0.62-alpha.1"
compute_version "1.1.0-alpha.1"
compute_version "2.0.0-alpha.1"
compute_version "0.0.0-alpha.0"
compute_version "0.0.1-alpha.1"
compute_version "1.2.3"
compute_version "1.2.3-beta.5"
```

- Output:

```sh
Version              -> Build Number
----------------------------------------
1.0.61-alpha.2       ->   1000061002
1.0.61-alpha.3       ->   1000061003
1.0.61               ->   1000061999
1.0.62-alpha.1       ->   1000062001
1.1.0-alpha.1        ->   1001000001
2.0.0-alpha.1        ->   2000000001
0.0.0-alpha.0        ->            0
0.0.1-alpha.1        ->         1001
1.2.3                ->   1002003999
1.2.3-beta.5         ->   1002003005
```

- Confirmed ordering: alpha.2 < alpha.3 < release < next-alpha
2026-01-08 19:52:33 +01:00
Jake Hillion
8f7f0e893a ci: avoid uploading alpha appcasts
Currently alpha appcasts get uploaded. It turns out these overwrite the
standard appcast, so even though no one will update to the alpha
channel, everyone will miss regular updates while the latest build was
an alpha one.

Ideally we should combine the source of truth for both the alpha and
release channels, but as no one is using the alpha channel for yet let's
stop uploading it for now.

Test plan:

![eyes](https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExeGNwdDk0dmdscjlkZnd6eGxhcjJzdDBsYndmc2t2cnlpZDNxZnZhYSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/gKHGnB1ml0moQdjhEJ/giphy.gif)
2026-01-08 18:52:10 +01:00
Alex Cheema
4759b09d4c Use presigned URLs for bug report uploads (#1109)
## Motivation

Previously we hardcoded AWS credentials into the app.
This is not good practice.

## Changes

Use presigned URLs instead.

## Why It Works

Presigned URLs are an S3 feature for this kind of thing. They provide an
expiring presigned URL with certain permissions. In this case we have a
presigned URL with `s3:PutObject` permission that expires after 5
minutes. The client uses this presigned URL to upload a bug report
instead of using its own credentials to sign a request. This also
simplifies a lot of the Swift code.

## Test Plan

### Manual Testing
On a single MacBook, I downloaded the app and sent a bug report. It
worked and appeared in the bucket.
2026-01-08 17:17:48 +00:00
Alex Cheema
ca680185f3 Display RDMA debug info in macOS app. (#1072)
## Motivation

Often users are running into issues with RDMA. See
https://github.com/exo-explore/exo/issues?q=is%3Aissue%20rdma
Having some debug info in the macOS app will help to debug these issues.

## Changes

Displays output of the following commands in the debug info section of
the macOS app:

1. `rdma_ctl status`
2. `ibv_devices`
3. `ibv_devinfo`

## Why It Works

It displays RDMA debug info in the debug info section of the macOS app.

## Test Plan

### Manual Testing
We need to make a new build of the macOS app and check the output under
the following conditions:

1. No RDMA enabled.
2. RDMA enabled but no devices connected over TB5.
3. RDMA enabled and devices connected over TB5.
2026-01-08 15:17:00 +00:00
Jake Hillion
383309e24e fmt: add typescript formatting
Add typescript auto formatting with Prettier and treefmt-nix. Added a
.prettierrc to useTabs, which isn't the default, to reduce churn. The
rest looks okay and will be checked by CI.

Test plan:
- CI
2026-01-08 13:47:27 +00:00
Jake Hillion
55463a9806 fmt: add swift formatting
Swift code currently has no auto formatting. Add `swift-format` to the
`treefmt-nix` config to get this formatted.

As our existing Swift code uses 4-space formatting instead of the
default 2-space, also adds a custom `.swift-format

Test plan:
- CI
2026-01-08 13:34:45 +00:00
Evan Quiney
56af61fac9 add a server for distributed testing in /tests until we work out a stable solution. (#1098)
## Motivation

Testing multiple devices simultaneously requires coordination, and we
don't necessarily want to run a full EXO to test single components. We
need a mid-scale integration testing framework for distributed tests.

## Changes

Add a simple python server + bash query that runs Jaccl and Ring tests
without constructing a worker/master/networking. The query relies on all
devices being accessible over tailscale, currently.

## Test Plan

Manually tested RDMA + Ring inference on 2 nodes.
2026-01-08 12:50:04 +00:00
Evan Quiney
f76d543d98 We shouldn't fail on an HTTPException in the tier-2 discovery system. (#1104)
## Motivation

Fixed a crash we found

## Changes

try/catch return None if we get an exception instead of crashing exo

## Test Plan

### Manual Testing
Exo launches. Couldn't repro the original case this arose.
2026-01-08 12:43:34 +00:00
Sami Khan
ea841aca37 local network check (#1103)
## Motivation

After machine restart, macOS local network permission can appear enabled
in System Settings but not actually work. EXO fails to discover other
machines, and the only fix is manually toggling the permission off/on
and relaunching. Users had no way to know this was happening.

## Changes

- Added LocalNetworkChecker service that detects if local network access
is actually functional
- Added warning banner with instructions and "Open Settings" button when
blocked
- Added NSLocalNetworkUsageDescription and NSBonjourServices to
Info.plist (required by macOS)

<img width="386" height="712" alt="image"
src="https://github.com/user-attachments/assets/c6fc873d-2c6a-4c9b-89cb-f7bc7322e25b"
/>

## Why It Works

Uses NWConnection to UDP multicast address 224.0.0.251:5353 (mDNS),
which is subject to the app's actual TCC permission state. Other
approaches (NWBrowser, dns-sd subprocess) either require additional
entitlements or run with their own permissions, giving false results.

## Test Plan

### Manual Testing
Hardware: MacBook Pro
  - Toggle local network OFF in System Settings → warning banner appears
  - Toggle local network ON → warning disappears
  - Verified detection correctly reflects actual permission state

### Automated Testing
N/A
2026-01-08 12:24:46 +00:00
rltakashige
077b1bc732 exo-bench (Benchmark model pp & tg speed) (#1099)
## Motivation

This PR implements benchmarking in the style of llama-bench. The main
difficulty here is the fact that exo is not a library - it exposes an
endpoint. This means that benchmarking numbers will be inaccurate if the
API is measured.

The solution assumes nodes are set up with uv run exo (or via the app),
and then hits the new endpoint /bench/chat/completions to retrieve
generation statistics directly from mlx_lm.
<!-- Why is this change needed? What problem does it solve? -->

This will allow us to release benchmarks for models and perform
regression tests.

TODO: Performance benchmarking.
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->
- Adds /bench/chat/completions endpoint
- Adds BenchChatCompletion/Response
- Adds a logits processor to prevent response from ending early
- Adds a "Prompt Sizer" which downloads the tokenizer and dynamically
adjusts the prompt of "a" to fit the desired prompt size.
- Reduce prefill step size to 2048 for now (in future, dynamically
adjust this value)

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->
Benchmarked Llama, Qwen, DeepSeek and Kimi models. Will require several
fixes to run consistently on all configurations (to be done in the
future).
Manually tested the normal API to verify chat requests complete as
expected.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
Not really possible. Type checker passes.
2026-01-06 17:39:09 +00:00
Alex Cheema
4963c33162 Fix Discord link in README.md. Fixes #1096 (#1097)
## Motivation

Discord link expired.

## Changes

Replace discord invite link with permanent link.

## Why It Works

It's permanent now.

## Test Plan

Clicked the link. It works.
2026-01-06 14:05:09 +00:00
104 changed files with 6373 additions and 14235 deletions

View File

@@ -1,159 +0,0 @@
# EXO Benchmark Dashboard
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
## Features
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
-**Response Time Analysis**: Track average request completion times
- 🎯 **Throughput Metrics**: Tokens per second visualization
- 📈 **Request Distribution**: Success/failure breakdown over time
- 🔄 **Auto-Refresh**: Updates every 60 seconds
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
- 🔐 **Secure**: Credentials stored in browser localStorage only
- 🌐 **No Backend**: Directly accesses S3 from the browser
## Quick Start
### Option 1: Direct File Access (Simplest)
Just open the HTML file directly in your browser:
```bash
open .github/benchmark-dashboard/index.html
```
Then click "Configure AWS Credentials" and enter your keys.
### Option 2: URL Parameters (For Quick Setup)
```bash
# Serve with credentials in URL (they'll be moved to localStorage)
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET&region=us-east-1"
```
The credentials will be saved to localStorage and removed from the URL immediately.
### Option 3: Simple HTTP Server
```bash
# From repo root
python3 -m http.server 8080
# Then open: http://localhost:8080/.github/benchmark-dashboard/
```
## AWS Credentials
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
### Required IAM Permissions
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::exo-benchmark-results",
"arn:aws:s3:::exo-benchmark-results/*"
]
}
]
}
```
### Security Notes
- ✅ Credentials stored in browser `localStorage` only
- ✅ Never sent to any server (except AWS)
- ✅ All S3 access happens client-side
- ✅ Use read-only IAM credentials
- ⚠️ Don't commit credentials to git
- ⚠️ Use a dedicated read-only IAM user
## TV/Kiosk Mode
For permanent display on a TV:
### macOS
```bash
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
```
### Linux
```bash
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
```
### Auto-start on Boot
Create a simple startup script:
```bash
#!/bin/bash
# /usr/local/bin/start-benchmark-dashboard.sh
cd /path/to/exo
python3 -m http.server 8080 &
sleep 2
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
```
## Data Displayed
### Summary Cards
- **Latest Success Rate**: Most recent benchmark success percentage with trend
- **Avg Response Time**: Latest average response time in ms with trend
- **Total Benchmarks**: Count of all benchmarks run
- **Active Configurations**: Number of unique benchmark configs
### Charts
1. **Success Rate Over Time**: Line chart showing reliability trends
2. **Average Response Time**: Performance over time (lower is better)
3. **Throughput**: Tokens/second metric (higher is better)
4. **Request Distribution**: Stacked bar chart of successes/failures
## How It Works
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
3. **Downloads Results**: Fetches each JSON result file
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
## Customization
To modify the dashboard:
1. Edit `index.html`
2. Adjust `REFRESH_INTERVAL` for different polling frequency
3. Modify chart colors/styles in the Chart.js configuration
4. Add new metrics by extending the results parsing
## Troubleshooting
**"AWS credentials not configured"**
- Click "Configure AWS Credentials" and enter your keys
**"Error loading benchmark data"**
- Check AWS credentials are correct
- Verify S3 bucket name is `exo-benchmark-results`
- Ensure IAM user has read permissions
- Check browser console for detailed errors
**"No benchmark results found"**
- Wait for benchmark workflows to run
- Verify results are being uploaded to S3
- Check S3 bucket has files in `bench/` prefix
**Charts not updating**
- Check browser console for errors
- Verify network connectivity to S3
- Try refreshing the page manually

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,186 +0,0 @@
# EXO Benchmark Configurations
This directory contains configuration files for the EXO staged benchmark system.
## Overview
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
- **Prompt Length**: Number of tokens in the input prompt
- **Generation Length**: Maximum tokens to generate in the response
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
- **Iterations**: Number of requests to send in this stage
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
## Configuration Files
### `bench_simple.yaml`
A minimal configuration that replicates the behavior of the original `bench.py` script:
- Single stage with 1 iteration
- Short prompt (~20 tokens)
- Generates up to 100 tokens
This is useful for quick smoke tests.
### `bench_config.yaml`
A comprehensive multi-stage benchmark with:
1. **Warmup** (10 requests): Light load with short prompts
2. **Medium Load** (20 requests): Moderate load with medium prompts
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
4. **Cooldown** (5 requests): Light load to wind down
This tests the cluster's behavior under varying load patterns.
## Configuration Schema
```yaml
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
M3ULTRA_GPU80_512GB: 4
# Environment variables to set on each node (optional)
environment:
OVERRIDE_MEMORY_MB: 512
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 600
# Model instances to run concurrently
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
# Benchmark stages
stages:
- name: "stage_name" # Human-readable name for this stage
prompt_length: 100 # Target prompt length in tokens
generation_length: 200 # Max tokens to generate
time_between_requests: 2.0 # Seconds between firing requests
iterations: 10 # Number of requests in this stage
```
## Running Benchmarks
### Via GitHub Actions
**Automatic (every commit):**
- The **`bench`** workflow runs automatically on every push
- Uses `bench_simple.yaml` as the default configuration
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
**Manual (on-demand):**
1. Go to **Actions****bench** workflow
2. Click **Run workflow**
3. Configure:
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
- `.github/configs/bench_simple.yaml` for quick tests
- `.github/configs/bench_config.yaml` for complex multi-stage tests
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
### Via Command Line
```bash
# Start EXO on localhost:8000
uv run exo --api-port 8000
# Run simple benchmark (1 stage, 1 iteration)
python3 .github/scripts/bench.py \
--api-port 8000 \
--config .github/configs/bench_simple.yaml \
--expected-nodes 1 \
--is-primary true \
--timeout-seconds 600
# Run complex staged benchmark (4 stages, multiple iterations)
python3 .github/scripts/bench.py \
--api-port 8000 \
--config .github/configs/bench_config.yaml \
--expected-nodes 1 \
--is-primary true \
--timeout-seconds 600
```
## Output Metrics
For each stage, the benchmark reports:
- **Total Requests**: Number of requests fired
- **Successful Requests**: Requests that completed successfully
- **Failed Requests**: Requests that encountered errors
- **Success Rate**: Percentage of successful requests
- **Total Tokens**: Sum of all tokens generated across successful requests
- **Avg Tokens/Request**: Average tokens per successful request
- **Avg Time/Request**: Average completion time per successful request
A JSON summary is also printed for easy parsing and storage.
## Creating Custom Benchmarks
To create a custom benchmark:
1. Copy an existing config file (e.g., `bench_config.yaml`)
2. Modify the stages to match your test scenario
3. Save it in this directory with a descriptive name
4. Run it using the workflow or command line
### Example: Sustained Load Test
```yaml
hardware_plan:
M3ULTRA_GPU80_512GB: 2
environment:
OVERRIDE_MEMORY_MB: 1024
timeout_seconds: 600
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
stages:
- name: "sustained_load"
prompt_length: 200
generation_length: 150
time_between_requests: 0.5 # Very fast - 2 requests/second
iterations: 100 # Run for ~50 seconds
```
### Example: Varying Prompt Sizes
```yaml
hardware_plan:
M4PRO_GPU16_24GB: 3
timeout_seconds: 900
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
stages:
- name: "tiny_prompts"
prompt_length: 10
generation_length: 100
time_between_requests: 1.0
iterations: 10
- name: "medium_prompts"
prompt_length: 200
generation_length: 100
time_between_requests: 1.0
iterations: 10
- name: "large_prompts"
prompt_length: 1000
generation_length: 100
time_between_requests: 1.0
iterations: 10
```
## Tips
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits

View File

@@ -1,49 +0,0 @@
# EXO Staged Benchmark Configuration
# This configuration defines a multi-stage load test for EXO clusters
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
M3ULTRA_GPU80_512GB: 4
# Environment variables to set on each node (optional)
environment:
OVERRIDE_MEMORY_MB: 512
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 600
# Multiple instances run concurrently on the cluster
model_ids:
- "mlx-community/Qwen3-0.6B-4bit"
- "mlx-community/Qwen3-0.6B-4bit"
# Stages run sequentially, each with its own characteristics
stages:
# Stage 1: Light load with short prompts
- name: "warmup"
prompt_length: 50 # Number of tokens in prompt
generation_length: 100 # Max tokens to generate
time_between_requests: 5.0 # Seconds between firing requests
iterations: 10 # Number of requests to send in this stage
# Stage 2: Medium load with medium prompts
- name: "medium_load"
prompt_length: 200
generation_length: 150
time_between_requests: 3.0
iterations: 20
# Stage 3: Heavy load with long prompts - requests will overlap
- name: "stress_test"
prompt_length: 500
generation_length: 200
time_between_requests: 1.0 # Fast firing - will definitely overlap
iterations: 30
# Stage 4: Cool down with simple prompts
- name: "cooldown"
prompt_length: 50
generation_length: 50
time_between_requests: 10.0
iterations: 5

View File

@@ -1,125 +0,0 @@
# Simple single-shot benchmark
# Tests 2 instances concurrently on 2 nodes
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
puffin4: 1
puffin8: 1
# Environment variables to set on each node
environment:
PLACEHOLDER: "placeholder"
# OVERRIDE_MEMORY_MB: 50000
MLX_METAL_FAST_SYNCH: 1
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 1800
# Model instances to run concurrently
model_ids:
# - "mlx-community/DeepSeek-V3.1-8bit"
# - "mlx-community/Kimi-K2-Instruct-4bit"
- "mlx-community/Kimi-K2-Thinking"
# - "mlx-community/Qwen3-235B-A22B-4bit"
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
# Sharding strategy: "Pipeline" or "Tensor"
sharding: "Tensor"
# Instance type: "MlxRing" or "MlxIbv"
instance_meta: "MlxIbv"
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
no_overlap: true
# Benchmark stages
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
# g: 64, 512
stages:
# - name: "simple"
# prompt_length: 512
# generation_length: 10
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g64"
# prompt_length: 64
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g64"
# prompt_length: 64
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g512"
# prompt_length: 64
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp256_g64"
# prompt_length: 256
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
- name: "pp256_g64"
prompt_length: 256
generation_length: 64
time_between_requests: 2.0
iterations: 5
# - name: "pp256_g512"
# prompt_length: 256
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp1024_g64"
# prompt_length: 1024
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp1024_g512"
# prompt_length: 1024
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp2048_g64"
# prompt_length: 2048
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp2048_g512"
# prompt_length: 2048
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp4096_g64"
# prompt_length: 4096
# generation_length: 64
# time_between_requests: 2.0
# iterations: 4
# - name: "pp4096_g512"
# prompt_length: 4096
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp8192_g64"
# prompt_length: 8192
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp8192_g512"
# prompt_length: 8192
# generation_length: 512
# time_between_requests: 2.0
# iterations: 5
# - name: "pp16384_g64"
# prompt_length: 16384
# generation_length: 64
# time_between_requests: 2.0
# iterations: 10
# - name: "pp16384_g512"
# prompt_length: 16384
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10

1399
.github/scripts/bench.py vendored
View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,70 +0,0 @@
#!/usr/bin/env python3
import json
import os
from typing import NotRequired, TypedDict, cast
import yaml
class MatrixEntry(TypedDict):
label: str
index: int
class MatrixInclude(TypedDict):
label: str
index: int
is_primary: bool
expected_nodes: int
class Config(TypedDict):
hardware_plan: dict[str, int]
timeout_seconds: NotRequired[int]
environment: NotRequired[dict[str, str]]
# Read the config file
config_file: str = os.environ["CONFIG_FILE"]
with open(config_file, "r") as f:
config: Config = cast(Config, yaml.safe_load(f))
# Extract hardware plan from config
plan: dict[str, int] = config["hardware_plan"]
if not plan:
raise ValueError(f"No hardware_plan found in {config_file}")
# Build matrix entries
entries: list[MatrixEntry] = []
for label, count in plan.items():
for idx in range(count):
entries.append({"label": label, "index": idx})
total_nodes: int = len(entries)
matrix: dict[str, list[MatrixInclude]] = {
"include": [
{
"label": e["label"],
"index": e["index"],
"is_primary": (i == 0),
"expected_nodes": total_nodes,
}
for i, e in enumerate(entries)
]
}
# Extract other config values
timeout_seconds: int = config.get("timeout_seconds", 600)
environment: dict[str, str] = config.get("environment", {})
# Output to GitHub Actions
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
f.write(f"matrix={json.dumps(matrix)}\n")
f.write(f"config_file={config_file}\n")
f.write(f"timeout_seconds={timeout_seconds}\n")
f.write(f"environment={json.dumps(environment)}\n")
print(f"Matrix: {json.dumps(matrix)}")
print(f"Config file: {config_file}")
print(f"Timeout: {timeout_seconds}")
print(f"Environment: {json.dumps(environment)}")

View File

@@ -1,156 +0,0 @@
# Benchmark Workflow Usage
## Overview
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
## Workflow Inputs
| Input | Description | Default | Required |
|-------|-------------|---------|----------|
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
## Hardware Plan Format
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
```json
{
"M4PRO_GPU16_24GB": 2,
"M3ULTRA_GPU80_512GB": 1
}
```
This example would:
- Start 2 runners with the `M4PRO_GPU16_24GB` label
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
- Total of 3 runners coordinating on a single distributed inference instance
## How It Works
1. **Planning Job** (`plan`)
- Runs on `ubuntu-latest`
- Parses the `hardware_plan` JSON
- Generates a dynamic matrix with one entry per runner
- Only the first runner (index 0) is marked as `is_primary`
2. **Benchmark Worker Jobs** (`bench_worker`)
- Each job runs on a self-hosted macOS runner with the specified label
- All runners start EXO in parallel
- The primary runner creates the model instance
- All runners wait for their assigned runner to be ready (Loaded/Running status)
- The primary runner executes the benchmark and prints results
- The primary runner deletes the instance
## Example Usage
### Single Machine Benchmark
```yaml
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
prompt: What is the capital of France?
timeout_seconds: 600
```
### Multi-Machine Distributed Benchmark
```yaml
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
prompt: Explain quantum computing in simple terms.
timeout_seconds: 900
```
## Benchmark Output
The primary runner outputs a JSON object with benchmark results:
```json
{
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"instance_id": "abc-123-def",
"tokens": 42,
"elapsed_s": 2.451,
"tps": 17.136
}
```
Where:
- `tokens`: Number of chunks/tokens generated
- `elapsed_s`: Total elapsed time in seconds
- `tps`: Tokens per second (tokens / elapsed_s)
## Runner Requirements
Each self-hosted runner must:
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
- Have the `self-hosted` and `macOS` labels
- Have Nix installed with flakes enabled
- Have network connectivity to other runners in the same job
## Architecture
```
┌─────────────────────────────────────────────────────────────┐
│ GitHub Actions Workflow (bench_matrix.yml) │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌────────────────┐ │
│ │ Plan Job │ │
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
│ └────────────────┘ │ │
│ │ │
│ ┌───────────────────▼──────────────────────────────────┐ │
│ │ Bench Worker Jobs (Matrix) │ │
│ ├──────────────────────────────────────────────────────┤ │
│ │ │ │
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
│ │ │ Print TPS │ │ │ │ │ │ │
│ │ │ Delete Inst │ │ │ │ │ │ │
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
│ └───────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
## Implementation Details
### `scripts/bench.py`
A standalone Python script that:
- Creates instance (primary only)
- Polls `/state` endpoint until instance and all runners are ready
- Executes chat completion with timing (primary only)
- Parses SSE stream and counts tokens
- Computes TPS metrics
- Cleans up instance (primary only)
### Key Functions
- `wait_for_instance()`: Polls until instance with model_id appears
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
## Troubleshooting
### Instance never becomes ready
- Check EXO logs in the workflow output
- Verify model_id is valid and accessible
- Increase `timeout_seconds`
### Runner mismatch
- Ensure hardware_plan counts match available labeled runners
- Check runner labels match exactly (case-sensitive)
### Network issues
- Verify runners can communicate on the network
- Check firewall rules between runner hosts

View File

@@ -1,305 +0,0 @@
name: bench
on: [push]
jobs:
plan:
if: contains(github.event.head_commit.message, '/bench')
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.build.outputs.matrix }}
config_file: ${{ steps.build.outputs.config_file }}
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
environment: ${{ steps.build.outputs.environment }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Build matrix from config file
id: build
shell: bash
run: |
set -euo pipefail
CONFIG_FILE='.github/configs/bench_simple.yaml'
export CONFIG_FILE
echo "Config file: $CONFIG_FILE"
python3 .github/scripts/build_matrix.py
bench_worker:
needs: plan
strategy:
fail-fast: false
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
# TODO: this is mega hacky and I'd like a simpler solution.
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix is already available
if command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
# Try sourcing profile scripts to set up environment properly
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Sourcing multi-user nix-daemon profile script"
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
echo "Sourcing single-user nix profile script"
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
echo "Sourcing per-user nix profile script"
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
elif [ -f /etc/profile.d/nix.sh ]; then
echo "Sourcing system-wide nix profile script"
source /etc/profile.d/nix.sh
# Fallback: manually add nix to PATH if binary exists
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary, manually adding to PATH"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
echo "Found nix binary in user profile, manually adding to PATH"
export PATH="$HOME/.nix-profile/bin:$PATH"
else
echo "Nix not found. Debugging info:"
echo "USER: $USER"
echo "HOME: $HOME"
echo "Current PATH: $PATH"
echo ""
echo "Checking common Nix locations:"
echo " /nix/var/nix/profiles/default/bin/nix:"
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
echo ""
echo "/nix directory structure:"
ls -la /nix 2>/dev/null || echo " /nix directory not found"
echo ""
echo "/nix/var:"
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
echo ""
echo "/nix/store:"
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
echo ""
echo "GitHub Actions runner is running as user '$USER'."
echo "If Nix is installed for a different user, either:"
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
echo " 2. Configure the runner service to run as the user with Nix installed"
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
exit 1
fi
# Verify nix is available and persist to GITHUB_ENV
if command -v nix >/dev/null 2>&1; then
echo "✓ Nix is available"
nix --version
echo "PATH=$PATH" >> $GITHUB_ENV
if [ -n "$NIX_PATH" ]; then
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
fi
else
echo "ERROR: Failed to set up Nix"
echo "PATH after setup attempt: $PATH"
exit 1
fi
shell: bash
- name: Setup EXO_HOME and API_PORT
run: |
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
EXO_MODELS_DIR="$HOME/.exo/models"
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
echo "Created EXO_HOME: $EXO_HOME"
echo "Generated API_PORT: $API_PORT"
echo "Using models from: $EXO_MODELS_DIR"
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
shell: bash
- name: Configure local MLX if available
run: |
echo "=== DEBUG: Checking for local MLX configuration ==="
MODIFIED=false
echo "Checking for /Users/Shared/mlx directory..."
if [ -d "/Users/Shared/mlx" ]; then
echo "✓ Found /Users/Shared/mlx"
ls -la /Users/Shared/mlx | head -5
echo "Enabling local mlx path in pyproject.toml"
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
MODIFIED=true
else
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
fi
echo "Checking for /Users/Shared/mlx-lm directory..."
if [ -d "/Users/Shared/mlx-lm" ]; then
echo "✓ Found /Users/Shared/mlx-lm"
ls -la /Users/Shared/mlx-lm | head -5
echo "Enabling local mlx-lm path in pyproject.toml"
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
MODIFIED=true
else
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
fi
if [ "$MODIFIED" = true ]; then
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
echo "=== Regenerating uv.lock with local MLX paths... ==="
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
echo "✓ Lock file regenerated"
else
echo "⚠ No local MLX directories found, using PyPI packages"
fi
echo "=== DEBUG: Local MLX configuration complete ==="
shell: bash
- name: Sync dependencies
run: |
if [ -d "/Users/Shared/test" ]; then
pushd /Users/Shared/test
uv sync --reinstall
popd
fi
echo "Running just sync to ensure clean dependencies..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
shell: bash
- name: Start EXO and run bench script
shell: bash
env:
IS_PRIMARY: ${{ matrix.is_primary }}
EXPECTED_NODES: ${{ matrix.expected_nodes }}
HARDWARE_LABEL: ${{ matrix.label }}
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
run: |
set -euo pipefail
# Parse environment variables from config
ENV_VARS=""
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
fi
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
echo "Environment variables from config: $ENV_VARS"
LOG_FILE=/tmp/exo.log
: > "$LOG_FILE"
MASTER_FLAG=""
if [ "$IS_PRIMARY" = "true" ]; then
MASTER_FLAG="-m"
fi
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
>> "$LOG_FILE" 2>&1 &
EXO_PID=$!
echo "Started EXO in background with PID: $EXO_PID"
echo "Log file: $LOG_FILE"
cleanup() {
echo '=== EXO log (tail) ==='
tail -n 300 "$LOG_FILE" || true
if ps -p "$EXO_PID" >/dev/null 2>&1; then
echo "Killing EXO (PID $EXO_PID)"
kill "$EXO_PID" || true
fi
}
trap cleanup EXIT
for i in $(seq 1 60); do
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
echo "EXO API ready"
break
fi
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
fi
sleep 1
done
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
echo "Results will be saved to: $RESULTS_FILE"
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
--api-port $API_PORT \
--config $CONFIG_FILE \
--expected-nodes ${EXPECTED_NODES} \
--is-primary ${IS_PRIMARY} \
--timeout-seconds ${TIMEOUT_SECONDS} \
--output $RESULTS_FILE \
--git-commit ${GITHUB_SHA} \
--hardware-labels ${HARDWARE_LABEL}"
- name: Install AWS CLI
if: always() && env.RESULTS_FILE && matrix.is_primary
run: |
if ! command -v aws &> /dev/null; then
echo "AWS CLI not found, installing..."
brew install awscli
else
echo "AWS CLI already installed"
fi
shell: bash
- name: Upload results to S3
if: always() && env.RESULTS_FILE && matrix.is_primary
env:
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: us-east-1
run: |
echo "Checking for results file: $RESULTS_FILE"
echo "Is primary: ${{ matrix.is_primary }}"
if [ -f "$RESULTS_FILE" ]; then
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
--content-type application/json \
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
echo "Results uploaded successfully"
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
else
echo "Results file not found at: $RESULTS_FILE"
echo "Skipping upload"
fi
shell: bash
- name: Cleanup EXO_HOME
run: |
echo "Cleaning up EXO_HOME: $EXO_HOME"
rm -rf "$EXO_HOME"
shell: bash
if: always()

View File

@@ -1,6 +1,7 @@
name: Build EXO macOS DMG
on:
workflow_dispatch:
push:
tags:
- "v*"
@@ -18,6 +19,7 @@ jobs:
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
AWS_REGION: ${{ secrets.AWS_REGION }}
EXO_BUILD_NUMBER: ${{ github.run_number }}
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
@@ -34,7 +36,7 @@ jobs:
- name: Derive release version from tag
run: |
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
VERSION="0.0.0-alpha.0"
echo "IS_ALPHA=true" >> $GITHUB_ENV
else
@@ -47,6 +49,32 @@ jobs:
fi
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
- name: Compute build version from semver
run: |
VERSION="$RELEASE_VERSION"
# Extract major.minor.patch (strip prerelease suffix)
BASE_VERSION="${VERSION%%-*}"
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
if [[ "$VERSION" == *-* ]]; then
PRERELEASE_PART="${VERSION#*-}"
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
# Default to 0 if not a number
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
PRERELEASE_NUM=0
fi
else
PRERELEASE_NUM=999
fi
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
echo "Computed build version: $BUILD_VERSION from $VERSION"
- name: Ensure tag commit is on main
if: github.ref_type == 'tag'
run: |
@@ -162,11 +190,12 @@ jobs:
-configuration Release \
-derivedDataPath build \
MARKETING_VERSION="$RELEASE_VERSION" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
EXO_BUILD_TAG="$RELEASE_VERSION" \
EXO_BUILD_COMMIT="$GITHUB_SHA" \
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
mkdir -p ../../output
@@ -294,5 +323,5 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
if [[ "$IS_ALPHA" != "true" ]]; then
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache

View File

@@ -20,6 +20,12 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
@@ -88,9 +94,19 @@ jobs:
- uses: ./.github/actions/typecheck
nix-flake-check:
name: Check Nix flake
runs-on: ubuntu-latest
nix:
name: Build and check (${{ matrix.system }})
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- runner: macos-26
system: aarch64-darwin
- runner: ubuntu-latest
system: x86_64-linux
- runner: ubuntu-24.04-arm
system: aarch64-linux
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -101,83 +117,20 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Run nix flake check
run: |
nix flake check
shell: bash
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
# ci:
# needs: typecheck
# runs-on: ubuntu-latest
# permissions:
# contents: read
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
# token: ${{ secrets.GITHUB_TOKEN }}
# lfs: true
#
# - name: Configure git user
# run: |
# git config --local user.email "github-actions@users.noreply.github.com"
# git config --local user.name "github-actions bot"
# shell: bash
#
# - name: Pull LFS files
# run: |
# echo "Pulling Git LFS files..."
# git lfs pull
# shell: bash
#
# - name: Setup EXO_HOME and API_PORT
# run: |
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
# # Generate random port (macOS compatible method)
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
# echo "Created EXO_HOME: $EXO_HOME"
# echo "Generated API_PORT: $API_PORT"
# shell: bash
#
# - name: Setup Nix Environment
# run: |
# echo "Checking for nix installation..."
#
# # Check if nix binary exists directly
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
# echo "PATH=$PATH" >> $GITHUB_ENV
# nix --version
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
# echo "Found nix profile script, sourcing..."
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
# nix --version
# elif command -v nix >/dev/null 2>&1; then
# echo "Nix already in PATH"
# nix --version
# else
# echo "Nix not found. Debugging info:"
# echo "Contents of /nix/var/nix/profiles/default/:"
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
# exit 1
# fi
# shell: bash
#
# - uses: ./.github/actions/lint-check
#
# - uses: ./.github/actions/unit-test
#
# - name: Cleanup EXO_HOME
# run: |
# echo "Cleaning up EXO_HOME: $EXO_HOME"
# rm -rf "$EXO_HOME"
# shell: bash
# if: always()
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
[
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
] | .[]
' | xargs nix build
- name: Run nix flake check
run: nix flake check

1
.gitignore vendored
View File

@@ -16,6 +16,7 @@ digest.txt
*.xcuserdatad/
**/.DS_Store
app/EXO/build/
dist/
# rust

View File

@@ -0,0 +1,156 @@
"""Type stubs for mlx_lm.models.deepseek_v3"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: Optional[int]
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
attention_bias: bool
class DeepseekV3Attention(nn.Module):
config: ModelArgs
hidden_size: int
num_heads: int
max_position_embeddings: int
rope_theta: float
q_lora_rank: Optional[int]
qk_rope_head_dim: int
kv_lora_rank: int
v_head_dim: int
qk_nope_head_dim: int
q_head_dim: int
scale: float
q_proj: nn.Linear
q_a_proj: nn.Linear
q_a_layernorm: nn.RMSNorm
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
o_proj: nn.Linear
rope: Any
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3MLP(nn.Module):
config: ModelArgs
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class MoEGate(nn.Module):
config: ModelArgs
top_k: int
norm_topk_prob: bool
n_routed_experts: Optional[int]
routed_scaling_factor: float
n_group: int
topk_group: int
weight: mx.array
e_score_correction_bias: mx.array
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class DeepseekV3MoE(nn.Module):
config: ModelArgs
num_experts_per_tok: int
switch_mlp: SwitchGLU
gate: MoEGate
shared_experts: DeepseekV3MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class DeepseekV3DecoderLayer(nn.Module):
self_attn: DeepseekV3Attention
mlp: DeepseekV3MLP | DeepseekV3MoE
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3Model(nn.Module):
vocab_size: int
embed_tokens: nn.Embedding
layers: list[DeepseekV3DecoderLayer]
norm: nn.RMSNorm
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
class Model(nn.Module):
model_type: str
model: DeepseekV3Model
lm_head: nn.Linear
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[DeepseekV3DecoderLayer]: ...

View File

@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
def __call__(self, x, gate): ...
class SwitchGLU(nn.Module):
gate_proj: SwitchLinear
up_proj: SwitchLinear
down_proj: SwitchLinear
activation: SwiGLU
def __init__(
self,
input_dims: int,

View File

@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
from functools import partial
from pathlib import Path
from typing import Any
from transformers import PreTrainedTokenizerFast
@@ -103,37 +104,55 @@ class TokenizerWrapper:
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
def add_eos_token(self, token: str): # -> None:
...
@property
def has_thinking(self): # -> bool:
...
@property
def think_start(self): # -> str | None:
...
@property
def think_end(self): # -> str | None:
...
@property
def has_tool_calling(self): # -> bool:
...
@property
def tool_call_start(self): # -> str | None:
...
@property
def tool_call_end(self): # -> str | None:
...
@property
def detokenizer(self): # -> NaiveStreamingDetokenizer:
"""
Get a stateful streaming detokenizer.
"""
def __getattr__(self, attr): # -> set[Any] | Any:
...
def __setattr__(self, attr, value): # -> None:
...
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
all_special_tokens: list[str]
def __init__(
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,
tool_call_end: str | None = ...,
) -> None: ...
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
def apply_chat_template(
self,
messages: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = False,
tools: Any = None,
**kwargs: Any,
) -> str: ...
def get_vocab(self) -> dict[str, int]: ...
def add_eos_token(self, token: str) -> None: ...
@property
def has_thinking(self) -> bool: ...
@property
def think_start(self) -> str | None: ...
@property
def think_end(self) -> str | None: ...
@property
def has_tool_calling(self) -> bool: ...
@property
def tool_call_start(self) -> str | None: ...
@property
def tool_call_end(self) -> str | None: ...
@property
def detokenizer(self) -> NaiveStreamingDetokenizer:
"""Get a stateful streaming detokenizer."""
def __getattr__(self, attr: str) -> Any: ...
def __setattr__(self, attr: str, value: Any) -> None: ...
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load_tokenizer(
def load(
model_path: Path,
tokenizer_config_extra=...,
return_tokenizer=...,
eos_token_ids=...,
) -> (
TokenizerWrapper
| type[SPMStreamingDetokenizer]
| partial[SPMStreamingDetokenizer]
| type[BPEStreamingDetokenizer]
| type[NaiveStreamingDetokenizer]
):
tokenizer_config_extra: dict[str, Any] | None = None,
eos_token_ids: list[int] | int | None = None,
) -> TokenizerWrapper:
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@@ -165,4 +177,7 @@ def load_tokenizer(
a Hugging Face repo ID.
"""
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
# Alias for backward compatibility
load_tokenizer = load
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...

3
.prettierrc Normal file
View File

@@ -0,0 +1,3 @@
{
"useTabs": true
}

6
.swift-format Normal file
View File

@@ -0,0 +1,6 @@
{
"version": 1,
"indentation": {
"spaces": 4
}
}

96
AGENTS.md Normal file
View File

@@ -0,0 +1,96 @@
# AGENTS.md
This file provides guidance to AI coding agents when working with code in this repository.
## Project Overview
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
## Build & Run Commands
```bash
# Build the dashboard (required before running exo)
cd dashboard && npm install && npm run build && cd ..
# Run exo (starts both master and worker with API at http://localhost:52415)
uv run exo
# Run with verbose logging
uv run exo -v # or -vv for more verbose
# Run tests (excludes slow tests by default)
uv run pytest
# Run all tests including slow tests
uv run pytest -m ""
# Run a specific test file
uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
```
## Architecture
### Node Composition
A single exo `Node` (src/exo/main.py) runs multiple components:
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
- **Worker**: Handles inference tasks, downloads models, manages runner processes
- **Master**: Coordinates cluster state, places model instances across nodes
- **Election**: Bully algorithm for master election
- **API**: FastAPI server for OpenAI-compatible chat completions
### Message Flow
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
- `LOCAL_EVENTS`: Workers send events to master for indexing
- `COMMANDS`: Workers/API send commands to master
- `ELECTION_MESSAGES`: Election protocol messages
- `CONNECTION_MESSAGES`: libp2p connection updates
### Event Sourcing
The system uses event sourcing for state management:
- `State` (src/exo/shared/types/state.py): Immutable state object
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
- Master indexes events and broadcasts; workers apply indexed events
### Key Type Hierarchy
- `src/exo/shared/types/`: Pydantic models for all shared types
- `events.py`: Event types (discriminated union)
- `commands.py`: Command types
- `tasks.py`: Task types for worker execution
- `state.py`: Cluster state model
### Rust Components
Rust code in `rust/` provides:
- `networking`: libp2p networking (gossipsub, peer discovery)
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
- `system_custodian`: System-level operations
### Dashboard
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
## Code Style Requirements
From .cursorrules:
- Strict, exhaustive typing - never bypass the type-checker
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
- Pydantic models with `frozen=True` and `strict=True`
- Pure functions with injectable effect handlers for side-effects
- Descriptive names - no abbreviations or 3-letter acronyms
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

41
MISSED_THINGS.md Normal file
View File

@@ -0,0 +1,41 @@
# Missed things
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] We put cache limit back in utils_mlx.py.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).

View File

@@ -8,7 +8,7 @@
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
</p>
@@ -166,6 +166,24 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
#### Uninstalling the macOS App
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
If you've already deleted the app, you can run the standalone uninstaller script:
```bash
sudo ./app/EXO/uninstall-exo.sh
```
This removes:
- Network setup LaunchDaemon
- Network configuration script
- Log files
- The "exo" network location
**Note:** You'll need to manually remove EXO from Login Items in System Settings → General → Login Items.
---
### Enabling RDMA on macOS
@@ -287,7 +305,10 @@ curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
- List all models: `curl http://localhost:52415/models`
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
For further details, see:
- API basic documentation in [docs/api.md](docs/api.md).
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
---

View File

@@ -12,20 +12,25 @@ struct ContentView: View {
@EnvironmentObject private var controller: ExoProcessController
@EnvironmentObject private var stateService: ClusterStateService
@EnvironmentObject private var networkStatusService: NetworkStatusService
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
@EnvironmentObject private var updater: SparkleUpdater
@State private var focusedNode: NodeViewModel?
@State private var deletingInstanceIDs: Set<String> = []
@State private var showAllNodes = false
@State private var showAllInstances = false
@State private var showAdvanced = false
@State private var showDebugInfo = false
@State private var bugReportInFlight = false
@State private var bugReportMessage: String?
@State private var showAdvancedOptions = false
@State private var uninstallInProgress = false
@State private var pendingNamespace: String = ""
var body: some View {
VStack(alignment: .leading, spacing: 12) {
statusSection
if shouldShowLocalNetworkWarning {
localNetworkWarningBanner
}
if shouldShowClusterDetails {
Divider()
overviewSection
@@ -40,6 +45,7 @@ struct ContentView: View {
}
.animation(.easeInOut(duration: 0.3), value: shouldShowClusterDetails)
.animation(.easeInOut(duration: 0.3), value: shouldShowInstances)
.animation(.easeInOut(duration: 0.3), value: shouldShowLocalNetworkWarning)
.padding()
.frame(width: 340)
.onAppear {
@@ -49,9 +55,62 @@ struct ContentView: View {
}
}
private var shouldShowLocalNetworkWarning: Bool {
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}
return false
}
private var localNetworkWarningBanner: some View {
VStack(alignment: .leading, spacing: 6) {
HStack(spacing: 6) {
Image(systemName: "exclamationmark.triangle.fill")
.foregroundColor(.orange)
Text("Local Network Access Issue")
.font(.caption)
.fontWeight(.semibold)
}
Text(
"Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO"
)
.font(.caption2)
.foregroundColor(.secondary)
.fixedSize(horizontal: false, vertical: true)
Button {
openLocalNetworkSettings()
} label: {
Text("Open Settings")
.font(.caption2)
}
.buttonStyle(.bordered)
.controlSize(.small)
}
.padding(8)
.background(
RoundedRectangle(cornerRadius: 8)
.fill(Color.orange.opacity(0.1))
)
.overlay(
RoundedRectangle(cornerRadius: 8)
.stroke(Color.orange.opacity(0.3), lineWidth: 1)
)
}
private func openLocalNetworkSettings() {
// Open Privacy & Security settings - Local Network section
if let url = URL(
string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork")
{
NSWorkspace.shared.open(url)
}
}
private var topologySection: some View {
Group {
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
if let topology = stateService.latestSnapshot?.topologyViewModel(
localNodeId: stateService.localNodeId), !topology.nodes.isEmpty
{
TopologyMiniView(topology: topology)
}
}
@@ -85,8 +144,10 @@ struct ContentView: View {
VStack(alignment: .leading, spacing: 4) {
HStack {
VStack(alignment: .leading) {
Text("\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB")
.font(.headline)
Text(
"\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB"
)
.font(.headline)
Text("Memory")
.font(.caption)
.foregroundColor(.secondary)
@@ -195,13 +256,7 @@ struct ContentView: View {
Divider()
.padding(.vertical, 4)
}
controlButton(title: "Check for Updates") {
updater.checkForUpdates()
}
.padding(.bottom, 8)
advancedOptionsSection
.padding(.bottom, 8)
debugSection
advancedSection
.padding(.bottom, 8)
controlButton(title: "Quit", tint: .secondary) {
controller.stop()
@@ -210,7 +265,57 @@ struct ContentView: View {
}
}
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void) -> some View {
private var advancedSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvanced)
}
.animation(nil, value: showAdvanced)
if showAdvanced {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
HoverButton(title: "Check for Updates", small: true) {
updater.checkForUpdates()
}
debugSection
HoverButton(title: "Uninstall", tint: .red, small: true) {
showUninstallConfirmationAlert()
}
.disabled(uninstallInProgress)
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvanced)
}
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void)
-> some View
{
HoverButton(title: title, tint: tint, trailingSystemImage: nil, action: action)
}
@@ -241,9 +346,12 @@ struct ContentView: View {
Button {
isExpanded.wrappedValue.toggle()
} label: {
Label(isExpanded.wrappedValue ? "Hide" : "Show All", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
Label(
isExpanded.wrappedValue ? "Hide" : "Show All",
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
)
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
}
.buttonStyle(.plain)
.font(.caption2)
@@ -331,57 +439,16 @@ struct ContentView: View {
}
}
private var advancedOptionsSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced Options")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvancedOptions)
}
.animation(nil, value: showAdvancedOptions)
if showAdvancedOptions {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Debug Info")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showDebugInfo)
VStack(alignment: .leading, spacing: 4) {
HoverButton(
title: "Debug Info",
tint: .primary,
trailingSystemImage: showDebugInfo ? "chevron.up" : "chevron.down",
small: true
) {
showDebugInfo.toggle()
}
.animation(nil, value: showDebugInfo)
if showDebugInfo {
VStack(alignment: .leading, spacing: 4) {
Text("Version: \(buildTag)")
@@ -394,15 +461,63 @@ struct ContentView: View {
.font(.caption2)
.foregroundColor(thunderboltStatusColor)
interfaceIpList
rdmaStatusView
sendBugReportButton
.padding(.top, 6)
}
.padding(.leading, 8)
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
}
private var rdmaStatusView: some View {
let rdma = networkStatusService.status.rdmaStatus
return VStack(alignment: .leading, spacing: 1) {
Text("RDMA: \(rdmaStatusText(rdma))")
.font(.caption2)
.foregroundColor(rdmaStatusColor(rdma))
if !rdma.devices.isEmpty {
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
.font(.caption2)
.foregroundColor(.secondary)
}
if !rdma.activePorts.isEmpty {
Text(" Active Ports:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(rdma.activePorts, id: \.device) { port in
Text(" \(port.device) port \(port.port): \(port.state)")
.font(.caption2)
.foregroundColor(.green)
}
}
}
}
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
switch rdma.rdmaCtlEnabled {
case .some(true):
return "Enabled"
case .some(false):
return "Disabled"
case nil:
return rdma.devices.isEmpty ? "Not Available" : "Available"
}
}
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
switch rdma.rdmaCtlEnabled {
case .some(true):
return .green
case .some(false):
return .orange
case nil:
return rdma.devices.isEmpty ? .secondary : .green
}
}
private var sendBugReportButton: some View {
VStack(alignment: .leading, spacing: 4) {
Button {
@@ -492,6 +607,88 @@ struct ContentView: View {
bugReportInFlight = false
}
private func showUninstallConfirmationAlert() {
let alert = NSAlert()
alert.messageText = "Uninstall EXO"
alert.informativeText = """
This will remove EXO and all its system components:
• Network configuration daemon
• Launch at login registration
• EXO network location
The app will be moved to Trash.
"""
alert.alertStyle = .warning
alert.addButton(withTitle: "Uninstall")
alert.addButton(withTitle: "Cancel")
// Style the Uninstall button as destructive
if let uninstallButton = alert.buttons.first {
uninstallButton.hasDestructiveAction = true
}
let response = alert.runModal()
if response == .alertFirstButtonReturn {
performUninstall()
}
}
private func performUninstall() {
uninstallInProgress = true
// Stop EXO process first
controller.cancelPendingLaunch()
controller.stop()
stateService.stopPolling()
// Run the privileged uninstall on a background thread
// Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess
DispatchQueue.global(qos: .utility).async {
do {
// Remove network setup daemon and components (requires admin privileges)
try NetworkSetupHelper.uninstall()
DispatchQueue.main.async {
// Unregister from launch at login
LaunchAtLoginHelper.disable()
// Move app to trash
self.moveAppToTrash()
// Quit the app
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
NSApplication.shared.terminate(nil)
}
}
} catch {
DispatchQueue.main.async {
self.showErrorAlert(message: error.localizedDescription)
self.uninstallInProgress = false
}
}
}
}
private func showErrorAlert(message: String) {
let alert = NSAlert()
alert.messageText = "Uninstall Failed"
alert.informativeText = message
alert.alertStyle = .critical
alert.addButton(withTitle: "OK")
alert.runModal()
}
private func moveAppToTrash() {
guard let appURL = Bundle.main.bundleURL as URL? else { return }
do {
try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)
} catch {
// If we can't trash the app, that's OK - user can do it manually
// The important system components have already been cleaned up
}
}
private var buildTag: String {
Bundle.main.infoDictionary?["EXOBuildTag"] as? String ?? "unknown"
}
@@ -505,14 +702,27 @@ private struct HoverButton: View {
let title: String
let tint: Color
let trailingSystemImage: String?
let small: Bool
let action: () -> Void
init(
title: String, tint: Color = .primary, trailingSystemImage: String? = nil,
small: Bool = false, action: @escaping () -> Void
) {
self.title = title
self.tint = tint
self.trailingSystemImage = trailingSystemImage
self.small = small
self.action = action
}
@State private var isHovering = false
var body: some View {
Button(action: action) {
HStack {
Text(title)
.font(small ? .caption : nil)
Spacer()
if let systemName = trailingSystemImage {
Image(systemName: systemName)
@@ -520,8 +730,8 @@ private struct HoverButton: View {
}
}
.frame(maxWidth: .infinity, alignment: .leading)
.padding(.vertical, 6)
.padding(.horizontal, 8)
.padding(.vertical, small ? 4 : 6)
.padding(.horizontal, small ? 6 : 8)
.background(
RoundedRectangle(cornerRadius: 6)
.fill(
@@ -536,4 +746,3 @@ private struct HoverButton: View {
.onHover { isHovering = $0 }
}
}

View File

@@ -8,9 +8,9 @@
import AppKit
import CoreImage
import CoreImage.CIFilterBuiltins
import ServiceManagement
import Sparkle
import SwiftUI
import ServiceManagement
import UserNotifications
import os.log
@@ -19,6 +19,7 @@ struct EXOApp: App {
@StateObject private var controller: ExoProcessController
@StateObject private var stateService: ClusterStateService
@StateObject private var networkStatusService: NetworkStatusService
@StateObject private var localNetworkChecker: LocalNetworkChecker
@StateObject private var updater: SparkleUpdater
private let terminationObserver: TerminationObserver
private let ciContext = CIContext(options: nil)
@@ -37,9 +38,13 @@ struct EXOApp: App {
_stateService = StateObject(wrappedValue: service)
let networkStatus = NetworkStatusService()
_networkStatusService = StateObject(wrappedValue: networkStatus)
let localNetwork = LocalNetworkChecker()
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
_updater = StateObject(wrappedValue: updater)
enableLaunchAtLoginIfNeeded()
NetworkSetupHelper.ensureLaunchDaemonInstalled()
// Check local network access BEFORE launching exo
localNetwork.check()
controller.scheduleLaunch(after: 15)
service.startPolling()
networkStatus.startPolling()
@@ -51,6 +56,7 @@ struct EXOApp: App {
.environmentObject(controller)
.environmentObject(stateService)
.environmentObject(networkStatusService)
.environmentObject(localNetworkChecker)
.environmentObject(updater)
} label: {
menuBarIcon
@@ -107,7 +113,7 @@ struct EXOApp: App {
filter.contrast = 0.9
guard let output = filter.outputImage,
let rendered = ciContext.createCGImage(output, from: output.extent)
let rendered = ciContext.createCGImage(output, from: output.extent)
else {
return nil
}
@@ -120,7 +126,26 @@ struct EXOApp: App {
do {
try SMAppService.mainApp.register()
} catch {
Logger().error("Failed to register EXO for launch at login: \(error.localizedDescription)")
Logger().error(
"Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
}
/// Helper for managing EXO's launch-at-login registration
enum LaunchAtLoginHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LaunchAtLogin")
/// Unregisters EXO from launching at login
static func disable() {
guard SMAppService.mainApp.status == .enabled else { return }
do {
try SMAppService.mainApp.unregister()
logger.info("Unregistered EXO from launch at login")
} catch {
logger.error(
"Failed to unregister EXO from launch at login: \(error.localizedDescription, privacy: .public)"
)
}
}
}
@@ -145,7 +170,7 @@ final class SparkleUpdater: NSObject, ObservableObject {
center.requestAuthorization(options: [.alert, .sound]) { _, _ in }
controller.updater.automaticallyChecksForUpdates = true
controller.updater.automaticallyDownloadsUpdates = false
controller.updater.updateCheckInterval = 900 // 15 minutes
controller.updater.updateCheckInterval = 900 // 15 minutes
DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in
controller?.updater.checkForUpdatesInBackground()
}
@@ -212,7 +237,8 @@ private final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterD
func userNotificationCenter(
_ center: UNUserNotificationCenter,
willPresent notification: UNNotification,
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) -> Void
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->
Void
) {
completionHandler([.banner, .list, .sound])
}

View File

@@ -31,7 +31,8 @@ final class ExoProcessController: ObservableObject {
@Published private(set) var launchCountdownSeconds: Int?
@Published var customNamespace: String = {
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
}() {
}()
{
didSet {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
@@ -221,7 +222,9 @@ final class ExoProcessController: ObservableObject {
if let tag = Bundle.main.infoDictionary?["EXOBuildTag"] as? String, !tag.isEmpty {
return tag
}
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String, !short.isEmpty {
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String,
!short.isEmpty
{
return short
}
return "dev"

View File

@@ -8,5 +8,15 @@
<string>$(EXO_BUILD_TAG)</string>
<key>EXOBuildCommit</key>
<string>$(EXO_BUILD_COMMIT)</string>
<key>EXOBugReportPresignedUrlEndpoint</key>
<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>
<key>NSLocalNetworkUsageDescription</key>
<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>
<key>NSBonjourServices</key>
<array>
<string>_p2p._tcp</string>
<string>_p2p._udp</string>
<string>_libp2p._udp</string>
</array>
</dict>
</plist>

View File

@@ -16,10 +16,13 @@ struct ClusterState: Decodable {
self.instances = rawInstances.mapValues(\.instance)
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
let rawTasks = try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
let rawTasks =
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
self.tasks = rawTasks.compactMapValues(\.task)
self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)
let rawDownloads = try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads) ?? [:]
let rawDownloads =
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
?? [:]
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
}
@@ -41,7 +44,8 @@ private struct TaggedInstance: Decodable {
let payloads = try container.decode([String: ClusterInstancePayload].self)
guard let entry = payloads.first else {
throw DecodingError.dataCorrupted(
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
DecodingError.Context(
codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
)
}
self.instance = ClusterInstance(
@@ -77,7 +81,8 @@ struct RunnerStatusSummary: Decodable {
let payloads = try container.decode([String: RunnerStatusDetail].self)
guard let entry = payloads.first else {
throw DecodingError.dataCorrupted(
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
DecodingError.Context(
codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
)
}
self.status = entry.key
@@ -257,7 +262,9 @@ struct ChatCompletionTaskParameters: Decodable, Equatable {
func promptPreview() -> String? {
guard let messages else { return nil }
if let userMessage = messages.last(where: { $0.role?.lowercased() == "user" && ($0.content?.isEmpty == false) }) {
if let userMessage = messages.last(where: {
$0.role?.lowercased() == "user" && ($0.content?.isEmpty == false)
}) {
return userMessage.content
}
return messages.last?.content
@@ -365,5 +372,3 @@ extension ClusterState {
func availableModels() -> [ModelOption] { [] }
}

View File

@@ -1,4 +1,3 @@
import CryptoKit
import Foundation
struct BugReportOutcome: Equatable {
@@ -7,17 +6,17 @@ struct BugReportOutcome: Equatable {
}
enum BugReportError: LocalizedError {
case missingCredentials
case invalidEndpoint
case presignedUrlFailed(String)
case uploadFailed(String)
case collectFailed(String)
var errorDescription: String? {
switch self {
case .missingCredentials:
return "Bug report upload credentials are not set."
case .invalidEndpoint:
return "Bug report endpoint is invalid."
case .presignedUrlFailed(let message):
return "Failed to get presigned URLs: \(message)"
case .uploadFailed(let message):
return "Bug report upload failed: \(message)"
case .collectFailed(let message):
@@ -27,11 +26,13 @@ enum BugReportError: LocalizedError {
}
struct BugReportService {
struct AWSConfig {
let accessKey: String
let secretKey: String
let region: String
let bucket: String
private struct PresignedUrlsRequest: Codable {
let keys: [String]
}
private struct PresignedUrlsResponse: Codable {
let urls: [String: String]
let expiresIn: Int?
}
func sendReport(
@@ -39,9 +40,9 @@ struct BugReportService {
now: Date = Date(),
isManual: Bool = false
) async throws -> BugReportOutcome {
let credentials = try loadCredentials()
let timestamp = ISO8601DateFormatter().string(from: now)
let prefix = "reports/\(timestamp)/"
let timestamp = Self.runTimestampString(now)
let dayPrefix = Self.dayPrefixString(now)
let prefix = "reports/\(dayPrefix)/\(timestamp)/"
let logData = readLog()
let ifconfigText = try await captureIfconfig()
@@ -66,28 +67,82 @@ struct BugReportService {
("\(prefix)exo.log", logData),
("\(prefix)state.json", stateData),
("\(prefix)events.json", eventsData),
("\(prefix)report.json", reportJSON)
("\(prefix)report.json", reportJSON),
]
let uploader = try S3Uploader(config: credentials)
for item in uploads {
guard let data = item.data else { continue }
try await uploader.upload(
objectPath: item.path,
body: data
)
let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in
guard let body = item.data else { return nil }
return (key: item.path, body: body)
}
return BugReportOutcome(success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
guard !uploadItems.isEmpty else {
return BugReportOutcome(success: false, message: "No data to upload")
}
let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\.key))
for item in uploadItems {
guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {
throw BugReportError.uploadFailed("Missing presigned URL for \(item.key)")
}
try await uploadToPresignedUrl(url: url, body: item.body)
}
return BugReportOutcome(
success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
}
private func loadCredentials() throws -> AWSConfig {
return AWSConfig(
accessKey: "AKIAYEKP5EMXTOBYDGHX",
secretKey: "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",
region: "us-east-1",
bucket: "exo-bug-reports"
)
private static func dayPrefixString(_ date: Date) -> String {
var calendar = Calendar(identifier: .gregorian)
calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
let components = calendar.dateComponents([.year, .month, .day], from: date)
let year = components.year ?? 0
let month = components.month ?? 0
let day = components.day ?? 0
return String(format: "%04d/%02d/%02d", year, month, day)
}
private static func runTimestampString(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.locale = Locale(identifier: "en_US_POSIX")
formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
formatter.dateFormat = "yyyy-MM-dd'T'HHmmss.SSS'Z'"
return formatter.string(from: date)
}
private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws
-> [String: String]
{
guard
let endpointString = bundle.infoDictionary?["EXOBugReportPresignedUrlEndpoint"]
as? String
else {
throw BugReportError.invalidEndpoint
}
let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)
guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)
else {
throw BugReportError.invalidEndpoint
}
var request = URLRequest(url: endpoint)
request.httpMethod = "POST"
request.timeoutInterval = 10
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
let encoder = JSONEncoder()
request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse else {
throw BugReportError.presignedUrlFailed("Non-HTTP response")
}
guard (200..<300).contains(http.statusCode) else {
throw BugReportError.presignedUrlFailed("HTTP status \(http.statusCode)")
}
let decoder = JSONDecoder()
let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)
return decoded.urls
}
private func readLog() -> Data? {
@@ -100,7 +155,8 @@ struct BugReportService {
private func captureIfconfig() async throws -> String {
let result = runCommand(["/sbin/ifconfig"])
guard result.exitCode == 0 else {
throw BugReportError.collectFailed(result.error.isEmpty ? "ifconfig failed" : result.error)
throw BugReportError.collectFailed(
result.error.isEmpty ? "ifconfig failed" : result.error)
}
return result.output
}
@@ -108,12 +164,23 @@ struct BugReportService {
private func readDebugInfo() -> DebugInfo {
DebugInfo(
thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),
interfaces: readInterfaces()
interfaces: readInterfaces(),
rdma: readRDMADebugInfo()
)
}
private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {
DebugInfo.RDMADebugInfo(
rdmaCtlStatus: safeRunCommand(["/usr/bin/rdma_ctl", "status"]),
ibvDevices: safeRunCommand(["/usr/bin/ibv_devices"]),
ibvDevinfo: safeRunCommand(["/usr/bin/ibv_devinfo"])
)
}
private func readThunderboltBridgeDisabled() -> Bool? {
let result = runCommand(["/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
let result = runCommand([
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased()
if output.contains("enabled") {
@@ -156,7 +223,8 @@ struct BugReportService {
request.timeoutInterval = 5
do {
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)
else {
return nil
}
return data
@@ -165,6 +233,36 @@ struct BugReportService {
}
}
private func uploadToPresignedUrl(url: URL, body: Data) async throws {
let maxAttempts = 2
var lastError: Error?
for attempt in 1...maxAttempts {
do {
var request = URLRequest(url: url)
request.httpMethod = "PUT"
request.httpBody = body
request.timeoutInterval = 30
let (_, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse else {
throw BugReportError.uploadFailed("Non-HTTP response")
}
guard (200..<300).contains(http.statusCode) else {
throw BugReportError.uploadFailed("HTTP status \(http.statusCode)")
}
return
} catch {
lastError = error
if attempt < maxAttempts {
try await Task.sleep(nanoseconds: 400_000_000)
}
}
}
throw BugReportError.uploadFailed(lastError?.localizedDescription ?? "Unknown error")
}
private func makeReportJson(
timestamp: String,
hostName: String,
@@ -182,7 +280,7 @@ struct BugReportService {
"system": system,
"exo_version": exo.version as Any,
"exo_commit": exo.commit as Any,
"report_type": isManual ? "manual" : "automated"
"report_type": isManual ? "manual" : "automated",
]
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
}
@@ -213,10 +311,13 @@ struct BugReportService {
let user = safeRunCommand(["/usr/bin/whoami"])
let consoleUser = safeRunCommand(["/usr/bin/stat", "-f%Su", "/dev/console"])
let uptime = safeRunCommand(["/usr/bin/uptime"])
let diskRoot = safeRunCommand(["/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'"])
let diskRoot = safeRunCommand([
"/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'",
])
let interfacesList = safeRunCommand(["/usr/sbin/ipconfig", "getiflist"])
let interfacesAndIPs = interfacesList?
let interfacesAndIPs =
interfacesList?
.split(whereSeparator: { $0 == " " || $0 == "\n" })
.compactMap { iface -> [String: Any]? in
let name = String(iface)
@@ -227,7 +328,8 @@ struct BugReportService {
} ?? []
let wifiSSID: String?
let airportPath = "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
let airportPath =
"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
if FileManager.default.isExecutableFile(atPath: airportPath) {
wifiSSID = safeRunCommand([airportPath, "-I"]).flatMap(parseWifiSSID)
} else {
@@ -255,7 +357,7 @@ struct BugReportService {
"disk_root": diskRoot as Any,
"interfaces_and_ips": interfacesAndIPs,
"ipconfig_getiflist": interfacesList as Any,
"wifi_ssid": wifiSSID as Any
"wifi_ssid": wifiSSID as Any,
]
}
@@ -313,7 +415,8 @@ struct BugReportService {
for line in airportOutput.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("SSID:") {
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(in: .whitespaces)
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(
in: .whitespaces)
}
}
return nil
@@ -350,6 +453,7 @@ struct BugReportService {
private struct DebugInfo {
let thunderboltBridgeDisabled: Bool?
let interfaces: [InterfaceStatus]
let rdma: RDMADebugInfo
struct InterfaceStatus {
let name: String
@@ -358,7 +462,21 @@ private struct DebugInfo {
func toDictionary() -> [String: Any] {
[
"name": name,
"ip": ip as Any
"ip": ip as Any,
]
}
}
struct RDMADebugInfo {
let rdmaCtlStatus: String?
let ibvDevices: String?
let ibvDevinfo: String?
func toDictionary() -> [String: Any] {
[
"rdma_ctl_status": rdmaCtlStatus as Any,
"ibv_devices": ibvDevices as Any,
"ibv_devinfo": ibvDevinfo as Any,
]
}
}
@@ -366,7 +484,8 @@ private struct DebugInfo {
func toDictionary() -> [String: Any] {
[
"thunderbolt_bridge_disabled": thunderboltBridgeDisabled as Any,
"interfaces": interfaces.map { $0.toDictionary() }
"interfaces": interfaces.map { $0.toDictionary() },
"rdma": rdma.toDictionary(),
]
}
}
@@ -376,163 +495,3 @@ private struct CommandResult {
let output: String
let error: String
}
private struct S3Uploader {
let config: BugReportService.AWSConfig
init(config: BugReportService.AWSConfig) throws {
self.config = config
}
func upload(objectPath: String, body: Data) async throws {
let host = "\(config.bucket).s3.amazonaws.com"
guard let url = URL(string: "https://\(host)/\(objectPath)") else {
throw BugReportError.invalidEndpoint
}
let now = Date()
let amzDate = awsTimestamp(now)
let dateStamp = dateStamp(now)
let payloadHash = sha256Hex(body)
let headers = [
"host": host,
"x-amz-content-sha256": payloadHash,
"x-amz-date": amzDate
]
let canonicalRequest = buildCanonicalRequest(
method: "PUT",
url: url,
headers: headers,
payloadHash: payloadHash
)
let stringToSign = buildStringToSign(
amzDate: amzDate,
dateStamp: dateStamp,
canonicalRequestHash: sha256Hex(canonicalRequest.data(using: .utf8) ?? Data())
)
let signingKey = deriveKey(secret: config.secretKey, dateStamp: dateStamp, region: config.region, service: "s3")
let signature = hmacHex(key: signingKey, data: Data(stringToSign.utf8))
let signedHeaders = "host;x-amz-content-sha256;x-amz-date"
let authorization = """
AWS4-HMAC-SHA256 Credential=\(config.accessKey)/\(dateStamp)/\(config.region)/s3/aws4_request, SignedHeaders=\(signedHeaders), Signature=\(signature)
"""
var request = URLRequest(url: url)
request.httpMethod = "PUT"
request.httpBody = body
request.setValue(headers["x-amz-content-sha256"], forHTTPHeaderField: "x-amz-content-sha256")
request.setValue(headers["x-amz-date"], forHTTPHeaderField: "x-amz-date")
request.setValue(host, forHTTPHeaderField: "Host")
request.setValue(authorization, forHTTPHeaderField: "Authorization")
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
let statusText = (response as? HTTPURLResponse)?.statusCode ?? -1
_ = data // ignore response body for UX
throw BugReportError.uploadFailed("HTTP status \(statusText)")
}
}
private func buildCanonicalRequest(
method: String,
url: URL,
headers: [String: String],
payloadHash: String
) -> String {
let canonicalURI = encodePath(url.path)
let canonicalQuery = url.query ?? ""
let sortedHeaders = headers.sorted { $0.key < $1.key }
let canonicalHeaders = sortedHeaders
.map { "\($0.key.lowercased()):\($0.value)\n" }
.joined()
let signedHeaders = sortedHeaders.map { $0.key.lowercased() }.joined(separator: ";")
return [
method,
canonicalURI,
canonicalQuery,
canonicalHeaders,
signedHeaders,
payloadHash
].joined(separator: "\n")
}
private func encodePath(_ path: String) -> String {
return path
.split(separator: "/")
.map { segment in
segment.addingPercentEncoding(withAllowedCharacters: Self.rfc3986) ?? String(segment)
}
.joined(separator: "/")
.prependSlashIfNeeded()
}
private func buildStringToSign(
amzDate: String,
dateStamp: String,
canonicalRequestHash: String
) -> String {
"""
AWS4-HMAC-SHA256
\(amzDate)
\(dateStamp)/\(config.region)/s3/aws4_request
\(canonicalRequestHash)
"""
}
private func deriveKey(secret: String, dateStamp: String, region: String, service: String) -> Data {
let kDate = hmac(key: Data(("AWS4" + secret).utf8), data: Data(dateStamp.utf8))
let kRegion = hmac(key: kDate, data: Data(region.utf8))
let kService = hmac(key: kRegion, data: Data(service.utf8))
return hmac(key: kService, data: Data("aws4_request".utf8))
}
private func hmac(key: Data, data: Data) -> Data {
let keySym = SymmetricKey(data: key)
let mac = HMAC<SHA256>.authenticationCode(for: data, using: keySym)
return Data(mac)
}
private func hmacHex(key: Data, data: Data) -> String {
hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined()
}
private func sha256Hex(_ data: Data) -> String {
let digest = SHA256.hash(data: data)
return digest.compactMap { String(format: "%02x", $0) }.joined()
}
private func awsTimestamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
formatter.timeZone = TimeZone(abbreviation: "UTC")
return formatter.string(from: date)
}
private func dateStamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.dateFormat = "yyyyMMdd"
formatter.timeZone = TimeZone(abbreviation: "UTC")
return formatter.string(from: date)
}
private static let rfc3986: CharacterSet = {
var set = CharacterSet.alphanumerics
set.insert(charactersIn: "-._~")
return set
}()
}
private extension String {
func prependSlashIfNeeded() -> String {
if hasPrefix("/") {
return self
}
return "/" + self
}
}

View File

@@ -57,7 +57,9 @@ final class ClusterStateService: ObservableObject {
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
return
}
if let nodeId = try? decoder.decode(String.self, from: data) {
@@ -113,7 +115,9 @@ final class ClusterStateService: ObservableObject {
}
}
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int) async {
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)
async
{
do {
var request = URLRequest(url: baseURL.appendingPathComponent("instance"))
request.httpMethod = "POST"
@@ -122,7 +126,7 @@ final class ClusterStateService: ObservableObject {
"model_id": modelId,
"sharding": sharding,
"instance_meta": instanceMeta,
"min_nodes": minNodes
"min_nodes": minNodes,
]
request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])
let (_, response) = try await session.data(for: request)
@@ -143,7 +147,9 @@ final class ClusterStateService: ObservableObject {
do {
let url = baseURL.appendingPathComponent("models")
let (data, response) = try await session.data(from: url)
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
throw URLError(.badServerResponse)
}
let list = try decoder.decode(ModelListResponse.self, from: data)

View File

@@ -0,0 +1,150 @@
import Foundation
import Network
import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
case unknown
case checking
case working
case notWorking(reason: String)
var isHealthy: Bool {
if case .working = self { return true }
return false
}
var displayText: String {
switch self {
case .unknown:
return "Unknown"
case .checking:
return "Checking..."
case .working:
return "Working"
case .notWorking(let reason):
return reason
}
}
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
let result = await self.performCheck()
self.status = result
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
// mDNS multicast address - same as libp2p uses for peer discovery
let host = NWEndpoint.Host("224.0.0.251")
let port = NWEndpoint.Port(integerLiteral: 5353)
let params = NWParameters.udp
params.allowLocalEndpointReuse = true
let conn = NWConnection(host: host, port: port, using: params)
connection = conn
return await withCheckedContinuation { continuation in
var hasResumed = false
let lock = NSLock()
let resumeOnce: (Status) -> Void = { status in
lock.lock()
defer { lock.unlock() }
guard !hasResumed else { return }
hasResumed = true
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
case .waiting(let error):
let errorStr = "\(error)"
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|| errorStr.contains("permission") || errorStr.contains("denied")
{
resumeOnce(.notWorking(reason: "Permission denied"))
} else {
resumeOnce(.notWorking(reason: "Failed: \(error.localizedDescription)"))
}
case .cancelled, .setup, .preparing:
break
@unknown default:
break
}
}
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:
resumeOnce(.working)
case .waiting, .preparing, .setup:
resumeOnce(.notWorking(reason: "Timeout (may be blocked)"))
default:
resumeOnce(.notWorking(reason: "Timeout"))
}
}
}
}
func stop() {
checkTask?.cancel()
checkTask = nil
connection?.cancel()
connection = nil
}
}

View File

@@ -5,64 +5,66 @@ import os.log
enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination = "/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
private static let setupScript = """
#!/usr/bin/env bash
#!/usr/bin/env bash
set -euo pipefail
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
static func ensureLaunchDaemonInstalled() {
Task.detached {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
do {
if daemonAlreadyInstalled() {
return
@@ -70,11 +72,70 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
try await installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error("Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)")
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
/// Requires admin privileges.
static func uninstall() throws {
let uninstallScript = makeUninstallScript()
try runShellAsAdmin(uninstallScript)
logger.info("EXO network setup components removed successfully")
}
/// Checks if there are any EXO network components installed that need cleanup
static func hasInstalledComponents() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
return scriptExists || plistExists
}
private static func makeUninstallScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
# Unload the LaunchDaemon if running
launchctl bootout system/"$LABEL" 2>/dev/null || true
# Remove LaunchDaemon plist
rm -f "$PLIST_DEST"
# Remove the script and parent directory if empty
rm -f "$SCRIPT_DEST"
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
# Remove log files
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic 2>/dev/null || true
# Delete the exo network location if it exists
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable Thunderbolt Bridge if it exists
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
} || true
echo "EXO network components removed successfully"
"""
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
@@ -82,7 +143,8 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
guard scriptExists, plistExists else { return false }
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any]
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
@@ -92,7 +154,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String], programArgs.contains(scriptDestination) == false {
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
@@ -105,58 +169,59 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
mkdir -p "$(dirname "$SCRIPT_DEST")"
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript = script
let escapedScript =
script
.replacingOccurrences(of: "\\", with: "\\\\")
.replacingOccurrences(of: "\"", with: "\\\"")
let appleScriptSource = """
do shell script "\(escapedScript)" with administrator privileges
"""
do shell script "\(escapedScript)" with administrator privileges
"""
guard let appleScript = NSAppleScript(source: appleScriptSource) else {
throw NetworkSetupError.scriptCreationFailed

View File

@@ -35,14 +35,34 @@ struct NetworkStatus: Equatable {
let thunderboltBridgeState: ThunderboltState?
let bridgeInactive: Bool?
let interfaceStatuses: [InterfaceIpStatus]
let rdmaStatus: RDMAStatus
static let empty = NetworkStatus(
thunderboltBridgeState: nil,
bridgeInactive: nil,
interfaceStatuses: []
interfaceStatuses: [],
rdmaStatus: .empty
)
}
struct RDMAStatus: Equatable {
let rdmaCtlEnabled: Bool?
let devices: [String]
let activePorts: [RDMAPort]
var isAvailable: Bool {
rdmaCtlEnabled == true || !devices.isEmpty
}
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
}
struct RDMAPort: Equatable {
let device: String
let port: String
let state: String
}
struct InterfaceIpStatus: Equatable {
let interfaceName: String
let ipAddress: String?
@@ -59,10 +79,79 @@ private struct NetworkStatusFetcher {
NetworkStatus(
thunderboltBridgeState: readThunderboltBridgeState(),
bridgeInactive: readBridgeInactive(),
interfaceStatuses: readInterfaceStatuses()
interfaceStatuses: readInterfaceStatuses(),
rdmaStatus: readRDMAStatus()
)
}
private func readRDMAStatus() -> RDMAStatus {
let rdmaCtlEnabled = readRDMACtlEnabled()
let devices = readRDMADevices()
let activePorts = readRDMAActivePorts()
return RDMAStatus(
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
}
private func readRDMACtlEnabled() -> Bool? {
let result = runCommand(["rdma_ctl", "status"])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
if output.contains("enabled") {
return true
}
if output.contains("disabled") {
return false
}
return nil
}
private func readRDMADevices() -> [String] {
let result = runCommand(["ibv_devices"])
guard result.exitCode == 0 else { return [] }
var devices: [String] = []
for line in result.output.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("---") || trimmed.lowercased().hasPrefix("device")
|| trimmed.isEmpty
{
continue
}
let parts = trimmed.split(separator: " ", maxSplits: 1)
if let deviceName = parts.first {
devices.append(String(deviceName))
}
}
return devices
}
private func readRDMAActivePorts() -> [RDMAPort] {
let result = runCommand(["ibv_devinfo"])
guard result.exitCode == 0 else { return [] }
var ports: [RDMAPort] = []
var currentDevice: String?
var currentPort: String?
for line in result.output.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("hca_id:") {
currentDevice = trimmed.replacingOccurrences(of: "hca_id:", with: "")
.trimmingCharacters(in: .whitespaces)
} else if trimmed.hasPrefix("port:") {
currentPort = trimmed.replacingOccurrences(of: "port:", with: "")
.trimmingCharacters(in: .whitespaces)
} else if trimmed.hasPrefix("state:") {
let state = trimmed.replacingOccurrences(of: "state:", with: "").trimmingCharacters(
in: .whitespaces)
if let device = currentDevice, let port = currentPort {
if state.lowercased().contains("active") {
ports.append(RDMAPort(device: device, port: port, state: state))
}
}
}
}
return ports
}
private func readThunderboltBridgeState() -> ThunderboltState? {
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
guard result.exitCode == 0 else {
@@ -85,10 +174,11 @@ private struct NetworkStatusFetcher {
private func readBridgeInactive() -> Bool? {
let result = runCommand(["ifconfig", "bridge0"])
guard result.exitCode == 0 else { return nil }
guard let statusLine = result.output
.components(separatedBy: .newlines)
.first(where: { $0.contains("status:") })?
.lowercased()
guard
let statusLine = result.output
.components(separatedBy: .newlines)
.first(where: { $0.contains("status:") })?
.lowercased()
else {
return nil
}
@@ -171,4 +261,3 @@ private struct NetworkStatusFetcher {
)
}
}

View File

@@ -57,7 +57,7 @@ struct InstanceViewModel: Identifiable, Equatable {
case waiting
case failed
case idle
case unknown
case preparing
var label: String {
switch self {
@@ -68,7 +68,7 @@ struct InstanceViewModel: Identifiable, Equatable {
case .waiting: return "Waiting"
case .failed: return "Failed"
case .idle: return "Idle"
case .unknown: return "Unknown"
case .preparing: return "Preparing"
}
}
}
@@ -107,10 +107,13 @@ extension ClusterState {
let nodeToRunner = instance.shardAssignments.nodeToRunner
let nodeIds = Array(nodeToRunner.keys)
let runnerIds = Array(nodeToRunner.values)
let nodeNames = nodeIds.compactMap { nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0 }
let nodeNames = nodeIds.compactMap {
nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0
}
let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }
let downloadProgress = aggregateDownloadProgress(for: nodeIds)
let state = InstanceViewModel.State(statuses: statuses, hasActiveDownload: downloadProgress != nil)
let state = InstanceViewModel.State(
statuses: statuses, hasActiveDownload: downloadProgress != nil)
let chatTasks = (chatTasksByInstance[entry.key] ?? [])
.sorted(by: { $0.sortPriority < $1.sortPriority })
.map { InstanceTaskViewModel(task: $0) }
@@ -165,8 +168,8 @@ extension ClusterState {
}
}
private extension InstanceViewModel.State {
init(statuses: [String], hasActiveDownload: Bool = false) {
extension InstanceViewModel.State {
fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {
if statuses.contains(where: { $0.contains("failed") }) {
self = .failed
} else if hasActiveDownload || statuses.contains(where: { $0.contains("downloading") }) {
@@ -182,7 +185,7 @@ private extension InstanceViewModel.State {
} else if statuses.isEmpty {
self = .idle
} else {
self = .unknown
self = .preparing
}
}
}
@@ -243,4 +246,3 @@ extension InstanceTaskViewModel {
self.parameters = task.parameters
}
}

View File

@@ -87,7 +87,9 @@ struct TopologyViewModel {
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
let allNodes = nodeViewModels().filter {
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
}
guard !allNodes.isEmpty else { return nil }
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
@@ -106,18 +108,24 @@ extension ClusterState {
}
// Rotate so the local node (from /node_id API) is first
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
if let localId = localNodeId,
let index = orderedNodes.firstIndex(where: { $0.id == localId })
{
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
}
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
return TopologyEdgeViewModel(sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
} ?? []
let edgesArray: [TopologyEdgeViewModel] =
topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId),
nodeIds.contains(connection.sendBackNodeId)
else { return nil }
return TopologyEdgeViewModel(
sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
} ?? []
let edges = Set(edgesArray)
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
return TopologyViewModel(
nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
}
}

View File

@@ -20,8 +20,8 @@ struct InstanceRowView: View {
if let progress = instance.downloadProgress {
downloadStatusView(progress: progress)
} else {
statusChip(label: instance.state.label.uppercased(), color: statusColor)
}
statusChip(label: instance.state.label.uppercased(), color: statusColor)
}
}
if let progress = instance.downloadProgress {
GeometryReader { geometry in
@@ -83,7 +83,7 @@ struct InstanceRowView: View {
case .ready: return .teal
case .waiting, .idle: return .gray
case .failed: return .red
case .unknown: return .secondary
case .preparing: return .secondary
}
}
@@ -97,7 +97,8 @@ struct InstanceRowView: View {
.font(.caption)
.fontWeight(.semibold)
if let subtitle = task.subtitle,
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame {
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame
{
Text(subtitle)
.font(.caption2)
.foregroundColor(.secondary)
@@ -234,9 +235,12 @@ struct InstanceRowView: View {
Button {
isExpanded.wrappedValue.toggle()
} label: {
Label(isExpanded.wrappedValue ? "Hide" : "Show", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
Label(
isExpanded.wrappedValue ? "Hide" : "Show",
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
)
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
}
.buttonStyle(.plain)
.font(.caption2)
@@ -311,7 +315,9 @@ struct InstanceRowView: View {
}
@ViewBuilder
private func detailRow(icon: String? = nil, title: String, value: String, tint: Color = .secondary) -> some View {
private func detailRow(
icon: String? = nil, title: String, value: String, tint: Color = .secondary
) -> some View {
HStack(alignment: .firstTextBaseline, spacing: 6) {
if let icon {
Image(systemName: icon)
@@ -329,4 +335,3 @@ struct InstanceRowView: View {
}
}
}

View File

@@ -32,4 +32,3 @@ struct NodeDetailView: View {
}
}
}

View File

@@ -28,4 +28,3 @@ struct NodeRowView: View {
.padding(.vertical, 4)
}
}

View File

@@ -76,30 +76,33 @@ struct TopologyMiniView: View {
private func connectionLines(in size: CGSize) -> some View {
let positions = positionedNodes(in: size)
let positionById = Dictionary(uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
let positionById = Dictionary(
uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
return Canvas { context, _ in
guard !topology.edges.isEmpty else { return }
let nodeRadius: CGFloat = 32
let arrowLength: CGFloat = 10
let arrowSpread: CGFloat = .pi / 7
for edge in topology.edges {
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId] else { continue }
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]
else { continue }
let dx = end.x - start.x
let dy = end.y - start.y
let distance = max(CGFloat(hypot(dx, dy)), 1)
let ux = dx / distance
let uy = dy / distance
let adjustedStart = CGPoint(x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
let adjustedStart = CGPoint(
x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)
var linePath = Path()
linePath.move(to: adjustedStart)
linePath.addLine(to: adjustedEnd)
context.stroke(
context.stroke(
linePath,
with: .color(.secondary.opacity(0.3)),
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
)
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
)
let angle = atan2(uy, ux)
let tip = adjustedEnd
@@ -168,5 +171,3 @@ private struct NodeGlyphView: View {
.frame(width: 95)
}
}

View File

@@ -6,6 +6,7 @@
//
import Testing
@testable import EXO
struct EXOTests {

154
app/EXO/uninstall-exo.sh Executable file
View File

@@ -0,0 +1,154 @@
#!/usr/bin/env bash
#
# EXO Uninstaller Script
#
# This script removes all EXO system components that persist after deleting the app.
# Run with: sudo ./uninstall-exo.sh
#
# Components removed:
# - LaunchDaemon: /Library/LaunchDaemons/io.exo.networksetup.plist
# - Network script: /Library/Application Support/EXO/
# - Log files: /var/log/io.exo.networksetup.*
# - Network location: "exo"
# - Launch at login registration
#
set -euo pipefail
LABEL="io.exo.networksetup"
SCRIPT_DEST="/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
PLIST_DEST="/Library/LaunchDaemons/io.exo.networksetup.plist"
LOG_OUT="/var/log/${LABEL}.log"
LOG_ERR="/var/log/${LABEL}.err.log"
APP_BUNDLE_ID="io.exo.EXO"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
echo "========================================"
echo " EXO Uninstaller"
echo "========================================"
echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
# SMAppService-based login items cannot be removed from a shell script.
# They can only be unregistered from within the app itself or manually via System Settings.
echo_warn "Launch at login must be removed manually:"
echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
echo "========================================"
echo_info "EXO uninstall complete!"
echo "========================================"
echo ""
echo "The following have been removed:"
echo " • Network setup LaunchDaemon"
echo " • Network configuration script"
echo " • Log files"
echo " • 'exo' network location"
echo ""
echo "Your network has been restored to use the 'Automatic' location."
echo "Thunderbolt Bridge has been re-enabled (if present)."
echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

526
bench/exo_bench.py Normal file
View File

@@ -0,0 +1,526 @@
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from collections.abc import Callable
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
time.sleep(0.1)
continue
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
return f"{b:.2f}{unit}"
b /= 1024.0
raise ValueError("You're using petabytes of memory. Something went wrong...")
def parse_int_list(values: list[str]) -> list[int]:
items: list[int] = []
for v in values:
for part in v.split(","):
part = part.strip()
if part:
items.append(int(part))
seen: set[int] = set()
out: list[int] = []
for x in items:
if x not in seen:
out.append(x)
seen.add(x)
return out
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("id") == model_arg:
short_id = str(m["id"])
full_id = str(m.get("hugging_face_id") or m["id"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["id"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
content, pp_tokens = prompt_sizer.build(pp_hint)
payload: dict[str, Any] = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": tg,
}
t0 = time.perf_counter()
out = client.post_bench_chat_completions(payload)
elapsed = time.perf_counter() - t0
stats = out.get("generation_stats")
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
return {
"elapsed_s": elapsed,
"output_text_preview": preview,
"stats": stats,
}, pp_tokens
class PromptSizer:
def __init__(self, tokenizer: Any, atom: str = "a "):
self.tokenizer = tokenizer
self.atom = atom
self.count_fn = PromptSizer._make_counter(tokenizer)
self.base_tokens = self.count_fn("")
@staticmethod
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
def count_fn(user_content: str) -> int:
messages = [{"role": "user", "content": user_content}]
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
return int(len(ids))
return count_fn
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
target = int(target_prompt_tokens)
if target < self.base_tokens:
raise RuntimeError(
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
)
content = ""
tok = self.count_fn(content)
while tok < target:
content += self.atom
tok = self.count_fn(content)
if tok != target:
raise RuntimeError(
f"Overshot: got {tok} tokens (target {target}). "
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
)
return content, tok
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--pp",
nargs="+",
required=True,
help="Prompt-size hints (ints). Accepts commas.",
)
ap.add_argument(
"--tg",
nargs="+",
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Pipeline jaccl is often pointless, skip by default",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
ap.add_argument(
"--warmup",
type=int,
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
help="Write raw per-run results JSON to this path.",
)
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
tg_list = parse_int_list(args.tg)
if not pp_list or not tg_list:
logger.error("pp and tg lists must be non-empty")
return 2
if args.repeat <= 0:
logger.error("--repeat must be >= 1")
return 2
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": short_id}
)
previews = previews_resp.get("previews") or []
tokenizer = AutoTokenizer.from_pretrained(
full_model_id,
trust_remote_code=True,
)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
try:
prompt_sizer = PromptSizer(tokenizer)
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
except Exception:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if 0 < n <= args.max_nodes:
selected.append(p)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
selected.sort(
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
-nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
logger.info(f"placements: {len(selected)}")
for p in selected:
logger.info(
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
)
if args.dry_run:
return 0
all_rows: list[dict[str, Any]] = []
for preview in selected:
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])
instance_meta = str(preview["instance_meta"])
n_nodes = nodes_used_in_instance(instance)
logger.info("=" * 80)
logger.info(
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
)
client.request_json("POST", "/instance", body={"instance": instance})
wait_for_instance_ready(client, instance_id)
time.sleep(1)
try:
for i in range(args.warmup):
run_one_completion(
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
)
logger.debug(f" warmup {i + 1}/{args.warmup} done")
for pp in pp_list:
if (
pp * n_nodes > 2048
and "ring" in instance_meta.lower()
and "tensor" in sharding.lower()
):
model_card = MODEL_CARDS[short_id]
if model_card.metadata.storage_size > Memory.from_gb(10):
logger.info(
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
)
continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
all_rows.append(row)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
finally:
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
logger.debug(f"Deleted instance {instance_id}")
time.sleep(5)
if args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump(all_rows, f, indent=2, ensure_ascii=False)
logger.debug(f"\nWrote results JSON: {args.json_out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -11,4 +11,3 @@ declare global {
}
export {};

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import ChatAttachments from './ChatAttachments.svelte';
import type { ChatUploadedFile } from '$lib/types/files';
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
@@ -10,7 +10,6 @@
showHelperText?: boolean;
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
}
let {
@@ -18,8 +17,7 @@
placeholder = 'Ask anything',
showHelperText = false,
autofocus = true,
showModelSelector = false,
modelTasks = {}
showModelSelector = false
}: Props = $props();
let message = $state('');
@@ -50,29 +48,13 @@
// Accept all supported file types
const acceptString = getAcceptString(['image', 'text', 'pdf']);
// Check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
}
// Check if the currently selected model supports image generation
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsImageGeneration(currentModel);
});
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
const models: Array<{id: string, label: string}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
models.push({
id: modelId,
label: modelId.split('/').pop() || modelId,
isImageModel: modelSupportsImageGeneration(modelId)
});
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
}
}
return models;
@@ -178,12 +160,7 @@
uploadedFiles = [];
resetTextareaHeight();
// Use image generation for image models
if (isImageModel() && content) {
generateImage(content);
} else {
sendMessage(content, files);
}
sendMessage(content, files);
// Refocus the textarea after sending
setTimeout(() => textareaRef?.focus(), 10);
@@ -320,14 +297,7 @@
{:else}
<span class="w-3"></span>
{/if}
{#if model.isImageModel}
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate flex-1">{model.label}</span>
<span class="truncate">{model.label}</span>
</button>
{/each}
</div>
@@ -387,7 +357,7 @@
onkeydown={handleKeydown}
oninput={handleInput}
onpaste={handlePaste}
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
{placeholder}
disabled={loading}
rows={1}
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
@@ -401,23 +371,14 @@
{!canSend || loading
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={isImageModel() ? "Generate image" : "Send message"}
aria-label="Send message"
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
<span class="hidden sm:inline">PROCESSING</span>
<span class="sm:hidden">...</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}

View File

@@ -365,58 +365,10 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<!-- Generated Images -->
{#if message.attachments?.some(a => a.type === 'generated-image')}
<div class="mb-3">
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
<div class="relative group/img inline-block">
<img
src={attachment.preview}
alt=""
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
/>
<!-- Download button overlay -->
<button
type="button"
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
const link = document.createElement('a');
link.href = attachment.preview;
link.download = `generated-image-${Date.now()}.png`;
link.click();
}
}}
title="Download image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</button>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-foreground">
{#if message.content === 'Generating image...'}
<div class="flex items-center gap-3 text-exo-yellow">
<div class="relative">
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
</div>
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
</div>
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
</div>
</div>

View File

@@ -1,8 +1,7 @@
export { default as TopologyGraph } from './TopologyGraph.svelte';
export { default as ChatForm } from './ChatForm.svelte';
export { default as ChatMessages } from './ChatMessages.svelte';
export { default as ChatAttachments } from './ChatAttachments.svelte';
export { default as ChatSidebar } from './ChatSidebar.svelte';
export { default as ModelCard } from './ModelCard.svelte';
export { default as MarkdownContent } from './MarkdownContent.svelte';
export { default as TopologyGraph } from "./TopologyGraph.svelte";
export { default as ChatForm } from "./ChatForm.svelte";
export { default as ChatMessages } from "./ChatMessages.svelte";
export { default as ChatAttachments } from "./ChatAttachments.svelte";
export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";

View File

File diff suppressed because it is too large Load Diff

View File

@@ -13,55 +13,124 @@ export interface ChatUploadedFile {
}
export interface ChatAttachment {
type: 'image' | 'text' | 'pdf' | 'audio';
type: "image" | "text" | "pdf" | "audio";
name: string;
content?: string;
base64Url?: string;
mimeType?: string;
}
export type FileCategory = 'image' | 'text' | 'pdf' | 'audio' | 'unknown';
export type FileCategory = "image" | "text" | "pdf" | "audio" | "unknown";
export const IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg'];
export const IMAGE_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/svg+xml'];
export const IMAGE_EXTENSIONS = [
".jpg",
".jpeg",
".png",
".gif",
".webp",
".svg",
];
export const IMAGE_MIME_TYPES = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"image/svg+xml",
];
export const TEXT_EXTENSIONS = [
'.txt', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', '.log',
'.js', '.ts', '.jsx', '.tsx', '.py', '.java', '.cpp', '.c', '.h',
'.css', '.html', '.htm', '.sql', '.sh', '.bat', '.rs', '.go',
'.rb', '.php', '.swift', '.kt', '.scala', '.r', '.dart', '.vue', '.svelte'
".txt",
".md",
".json",
".xml",
".yaml",
".yml",
".csv",
".log",
".js",
".ts",
".jsx",
".tsx",
".py",
".java",
".cpp",
".c",
".h",
".css",
".html",
".htm",
".sql",
".sh",
".bat",
".rs",
".go",
".rb",
".php",
".swift",
".kt",
".scala",
".r",
".dart",
".vue",
".svelte",
];
export const TEXT_MIME_TYPES = [
'text/plain', 'text/markdown', 'text/csv', 'text/html', 'text/css',
'application/json', 'application/xml', 'text/xml', 'application/javascript',
'text/javascript', 'application/typescript'
"text/plain",
"text/markdown",
"text/csv",
"text/html",
"text/css",
"application/json",
"application/xml",
"text/xml",
"application/javascript",
"text/javascript",
"application/typescript",
];
export const PDF_EXTENSIONS = ['.pdf'];
export const PDF_MIME_TYPES = ['application/pdf'];
export const PDF_EXTENSIONS = [".pdf"];
export const PDF_MIME_TYPES = ["application/pdf"];
export const AUDIO_EXTENSIONS = ['.mp3', '.wav', '.ogg', '.m4a'];
export const AUDIO_MIME_TYPES = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/mp4'];
export const AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".m4a"];
export const AUDIO_MIME_TYPES = [
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/mp4",
];
/**
* Get file category based on MIME type and extension
*/
export function getFileCategory(mimeType: string, fileName: string): FileCategory {
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf('.'));
if (IMAGE_MIME_TYPES.includes(mimeType) || IMAGE_EXTENSIONS.includes(extension)) {
return 'image';
export function getFileCategory(
mimeType: string,
fileName: string,
): FileCategory {
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf("."));
if (
IMAGE_MIME_TYPES.includes(mimeType) ||
IMAGE_EXTENSIONS.includes(extension)
) {
return "image";
}
if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {
return 'pdf';
return "pdf";
}
if (AUDIO_MIME_TYPES.includes(mimeType) || AUDIO_EXTENSIONS.includes(extension)) {
return 'audio';
if (
AUDIO_MIME_TYPES.includes(mimeType) ||
AUDIO_EXTENSIONS.includes(extension)
) {
return "audio";
}
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith('text/')) {
return 'text';
if (
TEXT_MIME_TYPES.includes(mimeType) ||
TEXT_EXTENSIONS.includes(extension) ||
mimeType.startsWith("text/")
) {
return "text";
}
return 'unknown';
return "unknown";
}
/**
@@ -69,36 +138,36 @@ export function getFileCategory(mimeType: string, fileName: string): FileCategor
*/
export function getAcceptString(categories: FileCategory[]): string {
const accepts: string[] = [];
for (const category of categories) {
switch (category) {
case 'image':
case "image":
accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);
break;
case 'text':
case "text":
accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);
break;
case 'pdf':
case "pdf":
accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);
break;
case 'audio':
case "audio":
accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);
break;
}
}
return accepts.join(',');
return accepts.join(",");
}
/**
* Format file size for display
*/
export function formatFileSize(bytes: number): string {
if (bytes === 0) return '0 B';
if (bytes === 0) return "0 B";
const k = 1024;
const sizes = ['B', 'KB', 'MB', 'GB'];
const sizes = ["B", "KB", "MB", "GB"];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + " " + sizes[i];
}
/**
@@ -128,42 +197,44 @@ export function readFileAsText(file: File): Promise<string> {
/**
* Process uploaded files into ChatUploadedFile format
*/
export async function processUploadedFiles(files: File[]): Promise<ChatUploadedFile[]> {
export async function processUploadedFiles(
files: File[],
): Promise<ChatUploadedFile[]> {
const results: ChatUploadedFile[] = [];
for (const file of files) {
const id = Date.now().toString() + Math.random().toString(36).substring(2, 9);
const id =
Date.now().toString() + Math.random().toString(36).substring(2, 9);
const category = getFileCategory(file.type, file.name);
const base: ChatUploadedFile = {
id,
name: file.name,
size: file.size,
type: file.type,
file
file,
};
try {
if (category === 'image') {
if (category === "image") {
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else if (category === 'text' || category === 'unknown') {
} else if (category === "text" || category === "unknown") {
const textContent = await readFileAsText(file);
results.push({ ...base, textContent });
} else if (category === 'pdf') {
} else if (category === "pdf") {
results.push(base);
} else if (category === 'audio') {
} else if (category === "audio") {
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else {
results.push(base);
}
} catch (error) {
console.error('Error processing file:', file.name, error);
console.error("Error processing file:", file.name, error);
results.push(base);
}
}
return results;
}

View File

@@ -47,30 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
tasks[model.hugging_face_id] = model.tasks;
}
}
}
return tasks;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
if (!model?.tasks) return false;
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
}
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -616,7 +593,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
// Unwrap the instance
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { isDownloading: false, progress: null, statusText: 'UNKNOWN', perNode: [] };
return { isDownloading: false, progress: null, statusText: 'PREPARING', perNode: [] };
}
const inst = instance as { shardAssignments?: { nodeToRunner?: Record<string, string>; runnerToShard?: Record<string, unknown>; modelId?: string } };
@@ -729,7 +706,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function deriveInstanceStatus(instanceWrapped: unknown): { statusText: string; statusClass: string } {
const [, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { statusText: 'UNKNOWN', statusClass: 'inactive' };
return { statusText: 'PREPARING', statusClass: 'inactive' };
}
const inst = instance as { shardAssignments?: { runnerToShard?: Record<string, unknown> } };
@@ -758,7 +735,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const has = (s: string) => statuses.includes(s);
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
if (statuses.length === 0) return { statusText: 'PREPARING', statusClass: 'inactive' };
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
@@ -1273,7 +1250,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
placeholder="Ask anything"
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
/>
</div>
</div>
@@ -1291,9 +1267,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
</div>
<div
<div
bind:this={instancesContainerRef}
class="max-h-72 space-y-3 overflow-y-auto"
class="max-h-72 xl:max-h-96 space-y-3 overflow-y-auto overflow-x-hidden py-px"
>
{#each Object.entries(instanceData) as [id, instance]}
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
@@ -1495,18 +1471,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{@const foundModel = models.find(m => m.id === selectedModelId)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="flex items-center gap-2 text-exo-light-gray truncate">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{foundModel.name || foundModel.id}</span>
</span>
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
</span>
{:else}
@@ -1551,7 +1517,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(model.id)}
<button
type="button"
onclick={() => {
@@ -1571,16 +1536,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
: 'text-white/30 cursor-default'
}"
>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
@@ -1777,7 +1733,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
<ChatForm placeholder="Ask anything" showModelSelector={true} />
</div>
</div>
</div>
@@ -1817,7 +1773,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
</div>
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
<div class="space-y-3 max-h-72 xl:max-h-96 overflow-y-auto overflow-x-hidden py-px pr-1">
{#each Object.entries(instanceData) as [id, instance]}
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
{@const statusText = downloadInfo.statusText}

View File

@@ -199,7 +199,13 @@
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
?? (downloadPayload as Record<string, unknown>).downloadProgress
?? {};
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
const totalBytes = getBytes(
(downloadPayload as Record<string, unknown>).total_bytes
?? (downloadPayload as Record<string, unknown>).totalBytes
?? (rawProgress as Record<string, unknown>).total_bytes
?? (rawProgress as Record<string, unknown>).totalBytes
);
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
@@ -332,8 +338,13 @@
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
</div>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
<div>
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
</div>
<div class="text-exo-light-gray normal-case tracking-normal">
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
</div>
</div>
</div>
@@ -385,7 +396,7 @@
</div>
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
{#if model.status !== 'completed'}
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
{/if}

View File

@@ -1,16 +1,15 @@
import tailwindcss from '@tailwindcss/vite';
import { sveltekit } from '@sveltejs/kit/vite';
import { defineConfig } from 'vite';
import tailwindcss from "@tailwindcss/vite";
import { sveltekit } from "@sveltejs/kit/vite";
import { defineConfig } from "vite";
export default defineConfig({
plugins: [tailwindcss(), sveltekit()],
server: {
proxy: {
'/v1': 'http://localhost:52415',
'/state': 'http://localhost:52415',
'/models': 'http://localhost:52415',
'/instance': 'http://localhost:52415'
}
}
"/v1": "http://localhost:52415",
"/state": "http://localhost:52415",
"/models": "http://localhost:52415",
"/instance": "http://localhost:52415",
},
},
});

212
docs/api.md Normal file
View File

@@ -0,0 +1,212 @@
# EXO API Technical Reference
This document describes the REST API exposed by the **EXO ** service, as implemented in:
`src/exo/master/api.py`
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
Base URL example:
```
http://localhost:52415
```
## 1. General / Meta Endpoints
### Get Master Node ID
**GET** `/node_id`
Returns the identifier of the current master node.
**Response (example):**
```json
{
"node_id": "node-1234"
}
```
### Get Cluster State
**GET** `/state`
Returns the current state of the cluster, including nodes and active instances.
**Response:**
JSON object describing topology, nodes, and instances.
### Get Events
**GET** `/events`
Returns the list of internal events recorded by the master (mainly for debugging and observability).
**Response:**
Array of event objects.
## 2. Model Instance Management
### Create Instance
**POST** `/instance`
Creates a new model instance in the cluster.
**Request body (example):**
```json
{
"instance": {
"model_id": "llama-3.2-1b",
"placement": { }
}
}
```
**Response:**
JSON description of the created instance.
### Delete Instance
**DELETE** `/instance/{instance_id}`
Deletes an existing instance by ID.
**Path parameters:**
* `instance_id`: string, ID of the instance to delete
**Response:**
Status / confirmation JSON.
### Get Instance
**GET** `/instance/{instance_id}`
Returns details of a specific instance.
**Path parameters:**
* `instance_id`: string
**Response:**
JSON description of the instance.
### Preview Placements
**GET** `/instance/previews?model_id=...`
Returns possible placement previews for a given model.
**Query parameters:**
* `model_id`: string, required
**Response:**
Array of placement preview objects.
### Compute Placement
**GET** `/instance/placement`
Computes a placement for a potential instance without creating it.
**Query parameters (typical):**
* `model_id`: string
* `sharding`: string or config
* `instance_meta`: JSON-encoded metadata
* `min_nodes`: integer
**Response:**
JSON object describing the proposed placement / instance configuration.
### Place Instance (Dry Operation)
**POST** `/place_instance`
Performs a placement operation for an instance (planning step), without necessarily creating it.
**Request body:**
JSON describing the instance to be placed.
**Response:**
Placement result.
## 3. Models
### List Models
**GET** `/models`
**GET** `/v1/models` (alias)
Returns the list of available models and their metadata.
**Response:**
Array of model descriptors.
## 4. Inference / Chat Completions
### OpenAI-Compatible Chat Completions
**POST** `/v1/chat/completions`
Executes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.
**Request body (example):**
```json
{
"model": "llama-3.2-1b",
"messages": [
{ "role": "system", "content": "You are a helpful assistant." },
{ "role": "user", "content": "Hello" }
],
"stream": false
}
```
**Response:**
OpenAI-compatible chat completion response.
### Benchmarked Chat Completions
**POST** `/bench/chat/completions`
Same as `/v1/chat/completions`, but also returns performance and generation statistics.
**Request body:**
Same schema as `/v1/chat/completions`.
**Response:**
Chat completion plus benchmarking metrics.
## 5. Complete Endpoint Summary
```
GET /node_id
GET /state
GET /events
POST /instance
GET /instance/{instance_id}
DELETE /instance/{instance_id}
GET /instance/previews
GET /instance/placement
POST /place_instance
GET /models
GET /v1/models
POST /v1/chat/completions
POST /bench/chat/completions
```
## 6. Notes
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.

74
flake.lock generated
View File

@@ -8,11 +8,11 @@
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1761893049,
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
"lastModified": 1768287139,
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
"owner": "nix-community",
"repo": "fenix",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
"type": "github"
},
"original": {
@@ -21,25 +21,43 @@
"type": "github"
}
},
"flake-utils": {
"flake-parts": {
"inputs": {
"systems": "systems"
"nixpkgs-lib": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"lastModified": 1768135262,
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"owner": "hercules-ci",
"repo": "flake-parts",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1768127708,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs-swift": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
@@ -50,27 +68,28 @@
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
}
},
"root": {
"inputs": {
"fenix": "fenix",
"flake-utils": "flake-utils",
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1761849405,
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
"lastModified": 1768224240,
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
"rev": "725349602e525df37f377701e001fe8aab807878",
"type": "github"
},
"original": {
@@ -80,21 +99,6 @@
"type": "github"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"treefmt-nix": {
"inputs": {
"nixpkgs": [
@@ -102,11 +106,11 @@
]
},
"locked": {
"lastModified": 1762938485,
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
"lastModified": 1768158989,
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
"owner": "numtide",
"repo": "treefmt-nix",
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
"type": "github"
},
"original": {

198
flake.nix
View File

@@ -3,118 +3,136 @@
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
# Provides Rust dev-env integration:
flake-parts = {
url = "github:hercules-ci/flake-parts";
inputs.nixpkgs-lib.follows = "nixpkgs";
};
fenix = {
url = "github:nix-community/fenix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Provides formatting infrastructure:
treefmt-nix = {
url = "github:numtide/treefmt-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
};
# TODO: figure out caching story
# nixConfig = {
# # nix community cachix
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
# extra-substituters = "https://nix-community.cachix.org";
# };
nixConfig = {
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
extra-substituters = "https://exo.cachix.org";
};
outputs =
inputs:
let
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
systems = [
"x86_64-linux"
"aarch64-darwin"
"aarch64-linux"
];
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
in
inputs.flake-utils.lib.eachSystem systems (
system:
let
pkgs = import inputs.nixpkgs {
inherit system;
overlays = [ inputs.fenix.overlays.default ];
};
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
projectRootFile = "flake.nix";
programs.ruff-format.enable = true;
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
programs.rustfmt.enable = true;
programs.rustfmt.package = (fenixToolchain system).rustfmt;
programs.nixpkgs-fmt.enable = true;
};
in
{
formatter = treefmtEval.config.build.wrapper;
checks.formatting = treefmtEval.config.build.check inputs.self;
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = pkgs.mkShell {
packages =
with pkgs;
[
# PYTHON
python313
uv
ruff
basedpyright
imports = [
inputs.treefmt-nix.flakeModule
];
# RUST
((fenixToolchain system).withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
perSystem =
{ config, inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
treefmt = {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
};
rustfmt = {
enable = true;
package = fenixToolchain.rustfmt;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format = {
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
};
};
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
]);
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
packages =
[
# FORMATTING
config.treefmt.build.wrapper
# PYTHON
python313
uv
ruff
basedpyright
# RUST
(fenixToolchain.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
]);
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
'';
};
};
}
);
};
}

View File

@@ -8,35 +8,21 @@ dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"typeguard>=4.4.4",
"pydantic>=2.11.7",
"base58>=2.1.1",
"cryptography>=45.0.5",
"fastapi>=0.116.1",
"filelock>=3.18.0",
"aiosqlite>=0.21.0",
"networkx>=3.5",
"protobuf>=6.32.0",
"rich>=14.1.0",
"rustworkx>=0.17.1",
"sqlmodel>=0.0.24",
"sqlalchemy[asyncio]>=2.0.43",
"greenlet>=3.2.4",
"huggingface-hub>=0.33.4",
"psutil>=7.0.0",
"loguru>=0.7.3",
"textual>=5.3.0",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"bidict>=0.23.1",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx-lm>=0.28.3",
"mlx==0.30.1; sys_platform == 'darwin'",
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.12.1",
]
[project.scripts]
@@ -47,6 +33,7 @@ exo = "exo.main:main"
# dependencies only required for development
[dependency-groups]
dev = [
"basedpyright>=1.29.0",
"pyinstaller>=6.17.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
@@ -84,7 +71,7 @@ build-backend = "uv_build"
###
[tool.basedpyright]
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src", "bench"]
typeCheckingMode = "strict"
failOnWarnings = true
@@ -112,6 +99,7 @@ root = "src"
# supported platforms for this project
[tool.uv]
prerelease = "allow"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",

View File

@@ -1,6 +1,7 @@
import argparse
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -195,6 +196,8 @@ class Node:
def main():
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system

View File

@@ -1,13 +1,11 @@
import base64
import json
import time
from collections.abc import AsyncGenerator
from typing import Literal, cast
from typing import cast
import anyio
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
@@ -24,12 +22,13 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
@@ -37,10 +36,7 @@ from exo.shared.types.api import (
CreateInstanceResponse,
DeleteInstanceResponse,
FinishReason,
ImageData,
ImageEditsInternalParams,
ImageGenerationResponse,
ImageGenerationTaskParams,
GenerationStats,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -48,17 +44,14 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -94,23 +87,12 @@ def chunk_to_response(
)
def get_model_card(model_id: str) -> ModelCard | None:
async def resolve_model_meta(model_id: str) -> ModelMetadata:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for _, model_card in MODEL_CARDS.items():
if model_id == model_card.model_id:
return model_card
async def resolve_model_meta(model_id: str) -> ModelMetadata:
model_card = get_model_card(model_id)
if model_card is not None:
return model_card.metadata
return await get_model_meta(model_id)
else:
return await get_model_meta(model_id)
class API:
@@ -154,7 +136,6 @@ class API:
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -163,7 +144,6 @@ class API:
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
def unpause(self, result_clock: int):
@@ -195,10 +175,7 @@ class API:
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
)
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -517,6 +494,45 @@ class API:
],
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId, parse_gpt_oss: bool
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
text_parts.append(chunk.text)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
resp = BenchChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text
),
finish_reason=finish_reason,
)
],
generation_stats=stats,
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
@@ -552,15 +568,11 @@ class API:
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def image_generations(
self, payload: ImageGenerationTaskParams
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image generation requests.
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -572,304 +584,16 @@ class API:
status_code=404, detail=f"No instance found for model {payload.model}"
)
command = ImageGeneration(
request_params=payload,
)
payload.stream = False
command = ChatCompletion(request_params=payload)
await self._send(command)
# Check if streaming is requested
if payload.stream and payload.partial_images and payload.partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
async def _generate_image_stream(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> AsyncGenerator[str, None]:
"""Generate SSE stream of partial and final images."""
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
image_total_chunks: dict[tuple[int, bool], int] = {}
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
with recv as chunks:
async for chunk in chunks:
key = (chunk.image_index, chunk.is_partial)
if key not in image_chunks:
image_chunks[key] = {}
image_total_chunks[key] = chunk.total_chunks
image_metadata[key] = (
chunk.partial_index,
chunk.total_partials,
)
image_chunks[key][chunk.chunk_index] = chunk.data
# Check if this image is complete
if len(image_chunks[key]) == image_total_chunks[key]:
full_data = "".join(
image_chunks[key][i] for i in range(len(image_chunks[key]))
)
partial_idx, total_partials = image_metadata[key]
if chunk.is_partial:
# Yield partial image event
event_data = {
"type": "partial",
"partial_index": partial_idx,
"total_partials": total_partials,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
else:
# Final image
event_data = {
"type": "final",
"image_index": chunk.image_index,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
images_complete += 1
if images_complete >= num_images:
yield "data: [DONE]\n\n"
break
# Clean up completed image chunks
del image_chunks[key]
del image_total_chunks[key]
del image_metadata[key]
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_generation(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> ImageGenerationResponse:
"""Collect all image chunks (non-streaming) and return a single response."""
# Track chunks per image: {image_index: {chunk_index: data}}
# Only track non-partial (final) images
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
# Skip partial images in non-streaming mode
if chunk.is_partial:
continue
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
# Check if this image is complete
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
# Reassemble images in order
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: bool = Form(False),
partial_images: int = Form(0),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
model_meta = await resolve_model_meta(model)
resolved_model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
# Read and base64 encode the uploaded image
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
# Map input_fidelity to image_strength
image_strength = 0.7 if input_fidelity == "high" else 0.3
# Split image into chunks to stay under gossipsub message size limit
data_chunks = [
image_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
# Create command first to get command_id
command = ImageEdits(
request_params=ImageEditsInternalParams(
image_data="", # Empty - will be assembled at worker from chunks
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
n=n,
size=size,
response_format=response_format,
image_strength=image_strength,
stream=stream,
partial_images=partial_images,
),
)
# Send input chunks BEFORE the command
logger.info(
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(data_chunks):
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
)
)
)
# Now send the main command
await self._send(command)
num_images = n
# Check if streaming is requested
if stream and partial_images and partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=num_images,
response_format=response_format,
),
media_type="text/event-stream",
)
# Track chunks per image: {image_index: {chunk_index: data}}
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command.command_id], recv = channel[
ImageChunk
]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
raise
finally:
# Send TaskFinished command
await self._send(TaskFinished(finished_command_id=command.command_id))
del self._image_generation_queues[command.command_id]
return response
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
@@ -893,7 +617,6 @@ class API:
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
]
@@ -931,17 +654,14 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
await self._image_generation_queues[event.command_id].send(
event.chunk
)
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -2,7 +2,6 @@ from datetime import datetime, timedelta, timezone
import anyio
from anyio.abc import TaskGroup
from fastapi.routing import request_response
from loguru import logger
from exo.master.placement import (
@@ -12,17 +11,13 @@ from exo.master.placement import (
place_instance,
)
from exo.shared.apply import apply
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.commands import (
ChatCompletion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskFinished,
TestCommand,
)
@@ -31,7 +26,6 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
NodeTimedOut,
TaskCreated,
@@ -41,12 +35,6 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
from exo.shared.types.tasks import (
ImageGeneration as ImageGenerationTask,
)
from exo.shared.types.tasks import (
TaskId,
TaskStatus,
@@ -111,7 +99,6 @@ class Master:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
match command:
@@ -159,92 +146,6 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageEdits():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
@@ -272,13 +173,6 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(
command_id=chunk.command_id,
chunk=chunk,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -9,7 +9,6 @@ from exo.shared.types.events import (
ChunkGenerated,
Event,
IndexedEvent,
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
NodeCreated,
@@ -41,8 +40,8 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
): # Pass-through events that don't modify state
TestEvent() | ChunkGenerated() | TaskAcknowledged()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
return state
case InstanceCreated():
return apply_instance_created(event, state)

View File

@@ -44,5 +44,3 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024

View File

@@ -1,5 +1,5 @@
from exo.shared.types.memory import Memory
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
@@ -8,7 +8,6 @@ class ModelCard(CamelCaseModel):
model_id: ModelId
name: str
description: str
tasks: list[ModelTask]
tags: list[str]
metadata: ModelMetadata
@@ -46,7 +45,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -62,7 +60,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
@@ -85,6 +82,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# hidden_size=7168,
# supports_tensor=True,
# ),
# ),
# "deepseek-v3.2-4bit": ModelCard(
@@ -99,6 +97,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
# n_layers=61,
# hidden_size=7168,
# supports_tensor=True,
# ),
# ),
# deepseek r1
@@ -136,7 +135,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
@@ -152,7 +150,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
@@ -169,7 +166,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
@@ -185,7 +181,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
@@ -201,7 +196,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
@@ -217,7 +211,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
@@ -234,7 +227,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
@@ -250,7 +242,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
@@ -266,7 +257,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
@@ -283,7 +273,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
@@ -299,7 +288,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
@@ -315,7 +303,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
@@ -332,7 +319,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
@@ -348,7 +334,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
@@ -364,7 +349,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
@@ -380,7 +364,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
@@ -396,7 +379,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
@@ -412,7 +394,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
@@ -428,7 +409,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
@@ -444,7 +424,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
@@ -460,7 +439,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
@@ -476,7 +454,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
@@ -492,7 +469,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
@@ -508,7 +484,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
@@ -525,7 +500,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
@@ -541,7 +515,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
@@ -558,7 +531,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
@@ -574,7 +546,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
@@ -585,6 +556,81 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
@@ -600,188 +646,4 @@ MODEL_CARDS: dict[str, ModelCard] = {
# supports_tensor=True,
# ),
# ),
"flux1-schnell": ModelCard(
short_id="flux1-schnell",
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
name="FLUX.1 [schnell]",
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
pretty_name="FLUX.1 [schnell]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"flux1-dev": ModelCard(
short_id="flux1-dev",
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
name="FLUX.1 [dev]",
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
pretty_name="FLUX.1 [dev]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image": ModelCard(
short_id="qwen-image",
model_id=ModelId("Qwen/Qwen-Image"),
name="Qwen Image",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image"),
pretty_name="Qwen Image",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image-edit-2509": ModelCard(
short_id="qwen-image-edit-2509",
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
name="Qwen Image Edit 2509",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.ImageToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
pretty_name="Qwen Image Edit 2509",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
}

View File

@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})

View File

@@ -1,12 +1,11 @@
import time
from collections.abc import Generator
from typing import Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -29,7 +28,6 @@ class ModelListModel(BaseModel):
tags: list[str] = Field(default=[])
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
class ModelList(BaseModel):
@@ -54,6 +52,10 @@ class ChatCompletionMessage(BaseModel):
function_call: dict[str, Any] | None = None
class BenchChatCompletionMessage(ChatCompletionMessage):
pass
class TopLogprobItem(BaseModel):
token: str
logprob: float
@@ -116,6 +118,18 @@ class ChatCompletionResponse(BaseModel):
service_tier: str | None = None
class GenerationStats(BaseModel):
prompt_tps: float
generation_tps: float
prompt_tokens: int
generation_tokens: int
peak_memory_usage: Memory
class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class ChatCompletionTaskParams(BaseModel):
model: str
frequency_penalty: float | None = None
@@ -138,6 +152,10 @@ class ChatCompletionTaskParams(BaseModel):
user: str | None = None
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
pass
class PlaceInstanceParams(BaseModel):
model_id: str
sharding: Sharding = Sharding.Pipeline
@@ -184,75 +202,3 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class ImageGenerationTaskParams(BaseModel):
prompt: str
# background: str | None = None
model: str
# moderation: str | None = None
n: int | None = 1
# output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
# style: str | None = "vivid"
# user: str | None = None
class ImageEditsTaskParams(BaseModel):
image: UploadFile
prompt: str
input_fidelity: float = 0.7
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
# user: str | None = None
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
image_strength: float = 0.7
stream: bool = False
partial_images: int | None = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} chars>"
elif name is not None:
yield name, value
class ImageData(BaseModel):
b64_json: str | None = None
url: str | None = None
revised_prompt: str | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "b64_json" and value is not None:
yield name, f"<{len(value)} chars>"
elif name is not None:
yield name, value
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: list[ImageData]

View File

@@ -1,11 +1,9 @@
from collections.abc import Generator
from enum import Enum
from typing import Any
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
from .models import ModelId
@@ -23,37 +21,11 @@ class TokenChunk(BaseChunk):
text: str
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
class ImageChunk(BaseChunk):
data: str
chunk_index: int
total_chunks: int
image_index: int
is_partial: bool = False
partial_index: int | None = None
total_partials: int | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data":
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
class InputImageChunk(BaseChunk):
command_id: CommandId
data: str
chunk_index: int
total_chunks: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data":
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
data: bytes
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,11 +1,6 @@
from pydantic import Field
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -25,14 +20,6 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
request_params: ImageEditsInternalParams
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
sharding: Sharding
@@ -52,12 +39,6 @@ class TaskFinished(BaseCommand):
finished_command_id: CommandId
class SendInputChunk(BaseCommand):
"""Command to send an input image chunk (converted to event by master)."""
chunk: InputImageChunk
class RequestEventLog(BaseCommand):
since_idx: int
@@ -66,13 +47,10 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| ImageGeneration
| ImageEdits
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -3,7 +3,7 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
@@ -106,11 +106,6 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class InputChunkReceived(BaseEvent):
command_id: CommandId
chunk: InputImageChunk
class TopologyEdgeCreated(BaseEvent):
edge: Connection
@@ -136,7 +131,6 @@ Event = (
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -1,5 +1,3 @@
from enum import Enum
from pydantic import PositiveInt
from exo.shared.types.common import Id
@@ -11,21 +9,6 @@ class ModelId(Id):
pass
class ModelTask(str, Enum):
TextGeneration = "TextGeneration"
TextToImage = "TextToImage"
ImageToImage = "ImageToImage"
class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
@@ -33,4 +16,3 @@ class ModelMetadata(CamelCaseModel):
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
components: list[ComponentInfo] | None = None

View File

@@ -2,11 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
@@ -60,22 +56,6 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsInternalParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -87,7 +67,5 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| ImageGeneration
| ImageEdits
| Shutdown
)

View File

@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
pass
total_bytes: Memory
class DownloadFailed(BaseDownloadProgress):

View File

@@ -1,7 +1,4 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason
from exo.shared.types.api import FinishReason, GenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -18,32 +15,7 @@ class GenerationResponse(BaseRunnerResponse):
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class PartialImageResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
stats: GenerationStats | None = None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -9,7 +9,6 @@ from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from urllib.parse import urljoin
from huggingface_hub._snapshot_download import snapshot_download
import aiofiles
import aiofiles.os as aios
@@ -442,31 +441,12 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_files_dir = snapshot_download(
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
)
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
weight_map: dict[str, str] = {}
for index_file in index_files:
relative_dir = index_file.parent.relative_to(index_files_dir)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
if relative_dir != Path("."):
prefixed_weight_map = {
f"{relative_dir}/{key}": str(relative_dir / value)
for key, value in index_data.weight_map.items()
}
weight_map = weight_map | prefixed_weight_map
else:
weight_map = weight_map | index_data.weight_map
return weight_map
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
@@ -571,6 +551,8 @@ async def download_shard(
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
)

View File

@@ -100,68 +100,26 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
"*.py",
"tokenizer.model",
"tiktoken.model",
"*/spiece.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
)
shard_specific_patterns: set[str] = set()
if shard.model_meta.components is not None:
shardable_component = next(
(c for c in shard.model_meta.components if c.can_shard), None
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num <= shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
if weight_map and shardable_component:
for tensor_name, filename in weight_map.items():
# Strip component prefix from tensor name (added by weight map namespacing)
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
if "/" in tensor_name:
_, tensor_name_no_prefix = tensor_name.split("/", 1)
else:
tensor_name_no_prefix = tensor_name
# Determine which component this file belongs to from filename
component_path = Path(filename).parts[0] if "/" in filename else None
if component_path == shardable_component.component_path.rstrip("/"):
layer_num = extract_layer_num(tensor_name_no_prefix)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
if shard.is_first_layer or shard.is_last_layer:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns = set(["*.safetensors"])
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
for component in shard.model_meta.components:
if not component.can_shard and component.safetensors_index_filename is None:
component_pattern = f"{component.component_path.rstrip('/')}/*"
shard_specific_patterns.add(component_pattern)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
shard_specific_patterns = set(["*.safetensors"])
shard_specific_patterns = set(["*.safetensors"])
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

View File

@@ -1,10 +0,0 @@
from exo.worker.engines.image.base import ImageGenerator
from exo.worker.engines.image.distributed_model import initialize_image_model
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
__all__ = [
"ImageGenerator",
"generate_image",
"initialize_image_model",
"warmup_image_generator",
]

View File

@@ -1,50 +0,0 @@
from collections.abc import Generator
from pathlib import Path
from typing import Literal, Protocol, runtime_checkable
from PIL import Image
@runtime_checkable
class ImageGenerator(Protocol):
@property
def rank(self) -> int: ...
@property
def is_first_stage(self) -> bool: ...
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"],
seed: int,
image_path: Path | None = None,
partial_images: int = 0,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
"""Generate an image from a text prompt, or edit an existing image.
For distributed inference, only the last stage returns images.
Other stages yield nothing after participating in the pipeline.
When partial_images > 0, yields intermediate images during diffusion
as tuples of (image, partial_index, total_partials), then yields
the final image.
When partial_images = 0 (default), only yields the final image.
Args:
prompt: Text description of the image to generate
height: Image height in pixels
width: Image width in pixels
quality: Generation quality level
seed: Random seed for reproducibility
image_path: Optional path to input image for image editing
partial_images: Number of intermediate images to yield (0 for none)
Yields:
Intermediate images as (Image, partial_index, total_partials) tuples
Final PIL Image (last stage) or nothing (other stages)
"""
...

View File

@@ -1,74 +0,0 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
class BlockType(Enum):
JOINT = "joint" # Separate image/text streams
SINGLE = "single" # Concatenated streams
class TransformerBlockConfig(BaseModel):
model_config = {"frozen": True}
block_type: BlockType
count: int
has_separate_text_output: bool # True for joint blocks that output text separately
class ImageModelConfig(BaseModel):
model_config = {"frozen": True}
# Model identification
model_family: str # "flux", "fibo", "qwen"
model_variant: str # "schnell", "dev", etc.
# Architecture parameters
hidden_dim: int
num_heads: int
head_dim: int
# Block configuration - ordered sequence of block types
block_configs: tuple[TransformerBlockConfig, ...]
# Tokenization parameters
patch_size: int # 2 for Flux/Qwen
vae_scale_factor: int # 8 for Flux, 16 for others
# Inference parameters
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
# Feature flags
uses_attention_mask: bool # True for Fibo
# CFG (Classifier-Free Guidance) parameters
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@property
def total_blocks(self) -> int:
"""Total number of transformer blocks."""
return sum(bc.count for bc in self.block_configs)
@property
def joint_block_count(self) -> int:
"""Number of joint transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
)
@property
def single_block_count(self) -> int:
"""Number of single transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
)
def get_steps_for_quality(self, quality: str) -> int:
"""Get inference steps for a quality level."""
return self.default_steps[quality]
def get_num_sync_steps(self, quality: str) -> int:
"""Get number of synchronous steps based on quality."""
return ceil(self.default_steps[quality] * self.num_sync_steps_factor)

View File

@@ -1,228 +0,0 @@
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional
import mlx.core as mx
from mflux.config.config import Config
from PIL import Image
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
get_config_for_model,
)
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline import DiffusionRunner
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
from exo.worker.runner.bootstrap import logger
class DistributedImageModel:
__slots__ = (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
)
_config: ImageModelConfig
_adapter: BaseModelAdapter
_group: Optional[mx.distributed.Group]
_shard_metadata: PipelineShardMetadata
_runner: DiffusionRunner
def __init__(
self,
model_id: str,
local_path: Path,
shard_metadata: PipelineShardMetadata,
group: Optional[mx.distributed.Group] = None,
quantize: int | None = None,
):
# Get model config and create adapter (adapter owns the model)
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,
total_joint_blocks=config.joint_block_count,
total_single_blocks=config.single_block_count,
)
# Create diffusion runner (handles both single-node and distributed modes)
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
runner = DiffusionRunner(
config=config,
adapter=adapter,
group=group,
shard_metadata=shard_metadata,
num_sync_steps=num_sync_steps,
)
if group is not None:
logger.info("Initialized distributed diffusion runner")
mx.eval(adapter.model.parameters())
# TODO(ciaran): Do we need this?
mx.eval(adapter.model)
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
logger.info(f"Transformer sharded for rank {group.rank()}")
else:
logger.info("Single-node initialization")
object.__setattr__(self, "_config", config)
object.__setattr__(self, "_adapter", adapter)
object.__setattr__(self, "_group", group)
object.__setattr__(self, "_shard_metadata", shard_metadata)
object.__setattr__(self, "_runner", runner)
@classmethod
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_meta.model_id
model_path = build_model_path(model_id)
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
raise ValueError("Expected PipelineShardMetadata for image generation")
is_distributed = (
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
)
if is_distributed:
logger.info("Starting distributed init for image model")
group = mlx_distributed_init(bound_instance)
else:
group = None
return cls(
model_id=model_id,
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
)
@property
def model(self) -> Any:
"""Return the underlying mflux model via the adapter."""
return self._adapter.model
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def adapter(self) -> BaseModelAdapter:
return self._adapter
@property
def group(self) -> Optional[mx.distributed.Group]:
return self._group
@property
def shard_metadata(self) -> PipelineShardMetadata:
return self._shard_metadata
@property
def rank(self) -> int:
return self._shard_metadata.device_rank
@property
def world_size(self) -> int:
return self._shard_metadata.world_size
@property
def is_first_stage(self) -> bool:
return self._shard_metadata.device_rank == 0
@property
def is_last_stage(self) -> bool:
return self._shard_metadata.device_rank == self._shard_metadata.world_size - 1
@property
def is_distributed(self) -> bool:
return self._shard_metadata.world_size > 1
@property
def runner(self) -> DiffusionRunner:
return self._runner
# Delegate attribute access to the underlying model via the adapter.
# Guarded with TYPE_CHECKING to prevent type checker complaints
# while still providing full delegation at runtime.
if not TYPE_CHECKING:
def __getattr__(self, name: str) -> Any:
return getattr(self._adapter.model, name)
def __setattr__(self, name: str, value: Any) -> None:
if name in (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
):
object.__setattr__(self, name, value)
else:
setattr(self._adapter.model, name, value)
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"] = "medium",
seed: int = 2,
image_path: Path | None = None,
partial_images: int = 0,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
# Determine number of inference steps based on quality
steps = self._config.get_steps_for_quality(quality)
# For edit mode: compute dimensions from input image
# This also stores image_paths in the adapter for encode_prompt()
if image_path is not None:
computed_dims = self._adapter.set_image_dimensions(image_path)
if computed_dims is not None:
# Override user-provided dimensions with computed ones
width, height = computed_dims
config = Config(
num_inference_steps=steps,
height=height,
width=width,
image_path=image_path,
)
# Generate images via the runner
for result in self._runner.generate_image(
settings=config,
prompt=prompt,
seed=seed,
partial_images=partial_images,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)
generated_image, partial_idx, total_partials = result
yield (generated_image.image, partial_idx, total_partials)
else:
# Final image: GeneratedImage
logger.info("generated image")
yield result.image
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
"""Initialize DistributedImageModel from a BoundInstance."""
return DistributedImageModel.from_bound_instance(bound_instance)

View File

@@ -1,120 +0,0 @@
import base64
import io
import tempfile
from pathlib import Path
from typing import Generator, Literal
from PIL import Image
from exo.shared.types.api import ImageEditsInternalParams, ImageGenerationTaskParams
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.worker.engines.image.base import ImageGenerator
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if not size_str or size_str == "auto":
size_str = "1024x1024"
try:
parts = size_str.split("x")
if len(parts) == 2:
width, height = int(parts[0]), int(parts[1])
return (width, height)
except (ValueError, AttributeError):
pass
# Default fallback
return (1024, 1024)
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
"""Warmup the image generator with a small image."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a small dummy image for warmup (needed for edit models)
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
dummy_path = Path(tmpdir) / "warmup.png"
dummy_image.save(dummy_path)
for result in model.generate(
prompt="Warmup",
height=256,
width=256,
quality="low",
seed=2,
image_path=dummy_path,
):
if not isinstance(result, tuple):
return result
return None
def generate_image(
model: ImageGenerator,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
seed = 2 # TODO(ciaran): Randomise when not testing anymore
# Handle streaming params for both generation and edit tasks
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
with tempfile.TemporaryDirectory() as tmpdir:
if isinstance(task, ImageEditsInternalParams):
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
# Final image
image = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
)

View File

@@ -1,84 +0,0 @@
from pathlib import Path
from typing import Callable
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
QwenEditModelAdapter,
QwenModelAdapter,
)
from exo.worker.engines.image.pipeline.adapter import ModelAdapter
__all__: list[str] = []
# Type alias for adapter factory functions
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,
}
def get_config_for_model(model_id: str) -> ImageModelConfig:
"""Get configuration for a model ID.
Args:
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
The model configuration
Raises:
ValueError: If no configuration found for model ID
"""
model_id_lower = model_id.lower()
for pattern, config in _CONFIG_REGISTRY.items():
if pattern in model_id_lower:
return config
raise ValueError(f"No configuration found for model: {model_id}")
def create_adapter_for_model(
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
) -> ModelAdapter:
"""Create a model adapter for the given configuration.
Args:
config: The model configuration
model_id: The model identifier
local_path: Path to the model weights
quantize: Optional quantization bits
Returns:
A ModelAdapter instance
Raises:
ValueError: If no adapter found for model family
"""
factory = _ADAPTER_REGISTRY.get(config.model_family)
if factory is None:
raise ValueError(f"No adapter found for model family: {config.model_family}")
return factory(config, model_id, local_path, quantize)

View File

@@ -1,103 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import mlx.core as mx
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
from mflux.utils.array_util import ArrayUtil
from mflux.utils.image_util import ImageUtil
class BaseModelAdapter(ABC):
"""Base class for model adapters with shared utilities.
Provides common implementations for latent creation and decoding.
Subclasses implement model-specific prompt encoding and noise computation.
"""
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
"""Create initial latents. Uses model-specific latent creator."""
return LatentCreator.create_for_txt2img_or_img2img(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
img2img=Img2Img(
vae=self.model.vae,
latent_creator=self._get_latent_creator(),
sigmas=runtime_config.scheduler.sigmas,
init_time_step=runtime_config.init_time_step,
image_path=runtime_config.image_path,
),
)
def decode_latents(
self,
latents: mx.array,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to image. Shared implementation."""
latents = ArrayUtil.unpack_latents(
latents=latents,
height=runtime_config.height,
width=runtime_config.width,
)
decoded = self.model.vae.decode(latents)
return ImageUtil.to_image(
decoded_latents=decoded,
config=runtime_config,
seed=seed,
prompt=prompt,
quantization=self.model.bits,
lora_paths=self.model.lora_paths,
lora_scales=self.model.lora_scales,
image_path=runtime_config.image_path,
image_strength=runtime_config.image_strength,
generation_time=0,
)
# Abstract methods - subclasses must implement
@property
@abstractmethod
def model(self) -> Any:
"""Return the underlying mflux model."""
...
@abstractmethod
def _get_latent_creator(self) -> type:
"""Return the latent creator class for this model."""
...
@abstractmethod
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
total_joint_blocks: Total number of joint blocks in the model
total_single_blocks: Total number of single blocks in the model
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Default implementation: no dimension computation needed.
Override in edit adapters to compute dimensions from input image.
Returns:
None (use user-specified dimensions)
"""
return None

View File

@@ -1,11 +0,0 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
)
__all__ = [
"FluxModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -1,680 +0,0 @@
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.config.model_config import ModelConfig
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
AttentionUtils,
)
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
JointTransformerBlock,
)
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.txt2img.flux import Flux1
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class FluxPromptData:
"""Container for Flux prompt encoding results."""
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Flux has no extra forward kwargs."""
return {}
@property
def conditioning_latents(self) -> mx.array | None:
"""Flux does not use conditioning latents."""
return None
class FluxModelAdapter(BaseModelAdapter):
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
local_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> Flux1:
return self._model
@property
def transformer(self) -> Transformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0]
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def encode_prompt(self, prompt: str) -> FluxPromptData:
"""Encode prompt into FluxPromptData."""
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self._model.prompt_cache,
t5_tokenizer=self._model.t5_tokenizer,
clip_tokenizer=self._model.clip_tokenizer,
t5_text_encoder=self._model.t5_text_encoder,
clip_text_encoder=self._model.clip_text_encoder,
)
return FluxPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
@property
def needs_cfg(self) -> bool:
return False
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux does not use classifier-free guidance")
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None, # Ignored by Flux
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux text embeddings"
)
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> mx.array:
kontext_image_ids = kwargs.get("kontext_image_ids")
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # mx.array for Flux
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
if mode == BlockWrapperMode.CACHING:
return self._apply_single_block_caching(
block=block,
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_single_block_patched(
block=block,
patch_hidden=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
return cast(
list[SingleBlockInterface],
list(self._transformer.single_transformer_blocks),
)
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
all_joint = list(self._transformer.transformer_blocks)
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
all_single = list(self._transformer.single_transformer_blocks)
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
) -> tuple[mx.array, mx.array]:
num_img_tokens = hidden_states.shape[1]
batch_size = hidden_states.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# 1. Compute norms
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
)
# 2. Compute Q, K, V for full image
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 4. Concatenate Q, K, V: [text, image]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
# 6. Store IMAGE K/V in cache for async pipeline
if kv_cache is not None:
kv_cache.update_image_patch(
patch_start=0,
patch_end=num_img_tokens,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 7. Compute full attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 8. Extract and project outputs
context_attn_output = attn_output[:, :text_seq_len, :]
attn_output = attn_output[:, text_seq_len:, :]
attn_output = attn.to_out[0](attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 9. Apply norm and feed forward
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=hidden_states,
attn_output=attn_output,
gate_mlp=gate_mlp,
gate_msa=gate_msa,
scale_mlp=scale_mlp,
shift_mlp=shift_mlp,
norm_layer=block.norm2,
ff_layer=block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=c_gate_mlp,
gate_msa=c_gate_msa,
scale_mlp=c_scale_mlp,
shift_mlp=c_shift_mlp,
norm_layer=block.norm2_context,
ff_layer=block.ff_context,
)
return encoder_hidden_states, hidden_states
def _apply_joint_block_patched(
self,
block: JointBlockInterface,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# 1. Compute norms
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
)
# 2. Compute Q, K, V for image patch
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 4. Concatenate Q, K, V for patch: [text, patch]
query = mx.concatenate([txt_query, img_query], axis=2)
patch_key = mx.concatenate([txt_key, img_key], axis=2)
patch_value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 6. Apply RoPE
query, patch_key = AttentionUtils.apply_rope(
xq=query, xk=patch_key, freqs_cis=patch_rope
)
# 7. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=patch_key[:, :, text_seq_len:, :],
value=patch_value[:, :, text_seq_len:, :],
)
# 8. Get full K, V from cache
full_key, full_value = kv_cache.get_full_kv(
text_key=patch_key[:, :, :text_seq_len, :],
text_value=patch_value[:, :, :text_seq_len, :],
)
# 9. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=full_key,
value=full_value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 10. Extract and project outputs
context_attn_output = attn_output[:, :text_seq_len, :]
hidden_attn_output = attn_output[:, text_seq_len:, :]
hidden_attn_output = attn.to_out[0](hidden_attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 11. Apply norm and feed forward
patch_hidden = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=patch_hidden,
attn_output=hidden_attn_output,
gate_mlp=gate_mlp,
gate_msa=gate_msa,
scale_mlp=scale_mlp,
shift_mlp=shift_mlp,
norm_layer=block.norm2,
ff_layer=block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=c_gate_mlp,
gate_msa=c_gate_msa,
scale_mlp=c_scale_mlp,
shift_mlp=c_shift_mlp,
norm_layer=block.norm2_context,
ff_layer=block.ff_context,
)
return encoder_hidden_states, patch_hidden
def _apply_single_block_caching(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
) -> mx.array:
total_seq_len = hidden_states.shape[1]
num_img_tokens = total_seq_len - text_seq_len
batch_size = hidden_states.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# Residual connection
residual = hidden_states
# 1. Compute norm
norm_hidden, gate = block.norm(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
# 4. Store IMAGE K/V in cache
if kv_cache is not None:
kv_cache.update_image_patch(
patch_start=0,
patch_end=num_img_tokens,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 5. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 6. Apply feed forward and projection
hidden_states = block._apply_feed_forward_and_projection(
norm_hidden_states=norm_hidden,
attn_output=attn_output,
gate=gate,
)
return residual + hidden_states
def _apply_single_block_patched(
self,
block: SingleBlockInterface,
patch_hidden: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
) -> mx.array:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# Residual connection
residual = patch_hidden
# 1. Compute norm
norm_hidden, gate = block.norm(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 4. Apply RoPE
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
# 5. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 6. Get full K, V from cache
full_key, full_value = kv_cache.get_full_kv(
text_key=key[:, :, :text_seq_len, :],
text_value=value[:, :, :text_seq_len, :],
)
# 7. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=full_key,
value=full_value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 8. Apply feed forward and projection
hidden_states = block._apply_feed_forward_and_projection(
norm_hidden_states=norm_hidden,
attn_output=attn_output,
gate=gate,
)
return residual + hidden_states

View File

@@ -1,48 +0,0 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
FLUX_SCHNELL_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="schnell",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
uses_attention_mask=False,
)
FLUX_DEV_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="dev",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
uses_attention_mask=False,
)

View File

@@ -1,13 +0,0 @@
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
from exo.worker.engines.image.models.qwen.config import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
)
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
__all__ = [
"QwenModelAdapter",
"QwenEditModelAdapter",
"QWEN_IMAGE_CONFIG",
"QWEN_IMAGE_EDIT_CONFIG",
]

View File

@@ -1,519 +0,0 @@
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.config.model_config import ModelConfig
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
QwenPromptEncoder,
)
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class QwenPromptData:
"""Container for Qwen prompt encoding results.
Implements PromptData protocol with additional Qwen-specific attributes.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
):
self._prompt_embeds = prompt_embeds
self.prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self.negative_prompt_mask = negative_prompt_mask
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds # Use prompt_embeds as placeholder
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return encoder_hidden_states_mask for the appropriate prompt."""
if positive:
return {"encoder_hidden_states_mask": self.prompt_mask}
else:
return {"encoder_hidden_states_mask": self.negative_prompt_mask}
@property
def conditioning_latents(self) -> mx.array | None:
"""Standard Qwen does not use conditioning latents."""
return None
class QwenModelAdapter(BaseModelAdapter):
"""Adapter for Qwen-Image model.
Key differences from Flux:
- Single text encoder (vs dual T5+CLIP)
- 60 joint-style blocks, no single blocks
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
- Norm-preserving CFG with negative prompts
- Uses attention mask for variable-length text
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImage(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
local_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImage:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def encode_prompt(self, prompt: str) -> QwenPromptData:
"""Encode prompt into QwenPromptData.
Qwen uses classifier-free guidance with explicit negative prompts.
Returns a QwenPromptData container with all 4 tensors.
"""
# TODO(ciaran): empty string as default negative prompt
negative_prompt = ""
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
QwenPromptEncoder.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_cache=self._model.prompt_cache,
qwen_tokenizer=self._model.qwen_tokenizer,
qwen_text_encoder=self._model.text_encoder,
)
)
return QwenPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=neg_embeds,
negative_prompt_mask=neg_mask,
)
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
return self._model.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
# Image embedding
embedded_hidden = self._transformer.img_in(hidden_states)
# Text embedding: first normalize, then project
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings.
For Qwen, the time_text_embed only uses hidden_states for:
- batch_size (shape[0])
- dtype
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
when embedded hidden_states are not yet available.
"""
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
# (which for Qwen is the same as prompt_embeds)
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> Any:
"""Compute 3D rotary embeddings for Qwen.
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
Returns:
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
((img_cos, img_sin), (txt_cos, txt_sin))
"""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
cond_image_grid = kwargs.get("cond_image_grid")
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]] for Qwen
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply Qwen joint block.
For caching mode, we run the full block and optionally populate the KV cache.
For patched mode, we use the cached KV values (not yet implemented).
"""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
block_idx = kwargs.get("block_idx")
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
else:
# mode == BlockWrapperMode.PATCHED
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Qwen has no single blocks."""
raise NotImplementedError("Qwen does not have single blocks")
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final normalization and projection."""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Return all 60 transformer blocks."""
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
all_blocks = list(self._transformer.transformer_blocks)
assigned_blocks = all_blocks[start_layer:end_layer]
self._transformer.transformer_blocks = assigned_blocks
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams.
For Qwen, this is called before final projection.
The streams remain separate through all blocks.
"""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: Any, # QwenTransformerBlock
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
"""Apply joint block in caching mode (full attention, optionally populate cache).
Delegates to the QwenTransformerBlock's forward pass.
"""
# Call the block directly - it handles all the modulation and attention internally
return block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
text_embeddings=text_embeddings,
image_rotary_emb=rotary_embeddings,
block_idx=block_idx,
)
def _apply_joint_block_patched(
self,
block: Any, # QwenTransformerBlock
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dim
# 1. Compute modulation parameters
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# 2. Apply normalization and modulation
img_normed = block.img_norm1(patch_hidden)
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
# 3. Compute Q, K, V for image patch
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# 4. Compute Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# 5. Reshape to [B, S, H, D]
patch_len = patch_hidden.shape[1]
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
txt_query = mx.reshape(
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
)
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
txt_value = mx.reshape(
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
)
# 6. Apply RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# 7. Extract RoPE for patch: slice image RoPE, keep full text RoPE
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
patch_img_cos = img_cos[patch_start:patch_end]
patch_img_sin = img_sin[patch_start:patch_end]
# 8. Apply RoPE to Q, K
img_query = QwenAttention._apply_rope_qwen(
img_query, patch_img_cos, patch_img_sin
)
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# 9. Transpose to [B, H, S, D] for cache operations
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
# 10. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=img_key_bhsd,
value=img_value_bhsd,
)
# 11. Get full K, V from cache (text + full image)
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
full_key, full_value = kv_cache.get_full_kv(
text_key=txt_key_bhsd,
text_value=txt_value_bhsd,
)
# 12. Build query: [text, patch]
joint_query = mx.concatenate([txt_query, img_query], axis=1)
# 13. Build attention mask for [text + patch] query attending to [text + full_image] KV
mask = QwenAttention._convert_mask_for_qwen(
mask=encoder_hidden_states_mask,
joint_seq_len=full_key.shape[2], # text + full_image
txt_seq_len=text_seq_len,
)
# 14. Compute attention
hidden_states = attn._compute_attention_qwen(
query=joint_query,
key=mx.transpose(full_key, (0, 2, 1, 3)), # Back to [B, S, H, D]
value=mx.transpose(full_value, (0, 2, 1, 3)),
mask=mask,
block_idx=block_idx,
)
# 15. Extract text and image attention outputs
txt_attn_output = hidden_states[:, :text_seq_len, :]
img_attn_output = hidden_states[:, text_seq_len:, :]
# 16. Project outputs
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# 17. Apply residual + gate for attention
patch_hidden = patch_hidden + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# 18. Apply feed-forward for image
img_normed2 = block.img_norm2(patch_hidden)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, img_mod2
)
img_mlp_output = block.img_ff(img_modulated2)
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
# 19. Apply feed-forward for text
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, txt_mod2
)
txt_mlp_output = block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, patch_hidden

View File

@@ -1,49 +0,0 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
# Qwen-Image has 60 joint-style blocks (no single blocks)
# Architecture: 24 heads * 128 dim = 3072 hidden dim
# VAE uses scale factor of 16 (vs Flux's 8)
QWEN_IMAGE_CONFIG = ImageModelConfig(
model_family="qwen",
model_variant="image",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
# Qwen has no single blocks - all blocks process image and text separately
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
guidance_scale=None, # Set to None or < 1.0 to disable CFG
)
# Qwen-Image-Edit uses the same architecture but different processing pipeline
# Uses vision-language encoding and conditioning latents
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
model_family="qwen-edit",
model_variant="image-edit",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125,
uses_attention_mask=True,
guidance_scale=None,
)

View File

@@ -1,671 +0,0 @@
import math
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
from mflux.models.qwen.variants.edit.utils.qwen_edit_util import QwenEditUtil
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class QwenEditPromptData:
"""Container for Qwen edit prompt encoding results.
Includes vision-language encoded embeddings and edit-specific conditioning.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
conditioning_latents: mx.array,
qwen_image_ids: mx.array,
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
):
self._prompt_embeds = prompt_embeds
self.prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self.negative_prompt_mask = negative_prompt_mask
self._conditioning_latents = conditioning_latents
self._qwen_image_ids = qwen_image_ids
self._cond_image_grid = cond_image_grid
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from vision-language encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
@property
def conditioning_latents(self) -> mx.array:
"""Static image conditioning latents to concatenate with generated latents."""
return self._conditioning_latents
@property
def qwen_image_ids(self) -> mx.array:
"""Spatial position IDs for conditioning images."""
return self._qwen_image_ids
@property
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
"""Conditioning image grid dimensions."""
return self._cond_image_grid
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return encoder_hidden_states_mask and edit-specific params."""
if positive:
return {
"encoder_hidden_states_mask": self.prompt_mask,
"qwen_image_ids": self._qwen_image_ids,
"cond_image_grid": self._cond_image_grid,
}
else:
return {
"encoder_hidden_states_mask": self.negative_prompt_mask,
"qwen_image_ids": self._qwen_image_ids,
"cond_image_grid": self._cond_image_grid,
}
@property
def is_edit_mode(self) -> bool:
"""Indicates this is edit mode with conditioning latents."""
return True
class QwenEditModelAdapter(BaseModelAdapter):
"""Adapter for Qwen-Image-Edit model.
Key differences from standard QwenModelAdapter:
- Uses QwenImageEdit model with vision-language components
- Encodes prompts WITH input images via VL tokenizer/encoder
- Creates conditioning latents from input images
- Supports image editing with concatenated latents during diffusion
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImageEdit(
quantize=quantize,
local_path=str(local_path),
)
self._transformer = self._model.transformer
# Store dimensions and image paths (set via set_image_dimensions)
self._vl_width: int | None = None
self._vl_height: int | None = None
self._vae_width: int | None = None
self._vae_height: int | None = None
self._image_paths: list[str] | None = None
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImageEdit:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def _compute_dimensions_from_image(
self, image_path: Path
) -> tuple[int, int, int, int, int, int]:
"""Compute VL and VAE dimensions from input image.
Returns:
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
"""
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Vision-language dimensions (384x384 target area)
condition_image_size = 384 * 384
condition_ratio = image_size[0] / image_size[1]
vl_width = math.sqrt(condition_image_size * condition_ratio)
vl_height = vl_width / condition_ratio
vl_width = round(vl_width / 32) * 32
vl_height = round(vl_height / 32) * 32
# VAE dimensions (1024x1024 target area)
vae_image_size = 1024 * 1024
vae_ratio = image_size[0] / image_size[1]
vae_width = math.sqrt(vae_image_size * vae_ratio)
vae_height = vae_width / vae_ratio
vae_width = round(vae_width / 32) * 32
vae_height = round(vae_height / 32) * 32
# Output dimensions from input image aspect ratio
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
return (
int(vl_width),
int(vl_height),
int(vae_width),
int(vae_height),
int(output_width),
int(output_height),
)
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
"""Create initial noise latents (pure noise for edit mode)."""
return QwenLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(self, prompt: str) -> QwenEditPromptData:
"""Encode prompt with input images using vision-language encoder.
Uses stored image_paths from set_image_dimensions() for VL encoding.
Args:
prompt: Text prompt for editing
Returns:
QwenEditPromptData with VL embeddings and conditioning latents
"""
# Ensure image_paths and dimensions were set via set_image_dimensions()
if self._image_paths is None:
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for QwenEditModelAdapter"
)
negative_prompt = ""
image_paths = self._image_paths
# Use stored dimensions (computed from input image)
vl_width = self._vl_width
vl_height = self._vl_height
vae_width = self._vae_width
vae_height = self._vae_height
# Encode prompts with images via vision-language components
tokenizer = self._model.qwen_vl_tokenizer
pos_input_ids, pos_attention_mask, pos_pixel_values, pos_image_grid_thw = (
tokenizer.tokenize_with_image(
prompt, image_paths, vl_width=vl_width, vl_height=vl_height
)
)
pos_hidden_states = self._model.qwen_vl_encoder(
input_ids=pos_input_ids,
attention_mask=pos_attention_mask,
pixel_values=pos_pixel_values,
image_grid_thw=pos_image_grid_thw,
)
mx.eval(pos_hidden_states[0])
mx.eval(pos_hidden_states[1])
# Encode negative prompt with images
neg_input_ids, neg_attention_mask, neg_pixel_values, neg_image_grid_thw = (
tokenizer.tokenize_with_image(
negative_prompt, image_paths, vl_width=vl_width, vl_height=vl_height
)
)
neg_hidden_states = self._model.qwen_vl_encoder(
input_ids=neg_input_ids,
attention_mask=neg_attention_mask,
pixel_values=neg_pixel_values,
image_grid_thw=neg_image_grid_thw,
)
mx.eval(neg_hidden_states[0])
mx.eval(neg_hidden_states[1])
# Create conditioning latents from input images
# Ensure dimensions are set (should have been set via set_image_dimensions)
assert vl_width is not None and vl_height is not None
assert vae_width is not None and vae_height is not None
(
conditioning_latents,
qwen_image_ids,
cond_h_patches,
cond_w_patches,
num_images,
) = QwenEditUtil.create_image_conditioning_latents(
vae=self._model.vae,
height=vae_height,
width=vae_width,
image_paths=image_paths,
vl_width=vl_width,
vl_height=vl_height,
)
# Build cond_image_grid
if num_images > 1:
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
]
else:
cond_image_grid = (1, cond_h_patches, cond_w_patches)
return QwenEditPromptData(
prompt_embeds=pos_hidden_states[0].astype(mx.float16),
prompt_mask=pos_hidden_states[1].astype(mx.float16),
negative_prompt_embeds=neg_hidden_states[0].astype(mx.float16),
negative_prompt_mask=neg_hidden_states[1].astype(mx.float16),
conditioning_latents=conditioning_latents,
qwen_image_ids=qwen_image_ids,
cond_image_grid=cond_image_grid,
)
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_paths for use in encode_prompt().
Returns:
(output_width, output_height) for runtime config
"""
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
image_path
)
self._vl_width = vl_w
self._vl_height = vl_h
self._vae_width = vae_w
self._vae_height = vae_h
self._image_paths = [str(image_path)]
return out_w, out_h
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
return QwenImage.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
embedded_hidden = self._transformer.img_in(hidden_states)
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings."""
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> Any:
"""Compute 3D rotary embeddings for Qwen edit."""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
cond_image_grid = kwargs.get("cond_image_grid")
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply Qwen joint block."""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
block_idx = kwargs.get("block_idx")
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Qwen has no single blocks."""
raise NotImplementedError("Qwen does not have single blocks")
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final normalization and projection."""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Return all 60 transformer blocks."""
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
all_blocks = list(self._transformer.transformer_blocks)
assigned_blocks = all_blocks[start_layer:end_layer]
self._transformer.transformer_blocks = assigned_blocks
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams."""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: Any,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
"""Apply joint block in caching mode."""
return block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
text_embeddings=text_embeddings,
image_rotary_emb=rotary_embeddings,
block_idx=block_idx,
)
def _apply_joint_block_patched(
self,
block: Any,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dim
# Modulation parameters
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# Normalization and modulation
img_normed = block.img_norm1(patch_hidden)
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
# Q, K, V for image patch
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# Reshape to [B, S, H, D]
patch_len = patch_hidden.shape[1]
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
txt_query = mx.reshape(
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
)
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
txt_value = mx.reshape(
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
)
# RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# Extract RoPE for patch
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
patch_img_cos = img_cos[patch_start:patch_end]
patch_img_sin = img_sin[patch_start:patch_end]
# Apply RoPE
img_query = QwenAttention._apply_rope_qwen(
img_query, patch_img_cos, patch_img_sin
)
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# Transpose to [B, H, S, D]
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
# Update cache
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=img_key_bhsd,
value=img_value_bhsd,
)
# Get full K, V from cache
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
full_key, full_value = kv_cache.get_full_kv(
text_key=txt_key_bhsd,
text_value=txt_value_bhsd,
)
# Build query
joint_query = mx.concatenate([txt_query, img_query], axis=1)
# Build attention mask
mask = QwenAttention._convert_mask_for_qwen(
mask=encoder_hidden_states_mask,
joint_seq_len=full_key.shape[2],
txt_seq_len=text_seq_len,
)
# Compute attention
hidden_states = attn._compute_attention_qwen(
query=joint_query,
key=mx.transpose(full_key, (0, 2, 1, 3)),
value=mx.transpose(full_value, (0, 2, 1, 3)),
mask=mask,
block_idx=block_idx,
)
# Extract outputs
txt_attn_output = hidden_states[:, :text_seq_len, :]
img_attn_output = hidden_states[:, text_seq_len:, :]
# Project
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# Residual + gate
patch_hidden = patch_hidden + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Feed-forward for image
img_normed2 = block.img_norm2(patch_hidden)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, img_mod2
)
img_mlp_output = block.img_ff(img_modulated2)
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
# Feed-forward for text
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, txt_mod2
)
txt_mlp_output = block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, patch_hidden

View File

@@ -1,23 +0,0 @@
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
ModelAdapter,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
__all__ = [
"BlockWrapperMode",
"DiffusionRunner",
"ImagePatchKVCache",
"JointBlockInterface",
"JointBlockWrapper",
"ModelAdapter",
"SingleBlockInterface",
"SingleBlockWrapper",
]

View File

@@ -1,402 +0,0 @@
from enum import Enum
from pathlib import Path
from typing import Any, Protocol
import mlx.core as mx
from mflux.config.runtime_config import RuntimeConfig
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class AttentionInterface(Protocol):
num_heads: int
head_dimension: int
to_q: Any
to_k: Any
to_v: Any
norm_q: Any
norm_k: Any
to_out: list[Any]
class JointAttentionInterface(AttentionInterface, Protocol):
add_q_proj: Any
add_k_proj: Any
add_v_proj: Any
norm_added_q: Any
norm_added_k: Any
to_add_out: Any
class JointBlockInterface(Protocol):
attn: JointAttentionInterface
norm1: Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
norm1_context: (
Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
)
norm2: Any
norm2_context: Any
ff: Any
ff_context: Any
class SingleBlockInterface(Protocol):
attn: AttentionInterface
norm: Any # Callable module: (hidden_states, text_embeddings) -> tuple[2 arrays]
def _apply_feed_forward_and_projection(
self, norm_hidden_states: mx.array, attn_output: mx.array, gate: mx.array
) -> mx.array:
"""Apply feed forward network and projection."""
...
class BlockWrapperMode(Enum):
CACHING = "caching" # Sync mode: compute full attention, populate cache
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
class PromptData(Protocol):
"""Protocol for encoded prompt data.
All adapters must return prompt data that conforms to this protocol.
Model-specific prompt data classes can add additional attributes
(e.g., attention masks for Qwen).
"""
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
...
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
...
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Negative prompt embeddings for CFG (None if not using CFG)."""
...
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Negative pooled embeddings for CFG (None if not using CFG)."""
...
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return model-specific kwargs for forward pass.
Args:
positive: If True, return kwargs for positive prompt pass.
If False, return kwargs for negative prompt pass.
Returns:
Dict of extra kwargs (e.g., {"encoder_hidden_states_mask": ...} for Qwen)
"""
...
@property
def conditioning_latents(self) -> mx.array | None:
"""Conditioning latents for edit mode.
Returns:
Conditioning latents array for image editing, None for standard generation.
"""
...
class ModelAdapter(Protocol):
@property
def config(self) -> ImageModelConfig:
"""Return the model configuration."""
...
@property
def model(self) -> Any:
"""Return the underlying mflux model instance (e.g., Flux1, Fibo, Qwen)."""
...
@property
def transformer(self) -> Any:
"""Return the transformer component of the model."""
...
@property
def hidden_dim(self) -> int:
"""Return the hidden dimension of the transformer."""
...
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute x_embedder and context_embedder outputs.
Args:
hidden_states: Input latent states
prompt_embeds: Text embeddings from encoder
Returns:
Tuple of (embedded_hidden_states, embedded_encoder_states)
"""
...
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings for conditioning.
Args:
t: Current timestep
runtime_config: Runtime configuration
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
hidden_states: Image hidden states
Returns:
Text embeddings tensor
"""
...
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> Any:
"""Compute rotary position embeddings.
Args:
prompt_embeds: Text embeddings
runtime_config: Runtime configuration
**kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask for Qwen)
Returns:
Flux: mx.array
Qwen: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]
"""
...
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # Format varies: mx.array (Flux) or nested tuple (Qwen)
kv_cache: ImagePatchKVCache | None,
mode: "BlockWrapperMode",
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply a joint transformer block.
Args:
block: The joint transformer block
hidden_states: Image hidden states
encoder_hidden_states: Text hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings (format varies by model)
kv_cache: KV cache (None if not using cache)
mode: CACHING or PATCHED mode
text_seq_len: Text sequence length
patch_start: Start index for patched mode
patch_end: End index for patched mode
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
block_idx for Qwen)
Returns:
Tuple of (encoder_hidden_states, hidden_states)
"""
...
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: "BlockWrapperMode",
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Apply a single transformer block.
Args:
block: The single transformer block
hidden_states: Concatenated [text + image] hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
kv_cache: KV cache (None if not using cache)
mode: CACHING or PATCHED mode
text_seq_len: Text sequence length
patch_start: Start index for patched mode
patch_end: End index for patched mode
Returns:
Output hidden states
"""
...
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final norm and projection.
Args:
hidden_states: Hidden states (image only, text already removed)
text_embeddings: Conditioning embeddings
Returns:
Projected output
"""
...
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Get the list of joint transformer blocks from the model."""
...
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Get the list of single transformer blocks from the model."""
...
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
total_joint_blocks: Total number of joint blocks in the model
total_single_blocks: Total number of single blocks in the model
"""
...
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams for transition to single blocks.
This is called at the transition point from joint blocks (which process
image and text separately) to single blocks (which process them
together). Override to customize the merge strategy.
Args:
hidden_states: Image hidden states
encoder_hidden_states: Text hidden states
Returns:
Merged hidden states (default: concatenate [text, image])
"""
...
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
"""Create initial noise latents for generation.
Args:
seed: Random seed
runtime_config: Runtime configuration with dimensions
Returns:
Initial latent tensor
"""
...
def encode_prompt(self, prompt: str) -> PromptData:
"""Encode prompt into model-specific prompt data.
Args:
prompt: Text prompt
Returns:
PromptData containing embeddings (and model-specific extras)
"""
...
@property
def needs_cfg(self) -> bool:
"""Whether this model uses classifier-free guidance.
Returns:
True if model requires two forward passes with guidance (e.g., Qwen)
False if model uses a single forward pass (e.g., Flux)
"""
...
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
"""Apply classifier-free guidance to combine positive/negative predictions.
Only called when needs_cfg is True.
Args:
noise_positive: Noise prediction from positive prompt
noise_negative: Noise prediction from negative prompt
guidance_scale: Guidance strength
Returns:
Guided noise prediction
"""
...
def decode_latents(
self,
latents: mx.array,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to final image.
Args:
latents: Final denoised latents
runtime_config: Runtime configuration
seed: Random seed (for metadata)
prompt: Text prompt (for metadata)
Returns:
GeneratedImage result
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Compute and store dimensions from input image for edit mode.
For edit adapters: computes dimensions from input image aspect ratio,
stores image paths internally for encode_prompt(), returns (width, height).
For standard adapters: returns None (use user-specified dimensions).
Args:
image_path: Path to the input image
Returns:
Tuple of (width, height) if dimensions were computed, None otherwise.
"""
...

View File

@@ -1,146 +0,0 @@
from typing import Any
import mlx.core as mx
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
ModelAdapter,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class JointBlockWrapper:
"""Unified wrapper for joint transformer blocks.
Handles both CACHING (sync) and PATCHED (async) modes by delegating
to the model adapter for model-specific attention computation.
The wrapper is created once at initialization and reused across calls.
Mode and KV cache are passed at call time to support switching between
sync and async pipelines.
"""
def __init__(
self,
block: JointBlockInterface,
adapter: ModelAdapter,
):
"""Initialize the joint block wrapper.
Args:
block: The joint transformer block to wrap
adapter: Model adapter for model-specific operations
"""
self.block = block
self.adapter = adapter
def __call__(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
text_seq_len: int,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply the joint block.
Args:
hidden_states: Image hidden states (full or patch depending on mode)
encoder_hidden_states: Text hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
text_seq_len: Text sequence length
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
mode: CACHING (populate cache) or PATCHED (use cached K/V)
patch_start: Start index for patched mode (required if mode=PATCHED)
patch_end: End index for patched mode (required if mode=PATCHED)
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
block_idx for Qwen)
Returns:
Tuple of (encoder_hidden_states, hidden_states)
"""
return self.adapter.apply_joint_block(
block=self.block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
mode=mode,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
**kwargs,
)
class SingleBlockWrapper:
"""Unified wrapper for single transformer blocks.
Handles both CACHING (sync) and PATCHED (async) modes by delegating
to the model adapter for model-specific attention computation.
The wrapper is created once at initialization and reused across calls.
Mode and KV cache are passed at call time to support switching between
sync and async pipelines.
"""
def __init__(
self,
block: SingleBlockInterface,
adapter: ModelAdapter,
):
"""Initialize the single block wrapper.
Args:
block: The single transformer block to wrap
adapter: Model adapter for model-specific operations
"""
self.block = block
self.adapter = adapter
def __call__(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
text_seq_len: int,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Apply the single block.
Args:
hidden_states: [text + image] hidden states (full or patch depending on mode)
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
text_seq_len: Text sequence length
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
mode: CACHING (populate cache) or PATCHED (use cached K/V)
patch_start: Start index for patched mode (required if mode=PATCHED)
patch_end: End index for patched mode (required if mode=PATCHED)
Returns:
Output hidden states
"""
return self.adapter.apply_single_block(
block=self.block,
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
mode=mode,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)

View File

@@ -1,72 +0,0 @@
import mlx.core as mx
class ImagePatchKVCache:
"""KV cache that stores only IMAGE K/V with patch-level updates.
Only caches image K/V since:
- Text K/V is always computed fresh (same for all patches)
- Only image portion needs stale/fresh cache management across patches
"""
def __init__(
self,
batch_size: int,
num_heads: int,
image_seq_len: int,
head_dim: int,
dtype: mx.Dtype = mx.float32,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.image_seq_len = image_seq_len
self.head_dim = head_dim
self._dtype = dtype
self.key_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
self.value_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
def update_image_patch(
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
) -> None:
"""Update cache with fresh K/V for an image patch slice.
Args:
patch_start: Start token index within image portion (0-indexed)
patch_end: End token index within image portion
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
"""
self.key_cache[:, :, patch_start:patch_end, :] = key
self.value_cache[:, :, patch_start:patch_end, :] = value
def get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
Args:
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
Returns:
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
"""
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
return full_key, full_value
def reset(self) -> None:
"""Reset cache to zeros."""
self.key_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)
self.value_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)

View File

@@ -1,975 +0,0 @@
from math import ceil
from typing import Any, Optional
import mlx.core as mx
from mflux.callbacks.callbacks import Callbacks
from mflux.config.config import Config
from mflux.config.runtime_config import RuntimeConfig
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
ModelAdapter,
PromptData,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
def calculate_patch_heights(latent_height: int, num_patches: int):
patch_height = ceil(latent_height / num_patches)
actual_num_patches = ceil(latent_height / patch_height)
patch_heights = [patch_height] * (actual_num_patches - 1)
last_height = latent_height - patch_height * (actual_num_patches - 1)
patch_heights.append(last_height)
return patch_heights, actual_num_patches
def calculate_token_indices(patch_heights: list[int], latent_width: int):
tokens_per_row = latent_width
token_ranges = []
cumulative_height = 0
for h in patch_heights:
start_token = tokens_per_row * cumulative_height
end_token = tokens_per_row * (cumulative_height + h)
token_ranges.append((start_token, end_token))
cumulative_height += h
return token_ranges
class DiffusionRunner:
"""Orchestrates the diffusion loop for image generation.
This class owns the entire diffusion process, handling both single-node
and distributed (PipeFusion) modes.
In distributed mode, it implements PipeFusion with:
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
- Async pipeline for later timesteps (patches processed independently)
"""
def __init__(
self,
config: ImageModelConfig,
adapter: ModelAdapter,
group: Optional[mx.distributed.Group],
shard_metadata: PipelineShardMetadata,
num_sync_steps: int = 1,
num_patches: Optional[int] = None,
):
"""Initialize the diffusion runner.
Args:
config: Model configuration (architecture, block counts, etc.)
adapter: Model adapter for model-specific operations
group: MLX distributed group (None for single-node mode)
shard_metadata: Pipeline shard metadata with layer assignments
num_sync_steps: Number of synchronous timesteps before async mode
num_patches: Number of patches for async mode (defaults to world_size)
"""
self.config = config
self.adapter = adapter
self.group = group
# Handle single-node vs distributed mode
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_sync_steps = num_sync_steps
self.num_patches = num_patches if num_patches else max(1, self.world_size)
# Persistent KV caches (initialized on first async timestep, reused across timesteps)
self.joint_kv_caches: list[ImagePatchKVCache] | None = None
self.single_kv_caches: list[ImagePatchKVCache] | None = None
# Get block counts from config (model-agnostic)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
self.total_layers = config.total_blocks
self._compute_assigned_blocks()
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
end = self.end_layer
if end <= self.total_joint:
# All assigned blocks are joint blocks
self.joint_start = start
self.joint_end = end
self.single_start = 0
self.single_end = 0
elif start >= self.total_joint:
# All assigned blocks are single blocks
self.joint_start = 0
self.joint_end = 0
self.single_start = start - self.total_joint
self.single_end = end - self.total_joint
else:
# Stage spans joint→single transition
self.joint_start = start
self.joint_end = self.total_joint
self.single_start = 0
self.single_end = end - self.total_joint
self.has_joint_blocks = self.joint_end > self.joint_start
self.has_single_blocks = self.single_end > self.single_start
self.owns_concat_stage = self.has_joint_blocks and (
self.has_single_blocks or self.end_layer == self.total_joint
)
joint_blocks = self.adapter.get_joint_blocks()
single_blocks = self.adapter.get_single_blocks()
# Wrap blocks at initialization (reused across all calls)
self.joint_block_wrappers = [
JointBlockWrapper(block=block, adapter=self.adapter)
for block in joint_blocks
]
self.single_block_wrappers = [
SingleBlockWrapper(block=block, adapter=self.adapter)
for block in single_blocks
]
@property
def is_first_stage(self) -> bool:
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
return self.group is not None
def _calculate_capture_steps(
self,
partial_images: int,
init_time_step: int,
num_inference_steps: int,
) -> set[int]:
"""Calculate which timesteps should produce partial images.
Evenly spaces `partial_images` captures across the diffusion loop.
Does NOT include the final timestep (that's the complete image).
Args:
partial_images: Number of partial images to capture
init_time_step: Starting timestep (for img2img this may not be 0)
num_inference_steps: Total inference steps
Returns:
Set of timestep indices to capture
"""
if partial_images <= 0:
return set()
total_steps = num_inference_steps - init_time_step
if total_steps <= 1:
return set()
if partial_images >= total_steps - 1:
# Capture every step except final
return set(range(init_time_step, num_inference_steps - 1))
# Evenly space partial captures
step_interval = total_steps / (partial_images + 1)
capture_steps: set[int] = set()
for i in range(1, partial_images + 1):
step_idx = int(init_time_step + i * step_interval)
# Ensure we don't capture the final step
if step_idx < num_inference_steps - 1:
capture_steps.add(step_idx)
return capture_steps
def generate_image(
self,
settings: Config,
prompt: str,
seed: int,
partial_images: int = 0,
):
"""Primary entry point for image generation.
Orchestrates the full generation flow:
1. Create runtime config
2. Create initial latents
3. Encode prompt
4. Run diffusion loop (yielding partials if requested)
5. Decode to image
When partial_images > 0, yields (GeneratedImage, partial_index, total_partials)
tuples for intermediate images, then yields the final GeneratedImage.
Args:
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
runtime_config = RuntimeConfig(settings, self.adapter.model.model_config)
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt)
# Calculate which steps to capture
capture_steps = self._calculate_capture_steps(
partial_images=partial_images,
init_time_step=runtime_config.init_time_step,
num_inference_steps=runtime_config.num_inference_steps,
)
# Run diffusion loop - may yield partial latents
diffusion_gen = self._run_diffusion_loop(
latents=latents,
prompt_data=prompt_data,
runtime_config=runtime_config,
seed=seed,
prompt=prompt,
capture_steps=capture_steps,
)
# Process partial yields and get final latents
partial_index = 0
total_partials = len(capture_steps)
if capture_steps:
# Generator mode - iterate to get partials and final latents
try:
while True:
partial_latents, _step = next(diffusion_gen)
if self.is_last_stage:
partial_image = self.adapter.decode_latents(
partial_latents, runtime_config, seed, prompt
)
yield (partial_image, partial_index, total_partials)
partial_index += 1
except StopIteration as e:
latents = e.value
else:
# No partials - just consume generator to get final latents
try:
while True:
next(diffusion_gen)
except StopIteration as e:
latents = e.value
# Yield final image (only on last stage)
if self.is_last_stage:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
def _run_diffusion_loop(
self,
latents: mx.array,
prompt_data: PromptData,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
capture_steps: set[int] | None = None,
):
"""Execute the diffusion loop, optionally yielding at capture steps.
When capture_steps is provided and non-empty, this becomes a generator
that yields (latents, step_index) tuples at the specified timesteps.
Only the last stage yields (others have incomplete latents).
Args:
latents: Initial noise latents
prompt_data: Encoded prompt data
runtime_config: RuntimeConfig with scheduler, steps, dimensions
seed: Random seed (for callbacks)
prompt: Text prompt (for callbacks)
capture_steps: Set of timestep indices to capture (None = no captures)
Yields:
(latents, step_index) tuples at capture steps (last stage only)
Returns:
Final denoised latents ready for VAE decoding
"""
if capture_steps is None:
capture_steps = set()
time_steps = tqdm(range(runtime_config.num_inference_steps))
# Call subscribers for beginning of loop
Callbacks.before_loop(
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
)
for t in time_steps:
try:
latents = self._diffusion_step(
t=t,
config=runtime_config,
latents=latents,
prompt_data=prompt_data,
)
# Call subscribers in-loop
Callbacks.in_loop(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
time_steps=time_steps,
)
mx.eval(latents)
# Yield partial latents at capture steps (only on last stage)
if t in capture_steps and self.is_last_stage:
yield (latents, t)
except KeyboardInterrupt: # noqa: PERF203
Callbacks.interruption(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
time_steps=time_steps,
)
raise StopImageGenerationException(
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
) from None
# Call subscribers after loop
Callbacks.after_loop(
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
)
return latents
def _forward_pass(
self,
latents: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
kwargs: dict[str, Any],
) -> mx.array:
"""Run a single forward pass through the transformer.
This is the internal method called by adapters via compute_step_noise.
Returns noise prediction without applying scheduler step.
For edit mode, concatenates conditioning latents with generated latents
before the transformer, and extracts only the generated portion after.
Args:
latents: Input latents (already scaled by caller)
prompt_embeds: Text embeddings
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask, t)
Returns:
Noise prediction tensor
"""
t = kwargs.get("t", 0)
config = kwargs.get("config")
if config is None:
raise ValueError("config must be provided in kwargs")
scaled_latents = config.scheduler.scale_model_input(latents, t)
# For edit mode: concatenate with conditioning latents
conditioning_latents = kwargs.get("conditioning_latents")
original_latent_tokens = scaled_latents.shape[1]
if conditioning_latents is not None:
scaled_latents = mx.concatenate(
[scaled_latents, conditioning_latents], axis=1
)
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
scaled_latents, prompt_embeds
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
)
rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds, config, **kwargs
)
text_seq_len = prompt_embeds.shape[1]
# Run through all joint blocks
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=None,
mode=BlockWrapperMode.CACHING,
block_idx=block_idx,
**kwargs,
)
# Merge streams
if self.joint_block_wrappers:
hidden_states = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
# Run through single blocks
for wrapper in self.single_block_wrappers:
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=None,
mode=BlockWrapperMode.CACHING,
)
# Extract image portion and project
hidden_states = hidden_states[:, text_seq_len:, ...]
# For edit mode: extract only the generated portion (exclude conditioning latents)
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
return self.adapter.final_projection(hidden_states, text_embeddings)
def _diffusion_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
"""Execute a single diffusion step.
Routes to single-node, sync pipeline, or async pipeline based on
configuration and current timestep.
"""
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < config.init_time_step + self.num_sync_steps:
return self._sync_pipeline(
t,
config,
latents,
prompt_data,
)
else:
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
)
def _single_node_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
"""Execute a single diffusion step on a single node (no distribution)."""
base_kwargs = {"t": t, "config": config}
# For edit mode: include conditioning latents
if prompt_data.conditioning_latents is not None:
base_kwargs["conditioning_latents"] = prompt_data.conditioning_latents
if self.adapter.needs_cfg:
# Two forward passes + guidance for CFG models (e.g., Qwen)
pos_kwargs = {
**base_kwargs,
**prompt_data.get_extra_forward_kwargs(positive=True),
}
noise_pos = self._forward_pass(
latents,
prompt_data.prompt_embeds,
prompt_data.pooled_prompt_embeds,
pos_kwargs,
)
neg_kwargs = {
**base_kwargs,
**prompt_data.get_extra_forward_kwargs(positive=False),
}
noise_neg = self._forward_pass(
latents,
prompt_data.negative_prompt_embeds,
prompt_data.negative_pooled_prompt_embeds,
neg_kwargs,
)
assert self.config.guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=self.config.guidance_scale
)
else:
# Single forward pass for non-CFG models (e.g., Flux)
kwargs = {**base_kwargs, **prompt_data.get_extra_forward_kwargs()}
noise = self._forward_pass(
latents,
prompt_data.prompt_embeds,
prompt_data.pooled_prompt_embeds,
kwargs,
)
return config.scheduler.step(model_output=noise, timestep=t, sample=latents)
def _initialize_kv_caches(
self,
batch_size: int,
num_img_tokens: int,
dtype: mx.Dtype,
) -> None:
"""Initialize KV caches for both sync and async pipelines.
Note: Caches only store IMAGE K/V, not text K/V. Text K/V is always
computed fresh and doesn't need caching (it's the same for all patches).
"""
self.joint_kv_caches = [
ImagePatchKVCache(
batch_size=batch_size,
num_heads=self.config.num_heads,
image_seq_len=num_img_tokens,
head_dim=self.config.head_dim,
dtype=dtype,
)
for _ in range(len(self.joint_block_wrappers))
]
self.single_kv_caches = [
ImagePatchKVCache(
batch_size=batch_size,
num_heads=self.config.num_heads,
image_seq_len=num_img_tokens,
head_dim=self.config.head_dim,
dtype=dtype,
)
for _ in range(len(self.single_block_wrappers))
]
def _create_patches(
self,
latents: mx.array,
config: RuntimeConfig,
) -> tuple[list[mx.array], list[tuple[int, int]]]:
"""Split latents into patches for async pipeline."""
# Use 16 to match FluxLatentCreator.create_noise formula
latent_height = config.height // 16
latent_width = config.width // 16
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
token_indices = calculate_token_indices(patch_heights, latent_width)
# Split latents into patches
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
return patch_latents, token_indices
def _sync_pipeline(
self,
t: int,
config: RuntimeConfig,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
prompt_embeds = prompt_data.prompt_embeds
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
extra_kwargs = prompt_data.get_extra_forward_kwargs()
hidden_states = config.scheduler.scale_model_input(hidden_states, t)
# For edit mode: handle conditioning latents
# All stages need to know the total token count for correct recv templates
conditioning_latents = prompt_data.conditioning_latents
original_latent_tokens = hidden_states.shape[1]
if conditioning_latents is not None:
num_img_tokens = original_latent_tokens + conditioning_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
# First stage: concatenate conditioning latents before embedding
if self.is_first_stage and conditioning_latents is not None:
hidden_states = mx.concatenate(
[hidden_states, conditioning_latents], axis=1
)
# === PHASE 1: Embeddings ===
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
hidden_states, prompt_embeds
)
# All stages need these for their blocks
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
kontext_image_ids=kontext_image_ids,
**extra_kwargs,
)
# === Initialize KV caches to populate during sync for async warmstart ===
batch_size = prev_latents.shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
if t == config.init_time_step:
self._initialize_kv_caches(
batch_size=batch_size,
num_img_tokens=num_img_tokens,
dtype=prev_latents.dtype,
)
# === PHASE 2: Joint Blocks with Communication and Caching ===
if self.has_joint_blocks:
# Receive from previous stage (if not first stage)
if not self.is_first_stage:
recv_template = mx.zeros(
(batch_size, num_img_tokens, hidden_dim), dtype=prev_latents.dtype
)
hidden_states = mx.distributed.recv_like(
recv_template, self.prev_rank, group=self.group
)
enc_template = mx.zeros(
(batch_size, text_seq_len, hidden_dim), dtype=prev_latents.dtype
)
encoder_hidden_states = mx.distributed.recv_like(
enc_template, self.prev_rank, group=self.group
)
mx.eval(hidden_states, encoder_hidden_states)
# Run assigned joint blocks with caching mode
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.joint_kv_caches[block_idx],
mode=BlockWrapperMode.CACHING,
**extra_kwargs,
)
# === PHASE 3: Joint→Single Transition ===
if self.owns_concat_stage:
# Merge encoder and hidden states using adapter hook
concatenated = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
if self.has_single_blocks or self.is_last_stage:
# Keep locally: either for single blocks or final projection
hidden_states = concatenated
else:
# Send concatenated state to next stage (which has single blocks)
mx.eval(
mx.distributed.send(concatenated, self.next_rank, group=self.group)
)
elif self.has_joint_blocks and not self.is_last_stage:
# Send joint block outputs to next stage (which has more joint blocks)
mx.eval(
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
),
)
# === PHASE 4: Single Blocks with Communication and Caching ===
if self.has_single_blocks:
# Receive from previous stage if we didn't do concatenation
if not self.owns_concat_stage and not self.is_first_stage:
recv_template = mx.zeros(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype=prev_latents.dtype,
)
hidden_states = mx.distributed.recv_like(
recv_template, self.prev_rank, group=self.group
)
mx.eval(hidden_states)
# Run assigned single blocks with caching mode
for block_idx, wrapper in enumerate(self.single_block_wrappers):
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.single_kv_caches[block_idx],
mode=BlockWrapperMode.CACHING,
)
# Send to next stage if not last
if not self.is_last_stage:
mx.eval(
mx.distributed.send(hidden_states, self.next_rank, group=self.group)
)
# === PHASE 5: Last Stage - Final Projection + Scheduler ===
# Extract image portion (remove text embeddings prefix)
hidden_states = hidden_states[:, text_seq_len:, ...]
# For edit mode: extract only the generated portion (exclude conditioning latents)
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
if self.is_last_stage:
hidden_states = self.adapter.final_projection(
hidden_states, text_embeddings
)
hidden_states = config.scheduler.step(
model_output=hidden_states,
timestep=t,
sample=prev_latents,
)
if not self.is_first_stage:
mx.eval(mx.distributed.send(hidden_states, 0, group=self.group))
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
# For shape correctness
hidden_states = prev_latents
return hidden_states
def _async_pipeline_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
patch_latents = self._async_pipeline(
t,
config,
patch_latents,
token_indices,
prompt_data,
kontext_image_ids,
)
return mx.concatenate(patch_latents, axis=1)
def _async_pipeline(
self,
t: int,
config: RuntimeConfig,
patch_latents: list[mx.array],
token_indices: list[tuple[int, int]],
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> list[mx.array]:
"""Execute async pipeline for all patches."""
assert self.joint_kv_caches is not None
assert self.single_kv_caches is not None
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
prompt_embeds = prompt_data.prompt_embeds
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
extra_kwargs = prompt_data.get_extra_forward_kwargs()
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
kontext_image_ids=kontext_image_ids,
**extra_kwargs,
)
batch_size = patch_latents[0].shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
for patch_idx, patch in enumerate(patch_latents):
patch_prev = patch
start_token, end_token = token_indices[patch_idx]
if self.has_joint_blocks:
if (
not self.is_first_stage
or t != config.init_time_step + self.num_sync_steps
):
if self.is_first_stage:
# First stage receives latent-space from last stage (scheduler output)
recv_template = patch
else:
# Other stages receive hidden-space from previous stage
patch_len = patch.shape[1]
recv_template = mx.zeros(
(batch_size, patch_len, hidden_dim),
dtype=patch.dtype,
)
patch = mx.distributed.recv_like(
recv_template, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch_latents[patch_idx] = patch
if not self.is_first_stage and patch_idx == 0:
enc_template = mx.zeros(
(batch_size, text_seq_len, hidden_dim),
dtype=patch_latents[0].dtype,
)
encoder_hidden_states = mx.distributed.recv_like(
enc_template, src=self.prev_rank, group=self.group
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
patch, prompt_embeds
)
# Run assigned joint blocks with patched mode
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.joint_kv_caches[block_idx],
mode=BlockWrapperMode.PATCHED,
patch_start=start_token,
patch_end=end_token,
**extra_kwargs,
)
if self.owns_concat_stage:
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
if self.has_single_blocks or self.is_last_stage:
# Keep locally: either for single blocks or final projection
patch = patch_concat
else:
mx.eval(
mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
)
elif self.has_joint_blocks and not self.is_last_stage:
mx.eval(mx.distributed.send(patch, self.next_rank, group=self.group))
if patch_idx == 0:
mx.eval(
mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
recv_template = mx.zeros(
[
batch_size,
text_seq_len + patch_latents[patch_idx].shape[1],
hidden_dim,
],
dtype=patch_latents[0].dtype,
)
patch = mx.distributed.recv_like(
recv_template, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch_latents[patch_idx] = patch
# Run assigned single blocks with patched mode
for block_idx, wrapper in enumerate(self.single_block_wrappers):
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.single_kv_caches[block_idx],
mode=BlockWrapperMode.PATCHED,
patch_start=start_token,
patch_end=end_token,
)
if not self.is_last_stage:
mx.eval(
mx.distributed.send(patch, self.next_rank, group=self.group)
)
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
patch_img_only = self.adapter.final_projection(
patch_img_only, text_embeddings
)
patch = config.scheduler.step(
model_output=patch_img_only,
timestep=t,
sample=patch_prev,
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
mx.eval(
mx.distributed.send(patch, self.next_rank, group=self.group)
)
patch_latents[patch_idx] = patch
return patch_latents

View File

@@ -10,18 +10,23 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.cache import (
_BaseCache, # pyright: ignore[reportPrivateUsage]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
)
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
class _LayerCallable(Protocol):
@@ -91,8 +96,6 @@ class PipelineLastLayer(CustomMlxLayer):
x, *args, **kwargs
).arguments.get("cache", None)
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
@@ -100,10 +103,8 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# This change happened upstream - check out mlx github somewhere??
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# TODO(ciaran): This is overkill
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
return output
@@ -133,24 +134,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
return layers
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(model, DeepseekV3Model) and hasattr(
inner_model_instance, "num_layers"
):
inner_model_instance.start_idx = 0
inner_model_instance.end_idx = len(layers)
inner_model_instance.num_layers = len(layers)
elif hasattr(inner_model_instance, "h"):
inner_model_instance.h = layers
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
def pipeline_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
@@ -166,8 +149,7 @@ def pipeline_auto_parallel(
"""
inner_model_instance: nn.Module = _inner_model(model)
# Handle both model.layers and model.h cases
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
@@ -181,6 +163,17 @@ def pipeline_auto_parallel(
group=group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -205,18 +198,44 @@ def tensor_auto_parallel(
group=group,
)
segments: int = 1
def _all_to_sharded(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - all to sharded")
return weight.ndim - 1, segments
return max(weight.ndim - 2, 0), segments
all_to_sharded_linear_in_place = partial(
shard_inplace,
sharding="all-to-sharded",
group=group,
)
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding="sharded-to-all",
sharding=_all_to_sharded, # type: ignore
group=group,
)
if isinstance(model, LlamaModel):
n = group.size()
def _sharded_to_all(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - sharded to all")
weight /= n
return None
return -1, segments
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding=_sharded_to_all, # type: ignore
group=group,
)
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -224,7 +243,8 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, DeepseekV3Model):
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -232,7 +252,7 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Qwen3MoeModel):
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -240,6 +260,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -285,13 +314,38 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
) and hasattr(inner_model_instance, "num_layers"):
logger.info(
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
)
inner_model_instance.start_idx = 0
inner_model_instance.end_idx = len(layers)
inner_model_instance.num_layers = len(layers)
elif isinstance(model, Qwen3MoeModel):
logger.info(
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
)
inner_model_instance.num_hidden_layers = len(layers)
elif hasattr(inner_model_instance, "h"):
inner_model_instance.h = layers
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
# Shard the self attention
if layer.self_attn.q_lora_rank is None: # pyright: ignore[reportUnnecessaryComparison]
# Unfortunately, q_lora_rank can be None despite typing hints.
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
@@ -306,7 +360,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.num_heads //= self.N
# Shard the MLP
if isinstance(layer.mlp, DeepseekV3MLP):
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
@@ -354,7 +408,9 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
if isinstance(
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
@@ -382,3 +438,50 @@ class ShardedQwenMoE(CustomMlxLayer):
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn.num_key_value_groups = (
layer.self_attn.num_attention_heads
// layer.self_attn.num_key_value_heads
)
layer.self_attn.sinks = layer.self_attn.sinks[
layer.self_attn.num_attention_heads
* self.group.rank() : layer.self_attn.num_attention_heads
* (self.group.rank() + 1)
]
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
return model
class ShardedGptOssMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -1,12 +1,19 @@
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import ChatCompletionMessage, FinishReason
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
FinishReason,
GenerationStats,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
@@ -41,7 +48,6 @@ def maybe_quantize_kv_cache(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -64,6 +70,9 @@ def warmup_inference(
model=model,
)
# Use a default sampler for warmup
sampler = make_sampler(temp=0.7)
logger.info("Generating warmup tokens")
for _r in stream_generate(
model=model,
@@ -72,7 +81,7 @@ def warmup_inference(
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=65536,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
@@ -80,20 +89,47 @@ def warmup_inference(
tokens_generated += 1
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
token_ids = [int(t) for t in token_ids]
def proc(_history: mx.array, logits: mx.array) -> mx.array:
for tid in token_ids:
logits[..., tid] = -1e9
return logits
return proc
def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
eos: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
if eos is None:
return []
return eos
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
task: ChatCompletionTaskParams,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
# Currently we support chat-completion tasks only.
logger.info(f"task_params: {task}")
if task.seed is not None:
mx.random.seed(task.seed)
prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=task,
@@ -101,6 +137,17 @@ def mlx_generate(
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
eos_ids = eos_ids_from_tokenizer(tokenizer)
logits_processors = [ban_token_ids(eos_ids)]
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
@@ -108,26 +155,40 @@ def mlx_generate(
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
prefill_step_size=65536,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info(out.text)
if out.finish_reason is not None and out.finish_reason not in get_args(
FinishReason
):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
)
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -1,13 +1,25 @@
import json
import os
import resource
import sys
import time
from pathlib import Path
from typing import Any, Callable, cast
from typing import Any, cast
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
try:
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
from transformers.convert_slow_tokenizer import bytes_to_unicode
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -19,7 +31,7 @@ from exo.worker.engines.mlx.constants import (
try:
from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
from mlx_lm.tokenizer_utils import load as load_tokenizer
import contextlib
import mlx.core as mx
@@ -176,11 +188,7 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
# TODO: pass temperature
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
logger.info("Created a sampler")
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
@@ -201,7 +209,7 @@ def load_mlx_items(
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
return cast(Model, model), tokenizer, sampler
return cast(Model, model), tokenizer
def shard_and_load(
@@ -257,26 +265,70 @@ def shard_and_load(
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
# TODO: Let's move away from this custom logic to mlx_lm.load()
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [163586]
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
elif "glm" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [151336, 151329, 151338]
else:
eos_token_ids = None
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
tokenizer = cast(
TokenizerWrapper,
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
),
Some models require explicit EOS token configuration that isn't in their
tokenizer config. This function returns the known EOS token IDs for such models.
Args:
model_id: The HuggingFace model ID
Returns:
List of EOS token IDs, or None if the model uses standard tokenizer config
"""
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
This is the core tokenizer loading logic, handling special cases for different
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
Args:
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
model_path: Local path where the model/tokenizer files are stored
Returns:
TokenizerWrapper instance configured for the model
"""
model_id_lower = model_id.lower()
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
if "kimi-k2" in model_id_lower:
sys.path.insert(0, str(model_path))
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
# Patch encode to use internal tiktoken model directly
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
def _patched_encode(text: str, **_kwargs: object) -> list[int]:
# Pass allowed_special="all" to handle special tokens like <|im_user|>
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
tokenizer = load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
)
assert isinstance(tokenizer, TokenizerWrapper)
return tokenizer
@@ -289,15 +341,15 @@ def apply_chat_template(
messages = chat_task_data.messages
formatted_messages: list[dict[str, Any]] = []
for _, message in enumerate(messages):
for message in messages:
if isinstance(message.content, ChatCompletionMessageText):
message.content = message.content.text
if isinstance(message.content, list):
if len(message.content) != 1:
logger.warning("Received malformed prompt")
if len(message.content) == 0:
logger.warning("Received prompt with no content, skipping")
continue
message.content = message.content[0].text
message.content = "\n".join(c.text for c in message.content).strip()
if message.content is None and message.thinking is None:
continue
@@ -306,13 +358,14 @@ def apply_chat_template(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
prompt: str = tokenizer.apply_chat_template( # type: ignore
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
)
return prompt # type: ignore
return prompt
class NullKVCache(KVCache):
@@ -397,3 +450,13 @@ def set_wired_limit_for_model(model_size: Memory):
)
mx.set_wired_limit(max_rec_size)
logger.info(f"Wired limit set to {max_rec_size}.")
def mlx_cleanup(
model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None
) -> None:
del model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()

View File

@@ -8,15 +8,13 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
@@ -32,7 +30,6 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadModel,
ImageEdits,
Shutdown,
Task,
TaskStatus,
@@ -98,10 +95,6 @@ class Worker:
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
async def run(self):
logger.info("Starting Worker")
@@ -180,17 +173,6 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
await anyio.sleep(0.1)
@@ -203,8 +185,6 @@ class Worker:
self.state.instances,
self.state.runners,
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
)
if task is None:
continue
@@ -237,7 +217,9 @@ class Worker:
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
@@ -266,42 +248,6 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
chunks = self.input_chunk_buffer.get(cmd_id, {})
assembled = "".join(chunks[i] for i in range(len(chunks)))
logger.info(
f"Assembled input image from {len(chunks)} chunks, "
f"total size: {len(assembled)} bytes"
)
# Create modified task with assembled image data
modified_task = ImageEdits(
task_id=task.task_id,
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsInternalParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
model=task.task_params.model,
n=task.task_params.n,
quality=task.task_params.quality,
output_format=task.task_params.output_format,
response_format=task.task_params.response_format,
size=task.task_params.size,
image_strength=task.task_params.image_strength,
),
)
# Cleanup buffers
if cmd_id in self.input_chunk_buffer:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
@@ -420,7 +366,11 @@ class Worker:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
@@ -513,7 +463,9 @@ class Worker:
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:

View File

@@ -2,15 +2,13 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -51,8 +49,6 @@ def plan(
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
@@ -62,7 +58,7 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _pending_tasks(runners, tasks, all_runners)
)
@@ -266,28 +262,14 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
# TODO(ciaran): do this better!
if (
not isinstance(task, ChatCompletion)
and not isinstance(task, ImageGeneration)
and not isinstance(task, ImageEdits)
):
if not isinstance(task, ChatCompletion):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue
# For ImageEdits tasks, verify all input chunks have been received
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue

View File

@@ -6,7 +6,7 @@ from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import MpReceiver, MpSender
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
logger: "loguru.Logger" = loguru.logger
@@ -31,6 +31,8 @@ def entrypoint(
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:
logger.opt(exception=e).warning(
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
@@ -42,8 +44,10 @@ def entrypoint(
)
)
finally:
event_sender.close()
task_receiver.close()
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")
try:
event_sender.close()
task_receiver.close()
finally:
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")

View File

@@ -1,10 +1,9 @@
import base64
import time
from exo.master.api import get_model_card
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
import mlx.core as mx
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -12,12 +11,9 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelTask
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -27,8 +23,6 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -44,14 +38,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.engines.image import (
ImageGenerator,
generate_image,
initialize_image_model,
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -71,368 +58,153 @@ def main(
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
try:
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(shard_metadata, "should_timeout", 0):
time.sleep(timeout)
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
if timeout := getattr(shard_metadata, "should_timeout", 0):
time.sleep(timeout)
setup_start_time = time.time()
setup_start_time = time.time()
model = None
tokenizer = None
sampler = None
group = None
model = None
tokenizer = None
group = None
model_card = get_model_card(shard_metadata.model_meta.model_id)
assert model_card
model_tasks = model_card.tasks
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
logger.info("runner connected")
current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected)
and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
# TODO(ciaran): switch
if ModelTask.TextGeneration in model_tasks:
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {model_card.tasks}"
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in model_tasks:
assert model and isinstance(model, Model)
assert tokenizer
assert sampler
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
sampler=sampler,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
assert isinstance(model, ImageGenerator)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(
f"warmed up by generating {image.size} image"
)
else:
logger.info("warmup completed (non-primary node)")
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert model and isinstance(model, Model)
assert tokenizer
assert sampler
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
sampler=sampler,
task=task_params,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
),
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, ImageGenerator)
logger.info(
f"received image generation request: {str(task)[:500]}"
)
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(
model=model,
task=task_params,
):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
encoded_data = base64.b64encode(
response.image_data
).decode("utf-8")
# Split into chunks to stay under gossipsub 1MB limit
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
)
]
total_chunks = len(data_chunks)
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}: {len(encoded_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(
data_chunks
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=shard_metadata.model_meta.model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=response.partial_index,
is_partial=True,
partial_index=response.partial_index,
total_partials=response.total_partials,
),
)
)
case ImageGenerationResponse():
encoded_data = base64.b64encode(
response.image_data
).decode("utf-8")
# Split into chunks to stay under gossipsub 1MB limit
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
)
]
total_chunks = len(data_chunks)
logger.info(
f"sending final ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(
data_chunks
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=shard_metadata.model_meta.model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=False,
),
)
)
image_index += 1
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, ImageGenerator)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
image_index = 0
for response in generate_image(
model=model,
task=task_params,
):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case ImageGenerationResponse():
encoded_data = base64.b64encode(
response.image_data
).decode("utf-8")
# Split into chunks to stay under gossipsub 1MB limit
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
)
]
total_chunks = len(data_chunks)
logger.info(
f"sending ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(
data_chunks
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=shard_metadata.model_meta.model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=False,
),
)
)
image_index += 1
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if isinstance(current_status, RunnerShutdown):
break
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")
except Exception as e:
logger.opt(exception=e).warning(
f"Runner {runner_id} crashed with critical exception {e}"
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id,
runner_status=RunnerFailed(error_message=str(e)),
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
)
finally:
event_sender.close()
task_receiver.close()
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
logger.info("runner connected")
current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected) and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
model, tokenizer = load_mlx_items(bound_instance, group)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"

View File

@@ -0,0 +1,386 @@
"""
Unit tests for tokenizer loading and functionality across all supported models.
This test downloads only tokenizer-related files (not full model weights) to verify
that tokenizers can be loaded and used correctly for encoding/decoding.
"""
import asyncio
import contextlib
from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,
)
# Files needed for tokenizer functionality
TOKENIZER_FILE_PATTERNS = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"vocab.txt",
"merges.txt",
"tiktoken.model",
"added_tokens.json",
"tokenizer.model",
"tokenization_*.py", # Custom tokenizer implementations
]
def is_tokenizer_file(filename: str) -> bool:
"""Check if a file is needed for tokenizer functionality."""
for pattern in TOKENIZER_FILE_PATTERNS:
if "*" in pattern:
prefix = pattern.split("*")[0]
suffix = pattern.split("*")[1]
if filename.startswith(prefix) and filename.endswith(suffix):
return True
elif filename == pattern:
return True
return False
async def download_tokenizer_files(model_id: str) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
tokenizer_files = [f for f in file_list if is_tokenizer_file(f.path)]
if not tokenizer_files:
pytest.skip(f"No tokenizer files found for {model_id}")
for file_entry in tokenizer_files:
with contextlib.suppress(FileNotFoundError):
await download_file_with_retry(
model_id, "main", file_entry.path, target_dir
)
return target_dir
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for short_id, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = short_id.split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (short_id, card)
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
@pytest.fixture(scope="module")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
# Verify required files exist
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
if not has_tokenizer:
pytest.skip(f"Required tokenizer files not found for {model_id}")
# Load tokenizer
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
# Test basic encoding
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
assert isinstance(encoded, list), f"encode() should return a list for {model_id}"
assert len(encoded) > 0, f"encode() should return non-empty list for {model_id}"
assert all(isinstance(t, int) for t in encoded), (
f"All tokens should be integers for {model_id}"
)
# Test decoding
decoded = tokenizer.decode(encoded)
assert isinstance(decoded, str), f"decode() should return a string for {model_id}"
assert test_text in decoded or decoded.strip() == test_text.strip(), (
f"decode(encode(x)) should preserve text for {model_id}: got {decoded!r}"
)
# Test with longer text
long_text = "The quick brown fox jumps over the lazy dog. " * 10
long_encoded = tokenizer.encode(long_text)
assert len(long_encoded) > len(encoded), (
f"Longer text should produce more tokens for {model_id}"
)
# Test empty string
empty_encoded = tokenizer.encode("")
assert isinstance(empty_encoded, list), (
f"encode('') should return a list for {model_id}"
)
# Test special characters
special_text = 'Hello!\n\tWorld? <test> & "quotes"'
special_encoded = tokenizer.encode(special_text)
assert len(special_encoded) > 0, f"Special chars should encode for {model_id}"
# Test unicode
unicode_text = "Hello 世界 🌍"
unicode_encoded = tokenizer.encode(unicode_text)
assert len(unicode_encoded) > 0, f"Unicode should encode for {model_id}"
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
if not has_tokenizer:
pytest.skip(f"Required tokenizer files not found for {model_id}")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Check for vocabulary size
empty_vocab: dict[str, int] = {}
vocab_size: int = getattr(tokenizer, "vocab_size", None) or len(
getattr(tokenizer, "get_vocab", lambda: empty_vocab)()
)
assert vocab_size > 0, f"Tokenizer should have vocab_size > 0 for {model_id}"
# Check for EOS token (either from tokenizer or explicitly provided)
has_eos = (
eos_token_ids is not None
or getattr(tokenizer, "eos_token_id", None) is not None
or getattr(tokenizer, "eos_token", None) is not None
)
assert has_eos, f"Tokenizer should have EOS token for {model_id}"
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
assert has_tokenizer, f"Required tokenizer files not found for {model_id}"
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
# Get special tokens from the tokenizer
special_tokens: list[str] = []
# Try to get special tokens from various sources
if hasattr(tokenizer, "all_special_tokens"):
special_tokens.extend(tokenizer.all_special_tokens)
elif hasattr(tokenizer, "_tokenizer") and hasattr(
tokenizer._tokenizer,
"all_special_tokens",
):
special_tokens.extend(tokenizer._tokenizer.all_special_tokens)
# Also check for common special token attributes
for attr in [
"bos_token",
"eos_token",
"pad_token",
"unk_token",
"sep_token",
"cls_token",
]:
token = getattr(tokenizer, attr, None)
if token is None and hasattr(tokenizer, "_tokenizer"):
token = getattr(tokenizer._tokenizer, attr, None)
if token and isinstance(token, str) and token not in special_tokens:
special_tokens.append(token)
# If we found special tokens, test encoding text that contains them
if special_tokens:
# Create text with special tokens interspersed
test_with_special = f"{special_tokens[0]}Hello world"
if len(special_tokens) > 1:
test_with_special += f"{special_tokens[1]}"
encoded = tokenizer.encode(test_with_special)
assert isinstance(encoded, list), (
f"encode() with special tokens should return list for {model_id}"
)
assert len(encoded) > 0, (
f"encode() with special tokens should return non-empty list for {model_id}"
)
assert all(isinstance(t, int) for t in encoded), (
f"All tokens should be integers for {model_id}"
)
# Verify we can decode
decoded = tokenizer.decode(encoded)
assert isinstance(decoded, str), f"decode() should return string for {model_id}"
# Test with angle-bracket tokens (common format for special tokens)
# These should not raise errors even if they're not actual special tokens
angle_bracket_text = "<|test|>Hello<|end|>"
encoded = tokenizer.encode(angle_bracket_text)
assert isinstance(encoded, list), (
f"encode() with angle brackets should return list for {model_id}"
)
assert len(encoded) > 0, (
f"encode() with angle brackets should be non-empty for {model_id}"
)
# Specifically test Kimi tokenizer since it has special handling
@pytest.mark.asyncio
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
# Ensure the custom tokenizer file exists
if not (model_path / "tokenization_kimi.py").exists():
pytest.skip("tokenization_kimi.py not found")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Test encode/decode cycle
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert len(encoded) > 0, "Kimi tokenizer should encode text"
assert isinstance(decoded, str), "Kimi tokenizer should decode to string"
# Test that the patched encode works (returns list of ints)
assert all(isinstance(t, int) for t in encoded), "Tokens should be integers"
# Test encoding text with special tokens (like from chat templates)
# This is critical - the warmup inference uses prompts with special tokens
special_token_text = "<|im_user|>user<|im_middle|>Hello<|im_end|><|im_assistant|>"
special_encoded = tokenizer.encode(special_token_text)
assert len(special_encoded) > 0, "Kimi tokenizer should handle special tokens"
assert all(isinstance(t, int) for t in special_encoded), (
"Special token encoding should return integers"
)
# Verify EOS token is set
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
# Test GLM tokenizer since it also has special handling
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
]
if not glm_models:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (model_path / "tokenizer.json").exists() or (
model_path / "tokenizer_config.json"
).exists()
if not has_tokenizer:
pytest.skip("GLM tokenizer files not found")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Test encode/decode
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert len(encoded) > 0, "GLM tokenizer should encode text"
assert isinstance(decoded, str), "GLM tokenizer should decode to string"
# Verify EOS tokens
assert eos_token_ids == [
151336,
151329,
151338,
], "GLM EOS tokens should be correct"

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -94,13 +95,23 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
],
}
result = plan_mod.plan(
@@ -140,7 +151,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,10 +205,16 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [], # NODE_B has no downloads completed yet
}
@@ -212,9 +231,15 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
assert result is None
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet
}

View File

@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)

View File

@@ -32,6 +32,8 @@ async def check_reachability(
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
return None
finally:
connection.close()

Some files were not shown because too many files have changed in this diff Show More