Compare commits

..

40 Commits

Author SHA1 Message Date
Alex Cheema
e86df5431c feat: add uncertainty visualization with token-level logprobs
- Add TokenHeatmap component for visualizing token confidence
- Collect and stream logprobs in generation pipeline
- Add regenerate-from-token feature with continue_from_prefix
- Add AbortController for request cancellation
- Support continue_final_message for seamless prefix continuation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 01:58:00 +00:00
Alex Cheema
aa447d3a4a style: fix formatting issues caught by treefmt
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 01:38:44 +00:00
Alex Cheema
967b040e1f refactor: use ResponsesRequest as canonical internal type
- Extend ResponsesRequest with fields: top_k, seed, stop, tools
- Remove redundant InternalTaskParams and InputMessage types
- Update all adapters to convert to ResponsesRequest
- Simplify Responses API (no conversion needed - native passthrough)
- Update all imports across codebase and tests

This eliminates type duplication and makes the Responses API
relationship explicit throughout the codebase.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 01:26:32 +00:00
Alex Cheema
a49eb90c36 refactor: make Responses API the canonical internal format
Restructure the API layer so that OpenAI Responses API is the native
format, with Chat Completions and Claude Messages as adapters on top.

Changes:
- Add new chat_completions.py adapter with streaming/non-streaming support
- Update responses.py with collect_responses_response() for non-streaming
- Update claude.py with collect_claude_response() for non-streaming
- Refactor api.py so all endpoints use adapters uniformly
- Rename _chat_chunk_stream to _token_chunk_stream (generic internal format)
- Remove unused chat_response_to_* converter functions
- Update tests to remove tests for deleted functions

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 01:26:32 +00:00
Alex Cheema
7426ac3745 feat: add Claude Messages API and OpenAI Responses API support
Adds two new API endpoints that wrap the existing chat completions:

- /v1/messages - Claude Messages API compatible endpoint
- /v1/responses - OpenAI Responses API compatible endpoint

Both support streaming (SSE) and non-streaming modes with proper
token usage reporting from actual inference stats.

Also adds top_k sampling parameter and stop sequence support to the
MLX inference engine.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 01:26:32 +00:00
Alex Cheema
c5158bee53 Add pre-commit checks documentation to AGENTS.md (#1184)
## Motivation

CI failures can be avoided by running checks locally before committing.
This adds clear documentation to AGENTS.md so that AI agents (and
humans) know exactly which checks must pass before pushing code.

## Changes

Added a new "Pre-Commit Checks (REQUIRED)" section to AGENTS.md that:
- Lists all 4 required checks (basedpyright, ruff, nix fmt, pytest)
- Provides a one-liner to run all checks in sequence
- Notes that `nix fmt` changes must be staged before committing
- Explains that CI runs `nix flake check` which verifies everything

## Why It Works

Clear documentation prevents CI failures by ensuring contributors run
checks locally first. The one-liner command makes it easy to run all
checks before committing.

## Test Plan

### Manual Testing
- Verified the documented commands work correctly

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 21:50:24 +00:00
rltakashige
5c8a237940 Handle model timeouts (#1177)
- Add eval with a timeout.
- Add fast synch flag

## Motivation

Because of the experimental FAST SYNCH flag, some models may not work.
This PR catches when this occurs and allows users to specify a run
without fast synch

## Changes

- Adds a flag to enable or disable fast synch (--fast-synch and
--no-fast-synch)
- Adds a heuristic timeout
- Reduces exo_bench default timeout to 10 minutes.

## Why It Works

Heuristic timeout assumes normal loading times on Mac devices (60 +
model size in gb / 5: e.g. DeepSeek takes up to 120 seconds to load on
tensor parallel, and timeout is set to 60 + 120 = 180s.

We could raise this value if necessary.

## Test Plan

### Manual Testing
Catches that GPT OSS fails to load in Tensor RDMA
Can launch with --no-fast-synch flag to launch GPT OSS.

**GPT OSS 20B**
TP with fast synch
<img width="3064" height="456" alt="image"
src="https://github.com/user-attachments/assets/f6e25cd8-8621-4e99-99fe-292ee05c4035"
/>

TP without fast synch
<img width="3098" height="496" alt="image"
src="https://github.com/user-attachments/assets/d36453d9-6686-4cfe-aa7c-a7d458369d4d"
/>
[Note: the performance is really not great as fast synch is off]

(As a sanity check)
PP with fast synch
<img width="3124" height="496" alt="image"
src="https://github.com/user-attachments/assets/e97d4547-c6fa-483d-badb-4b371b900b4c"
/>

PP without fast synch
<img width="3078" height="508" alt="image"
src="https://github.com/user-attachments/assets/b2e20dfd-4b0e-4295-8a92-417dfe745c28"
/>

PP without RDMA
<img width="3070" height="498" alt="image"
src="https://github.com/user-attachments/assets/a8509d68-0aef-4cda-bca5-a67d39a0801e"
/>

TP without RDMA
<img width="3068" height="496" alt="image"
src="https://github.com/user-attachments/assets/b5691429-89f4-4369-bcf2-8fde2ad7154a"
/>
2026-01-16 20:25:12 +00:00
rltakashige
745343c705 Return error responses for Chat Completions (#1173)
- Error chunks
- Use error handling in exo_bench.py

## Motivation

Return when an error occurs so that generation stops. Adding timeouts is
a separate TODO for model loading and chat completions.

## Changes

- Return HTTP exceptions as JSON responses in an OpenAI compatible
format.
- Context manager for generation to catch and return error messages.
- Use error handling in exo_bench.py.

## Test Plan

### Manual Testing
Manually tested that exo_bench returns on failures within and outside
generation

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 19:24:37 +00:00
Alex Cheema
5e28664c41 Fix draft release detection (attempt 3) (#1176)
## Motivation

Previous fix still failed in CI. Suspecting permissions issue with
GITHUB_TOKEN not being able to see draft releases via API.

## Changes

1. Add explicit `permissions: contents: write` to the job
2. Use `gh release list` first to check if draft exists (this uses a
different code path that might work better)
3. Add debug echo statements

## Test Plan

Delete v1.0.63 tag and re-push after merging.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:26:06 +00:00
Alex Cheema
ae0a804ccb Fix draft release detection query (#1175)
## Motivation

Fixes the draft release detection that failed on the v1.0.63 release
attempt.

## Changes

The jq query was piped to `head -1` which truncated multi-line JSON
output to just `{`, causing the empty check to fail.

Changed to use `first // empty` in jq instead.

## Test Plan

Tested locally:
```bash
GITHUB_REF_NAME="v1.0.63"
gh api repos/exo-explore/exo/releases --jq "[.[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")] | first // empty"
# Returns the full draft release JSON (2711 chars)
```

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:05:24 +00:00
Alex Cheema
07cf2c1aa1 Add GitHub releases with Sparkle release notes integration (#1172)
## Motivation

Closes #1140

Currently releases are uploaded to S3 for Sparkle updates but there's no
GitHub Release created, and Sparkle update dialogs don't show release
notes. Users have no visibility into what changed.

## Changes

- Added release workflow documentation comment at top of `build-app.yml`
- Added "Fetch release notes for Sparkle" step that converts markdown
from draft GitHub release to HTML
- Added "Inject release notes into appcast" step that embeds HTML in
appcast.xml with CDATA
- Added "Publish GitHub Release" step that attaches DMG and publishes
the draft

## Why It Works

- Sparkle's `<description>` tag supports HTML wrapped in CDATA for
rendering in update dialogs
- GitHub's markdown API (`/markdown`) converts the release notes to HTML
with proper formatting
- Draft releases allow writing polished notes before the build, then the
workflow publishes them automatically
- The workflow fails if no draft release exists, ensuring release notes
are always provided

## Test Plan

### Manual Testing
1. Create a draft GitHub release for a new tag with markdown release
notes
2. Push the tag to trigger the workflow
3. Verify the GitHub release is published with DMG attached
4. Download appcast.xml from S3 and verify
`<description><![CDATA[...]]></description>` contains HTML
5. Test Sparkle update dialog on macOS to confirm release notes appear

### Automated Testing
No automated tests added - this is CI workflow configuration.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 16:47:33 +00:00
Evan
83c5285a80 reduce logs
previous commits logs were too verbose, this tones them down a bit
2026-01-16 14:05:47 +00:00
Evan Quiney
39ee2bf7bd switch from synchronous threaded pinging to an async implementation (#1170)
still seeing churn in our networking - lets properly rate limit it

## changes

added an httpx client with max connections with a persistent AsyncClient

## testing

deployed on cluster, discovery VASTLY more stable (the only deleted
edges were those discovered by mdns)
2026-01-16 13:20:03 +00:00
Sami Khan
991adfbd6f fix local network warning (#1136)
## Motivation

Local network warning banner was showing on fresh install even though
mDNS was working. The check would fail before the user had a chance to
grant permission via the macOS prompt.

## Changes

- Added `hasWorkedBefore` flag persisted in UserDefaults
- Only show warning if permission previously worked but now doesn't

## Why It Works

On fresh install, the check may fail (no permission yet), but
`hasWorkedBefore` is false so no warning shows. Once the user grants
permission and a check succeeds, we record it. Future failures (zombie
permission after restart) will show the warning since `hasWorkedBefore`
is now true.

## Test Plan

### Manual Testing
Run locally

### Automated Testing
N/A
2026-01-16 13:10:50 +00:00
rltakashige
4b3de6b984 Fix exo bench for transformers 5.x (#1168)
## Motivation
Prompt Sizer was broken as transformers 5.x tokenizers create
BatchEncodings which are essentially a dictionary of {input_ids: []}
instead of the list of input ids.

## Test Plan

### Manual Testing
Tested that exo bench runs as expected.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 12:39:22 +00:00
Evan
c8de3b90ea quiet rust logs
rust logs were too verbose - now only warnings propagate to python

entirely happy not to merge this and to clean up rust logging instead,
but this felt saner right now
2026-01-16 12:34:28 +00:00
Sami Khan
6e6567a802 resolve issue #1070 (#1076)
## Motivation

https://github.com/exo-explore/exo/issues/1070

## Changes

Added check in ChatForm.svelte to reset selectedChatModel when it no
longer matches any running instance.

## Why It Works

The $effect now detects when the selected model is stale (not in
availableModels()) and resets to the first available model.

## Test Plan

### Manual Testing

1. Create instance of Model A → Delete it → Create instance of Model B →
Chat
2. Verify request goes to Model B (not Model A)

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
2026-01-15 20:00:41 +00:00
rltakashige
a735dad667 Parse GPT OSS in runner (#1160)
## Motivation

Simplification of API + moving model specific code to the runner

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Test Plan

### Manual Testing
Tested that GPT OSS outputs are parsed correctly on the dashboard.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-15 19:53:55 +00:00
rltakashige
aaf4e36bc3 FIX GPT OSS (#1165)
## Motivation

Adds several unmerged fixes for GPT OSS.
Also adds GPT OSS 20B MXFP4 Q8 instead of Q4 for numerical stability (as
this is unstable for MLX LM too)
<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->


## Test Plan

### Manual Testing
Manually tested. No further gibberish responses.

### Automated Testing
Ran EXO Bench - pipeline, tensor and single node work on both 20B and
120B models
2026-01-15 19:20:17 +00:00
Evan Quiney
3e623ccf0d up http timeout to 3 seconds and retry on BadStatusLine (#1164)
we're seeing a lot of network churn - perhaps this is a connection
timing out issue? lets also re-try after a second

## testing
none yet

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 18:15:12 +00:00
Evan Quiney
c22dad8a7d dashboard: add peer: true to package lock (#1162)
this happens every time i run npm install - lets upstream it

## testing
dashboard builds and renders
2026-01-15 17:01:43 +00:00
Evan
4bc4d50685 rust: remove dead code
the system custodian has been made unnecessary with the swift app - we
can remove it

## testing
everything still builds
2026-01-15 16:51:46 +00:00
Jake Hillion
e0aab46fd8 model_cards.py: clean up commented out code
Clean up the commented out code and make sure the comments are unified.
Carrying around the commented out code means people making changes to
model_cards are supposed to update it, but that's not clear and won't be
picked up by type checking etc. Drop it for now - it's in the git
history.

Also make the rest of the comments a bit more uniform, and place
comments about a specific model card inside the model card (instead of
above) so they don't get lost when code is added/moved around.

Test plan:
- my eyes
2026-01-15 13:21:58 +00:00
Evan Quiney
82ba42bae9 add glm-47, minimax-m21 (#1147)
Adds support glm 4.7 and MiniMax M2.1

Manual testing:
Tensor + Pipeline execution of both models.

Closes #1141 and #1142
2026-01-14 16:33:17 +00:00
Jake Hillion
3671528fa4 nix: add dashboard build with dream2nix
Continue working towards a fully Nix based build by building the
dashboard with Nix. Continuing the theme of using the existing lock
files, use dream2nix to parse the lock file and build the tree of
dependency derivations.

dream2nix doesn't like the bundleDependencies, so we apply a small patch
to the lock file that drops all dependencies that are bundled. This
should ideally be contributed upstream but that can be done later.

Use this new dashboard build in the build-app CI workflow, meaning
future macOS apps will include this reproducible dashboard.

Test plan:
- Built a DMG, shipped to a cluster, loaded in a browser with no cache
  and the dashboard looks good.

- Directory layout is as expected:
```
$ nix build .#dashboard
$ find result/
...
result/_app/immutable/entry
result/_app/immutable/entry/app.CTPAnMjf.js
result/_app/immutable/entry/start.fUSEa-2O.js
result/_app/immutable/nodes
result/_app/immutable/nodes/3.DqQr1Obm.js
result/_app/immutable/nodes/0.DgEY44RO.js
result/_app/immutable/nodes/2.BjZg_lJh.js
result/_app/immutable/nodes/1.D6vGUYYT.js
result/_app/env.js
result/_app/version.json
result/exo-logo.png
result/favicon.ico
result/index.html
```
2026-01-14 15:58:16 +01:00
Jake Hillion
e6434ec446 nix: add Rust builds with crane and fenix
The Rust workspace lacked Nix build support, making it difficult to
build packages reproducibly or run checks in CI.

Added a flake-parts module at rust/parts.nix that uses crane for Rust
builds and fenix for the nightly toolchain. The source filter isolates
rust/ and root Cargo files to prevent Python/docs changes from
triggering Rust rebuilds. Exports packages (system_custodian,
exo_pyo3_bindings wheel, exo-rust-workspace) and checks (cargo-nextest,
cargo-doc) for all three target platforms.

The devShell now uses inputsFrom to inherit build dependencies from the
workspace package, removing the need for manual pkg-config/openssl setup.

Test plan:
- Ran `nix flake check` successfully
- Built `nix build ".#checks.x86_64-linux.cargo-nextest"` and tests pass
- Built `nix build ".#exo_pyo3_bindings"` and wheel is produced
2026-01-14 11:52:29 +00:00
Jake Hillion
bdb43e1dbb nix: drop noisy echos from devshell
Drop all the printing when entering a devshell. It's annoying, and not a
super accurate description of how to develop exo anyway.
2026-01-14 10:04:57 +00:00
Jake Hillion
e4a01e2b0e chore(deps): nix lock file maintenance
Update nix flake inputs. Add a second input as Swift is currently broken
in nixpkgs on Linux for `swift-format` as we want `nix fmt` to continue
being reproducible everywhere.
2026-01-13 19:57:14 +01:00
Evan Quiney
1200a7db64 Add tensor sharding for GPT-OSS (#1144)
## Motivation

GPT OSS did not previously support tensor sharding

## Changes

Add GPT sharding support in tensor_auto_parallel.
Code is mostly @rltakashige's

## Test Plan

### Manual Testing
Tested GPT-OSS - MLX Fast Sync causes issues in Tensor RDMA - this is a general problem at the moment.
2026-01-13 17:25:52 +00:00
Evan Quiney
47ceb54bc1 up the rlimit (#1148)
Fixes #1117 

Manual testing:
Launched 100 instances. worked. yay.
2026-01-13 15:00:54 +00:00
Jake Hillion
f8112fdf25 nix: convert to flake-parts
Preparing to add a flake-parts module for Rust builds. The flake-utils
library doesn't support the module system needed for cleanly separating
the Rust build configuration.

Converted from flake-utils to flake-parts, switching to the treefmt-nix
flakeModule import pattern. The devShell and formatter outputs remain
functionally equivalent.

Test plan:
- Ran `nix flake check` successfully
- Verified `nix develop` provides the same environment
2026-01-13 15:06:44 +01:00
Alex Cheema
e388f59480 docs: add AGENTS.md for AI coding agents guidance (#1132)
## Motivation

Add documentation to help AI coding agents (Claude Code, Cursor, GitHub
Copilot, etc.) understand the exo codebase and contribute effectively.

## Changes

- Add `AGENTS.md` with guidance for AI agents working on the codebase
- Add symlink `CLAUDE.md -> AGENTS.md` for backwards compatibility with
Claude Code

## Why It Works

`AGENTS.md` is becoming a standard convention for AI agent instructions.
The symlink ensures Claude Code (which looks for `CLAUDE.md`) continues
to work while supporting the broader `AGENTS.md` convention.

## Test Plan

### Manual Testing
- Verified symlink works correctly

### Automated Testing
- N/A (documentation only)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 13:05:47 +00:00
Alex Cheema
e5e74e1eef Upgrade mlx-lm to 0.30.2 with transformers 5.x compatibility (#1125)
## Motivation

Upgrade mlx-lm to version 0.30.2 which requires transformers 5.0.0rc2 as
a prerelease dependency. This enables support for newer models like Kimi
K2 Thinking while maintaining compatibility with existing models.

The transformers 5.x release includes breaking changes that affect
custom tokenizers like Kimi's TikTokenTokenizer, requiring compatibility
fixes.

## Changes

### Core Changes
- **mlx-lm upgrade**: Bump to 0.30.2 with locked exact versions for
mlx/mlx-lm to prevent breaking changes
- **transformers 5.x compatibility**: Enable prerelease transformers
dependency

### Kimi K2 Tokenizer Fixes
- Add `bytes_to_unicode` monkey-patch to restore function moved in
transformers 5.0.0rc2
- Load `TikTokenTokenizer` directly instead of via `AutoTokenizer` to
bypass transformers 5.x bug with `auto_map` fallback
- Patch `encode()` to use tiktoken directly with `allowed_special="all"`
to handle special tokens from chat templates

### Other Changes
- Dashboard: Show disk usage for completed model downloads
- CI: Add `workflow_dispatch` trigger to build-app workflow
- Docs: Add basic API documentation

### Testing
- Add comprehensive tokenizer unit tests for all supported models
- Tests verify encode/decode, special token handling, and chat template
encoding

## Why It Works

**bytes_to_unicode issue**: transformers 5.0.0rc2 moved
`bytes_to_unicode` from `transformers.models.gpt2.tokenization_gpt2` to
`transformers.convert_slow_tokenizer`. Kimi's `tokenization_kimi.py`
imports from the old location. The monkey-patch restores it at module
load time.

**AutoTokenizer issue**: transformers 5.x has a bug where
`tokenizer_class_from_name('TikTokenTokenizer')` returns `None` for
custom tokenizers with `auto_map`. Loading the tokenizer directly
bypasses this.

**encode() issue**: transformers 5.x's `pad()` method fails for slow
tokenizers. Using tiktoken's encode directly with
`allowed_special="all"` avoids this path and properly handles special
tokens like `<|im_user|>` from chat templates.

## Test Plan

### Manual Testing
- Hardware: 2x Mac Studios connected via Thunderbolt 5 (mike22 and
james21)
- Tested Kimi K2 Thinking, GPT-OSS-120B, GPT-OSS-20B, LLama-3.1-8B-bf16, qwen3-30B-A3B-8bit model with pipeline parallelism across both
nodes
- Verified warmup inference completes successfully
- Verified chat completions work with special tokens

### Automated Testing
- Added `test_tokenizers.py` with 31 tests covering:
- Basic encode/decode for all model families (deepseek, kimi, llama,
qwen, gpt-oss, glm)
  - Special token encoding (critical for chat templates)
  - Chat template application and encoding
  - Kimi-specific and GLM-specific edge cases
- All tests pass: `uv run pytest
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py`

### Failing Tests
RDMA with all models.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-01-13 12:06:04 +00:00
Jake Hillion
b968d6f0a0 ci: remove old commented out job 2026-01-13 12:42:04 +01:00
Jake Hillion
3bfffd9b4f ci: build all Nix outputs on all platforms and push to cachix
The CI was only running `nix flake check` on ubuntu-latest, missing
builds for other platforms and not caching packages or devShells.

Added a matrix-based `nix-build` job that runs on macos-26 (aarch64-darwin),
ubuntu-latest (x86_64-linux), and ubuntu-24.04-arm (aarch64-linux). Each
job enumerates all packages and devShells via `nix flake show --json`,
builds them in a single `nix build` call for parallelization, then runs
`nix flake check`. The cachix-action pushes all built outputs automatically.

This ensures all Nix outputs are built and cached for every supported
platform, speeding up local development and CI runs.

Test plan:
- Tested jq enumeration command locally, correctly outputs devShell paths
- Verified xargs pipeline works with the enumerated outputs
2026-01-13 12:37:12 +01:00
Jake Hillion
007eb80029 nix: enable cachix
Enable cachix and push to it in the pipeline.yml workflow. This won't
cache a huge amount yet but will automatically extend our caching as we
build more of the repo with Nix in CI. It can also be used by local
users by accepting our cache to improve the speed of local builds.

Test plan:
- CI
2026-01-12 17:24:59 +01:00
Jake Hillion
8d7b6789b3 dashboard: show disk usage for completed models
The downloads dashboard showed "Completed" for finished model downloads
but provided no indication of how much disk space each model or the
total models on a node were using.

Added total_bytes field to DownloadCompleted type so the size is
preserved when a download completes. Updated the dashboard to display
the model size next to "Completed" status (e.g., "Completed (251.1GB)")
and a total disk usage line below the model count for each node (e.g.,
"502.2GB on disk").

Test plan:
- Ran unit tests for download apply and planning logic
- Type checked all modified files with basedpyright
2026-01-12 16:34:29 +01:00
Jake Hillion
3c5b7ea670 ci: add workflow_dispatch trigger to build-app
Build app is the most convenient way to get a DMG for testing, but
currently it's a bit limited. You have to push to test-app every time
which is far from ideal and requires a bit too much force pushing for my
liking.

Add the workflow_dispatch trigger. This adds a button in the actions UI
to trigger a workflow for a named branch, which means you can use your
normal dev branch instead of having to push to test-app. We'll leave
that behaviour there for now too, though it may change in future.

Filter on `"${{ github.event_name }}" == "workflow_dispatch"` and set
those to alpha as well. Will verify by pushing the first version from
`main` just in case. Unfortunately we do have to merge this before we
can test it.

Test plan:
- Looking really hard.
2026-01-12 12:14:21 +01:00
PG
b74a610537 Add a basic documentation to the api interface (#1122)
## Motivation

Adds basic api documentation

## Changes

- Add docs/api.md
- Modify README.md
2026-01-11 18:44:40 +00:00
Jake Hillion
18c4e49f91 nix: put treefmt in devshell
treefmt is a useful to be able to access directly for some formatters like
`jj fix`. Expose it in the devshell.

Test plan:
- Used with `jj fix` on a large branch. It worked.
2026-01-09 17:53:50 +01:00
102 changed files with 6594 additions and 2512 deletions

View File

@@ -1,5 +1,16 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
@@ -11,8 +22,10 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.8.1
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
@@ -87,6 +100,52 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -113,11 +172,22 @@ jobs:
uv python install
uv sync --locked
- name: Install Nix
uses: cachix/install-nix-action@v31
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Configure Cachix
uses: cachix/cachix-action@v14
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build dashboard
run: |
cd dashboard
npm ci
npm run build
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
mkdir -p dashboard/build
cp -r "$DASHBOARD_OUT"/* dashboard/build/
- name: Install Sparkle CLI
run: |
@@ -293,6 +363,28 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -325,3 +417,26 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi

View File

@@ -20,6 +20,12 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
@@ -88,9 +94,19 @@ jobs:
- uses: ./.github/actions/typecheck
nix-flake-check:
name: Check Nix flake
runs-on: ubuntu-latest
nix:
name: Build and check (${{ matrix.system }})
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- runner: macos-26
system: aarch64-darwin
- runner: ubuntu-latest
system: x86_64-linux
- runner: ubuntu-24.04-arm
system: aarch64-linux
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -101,83 +117,20 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Run nix flake check
run: |
nix flake check
shell: bash
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
# ci:
# needs: typecheck
# runs-on: ubuntu-latest
# permissions:
# contents: read
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
# token: ${{ secrets.GITHUB_TOKEN }}
# lfs: true
#
# - name: Configure git user
# run: |
# git config --local user.email "github-actions@users.noreply.github.com"
# git config --local user.name "github-actions bot"
# shell: bash
#
# - name: Pull LFS files
# run: |
# echo "Pulling Git LFS files..."
# git lfs pull
# shell: bash
#
# - name: Setup EXO_HOME and API_PORT
# run: |
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
# # Generate random port (macOS compatible method)
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
# echo "Created EXO_HOME: $EXO_HOME"
# echo "Generated API_PORT: $API_PORT"
# shell: bash
#
# - name: Setup Nix Environment
# run: |
# echo "Checking for nix installation..."
#
# # Check if nix binary exists directly
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
# echo "PATH=$PATH" >> $GITHUB_ENV
# nix --version
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
# echo "Found nix profile script, sourcing..."
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
# nix --version
# elif command -v nix >/dev/null 2>&1; then
# echo "Nix already in PATH"
# nix --version
# else
# echo "Nix not found. Debugging info:"
# echo "Contents of /nix/var/nix/profiles/default/:"
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
# exit 1
# fi
# shell: bash
#
# - uses: ./.github/actions/lint-check
#
# - uses: ./.github/actions/unit-test
#
# - name: Cleanup EXO_HOME
# run: |
# echo "Cleaning up EXO_HOME: $EXO_HOME"
# rm -rf "$EXO_HOME"
# shell: bash
# if: always()
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
[
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
] | .[]
' | xargs nix build
- name: Run nix flake check
run: nix flake check

View File

@@ -0,0 +1,156 @@
"""Type stubs for mlx_lm.models.deepseek_v3"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: Optional[int]
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
attention_bias: bool
class DeepseekV3Attention(nn.Module):
config: ModelArgs
hidden_size: int
num_heads: int
max_position_embeddings: int
rope_theta: float
q_lora_rank: Optional[int]
qk_rope_head_dim: int
kv_lora_rank: int
v_head_dim: int
qk_nope_head_dim: int
q_head_dim: int
scale: float
q_proj: nn.Linear
q_a_proj: nn.Linear
q_a_layernorm: nn.RMSNorm
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
o_proj: nn.Linear
rope: Any
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3MLP(nn.Module):
config: ModelArgs
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class MoEGate(nn.Module):
config: ModelArgs
top_k: int
norm_topk_prob: bool
n_routed_experts: Optional[int]
routed_scaling_factor: float
n_group: int
topk_group: int
weight: mx.array
e_score_correction_bias: mx.array
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class DeepseekV3MoE(nn.Module):
config: ModelArgs
num_experts_per_tok: int
switch_mlp: SwitchGLU
gate: MoEGate
shared_experts: DeepseekV3MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class DeepseekV3DecoderLayer(nn.Module):
self_attn: DeepseekV3Attention
mlp: DeepseekV3MLP | DeepseekV3MoE
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3Model(nn.Module):
vocab_size: int
embed_tokens: nn.Embedding
layers: list[DeepseekV3DecoderLayer]
norm: nn.RMSNorm
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
class Model(nn.Module):
model_type: str
model: DeepseekV3Model
lm_head: nn.Linear
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[DeepseekV3DecoderLayer]: ...

View File

@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
def __call__(self, x, gate): ...
class SwitchGLU(nn.Module):
gate_proj: SwitchLinear
up_proj: SwitchLinear
down_proj: SwitchLinear
activation: SwiGLU
def __init__(
self,
input_dims: int,

View File

@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
from functools import partial
from pathlib import Path
from typing import Any
from transformers import PreTrainedTokenizerFast
@@ -103,37 +104,55 @@ class TokenizerWrapper:
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
def add_eos_token(self, token: str): # -> None:
...
@property
def has_thinking(self): # -> bool:
...
@property
def think_start(self): # -> str | None:
...
@property
def think_end(self): # -> str | None:
...
@property
def has_tool_calling(self): # -> bool:
...
@property
def tool_call_start(self): # -> str | None:
...
@property
def tool_call_end(self): # -> str | None:
...
@property
def detokenizer(self): # -> NaiveStreamingDetokenizer:
"""
Get a stateful streaming detokenizer.
"""
def __getattr__(self, attr): # -> set[Any] | Any:
...
def __setattr__(self, attr, value): # -> None:
...
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
all_special_tokens: list[str]
def __init__(
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,
tool_call_end: str | None = ...,
) -> None: ...
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
def apply_chat_template(
self,
messages: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = False,
tools: Any = None,
**kwargs: Any,
) -> str: ...
def get_vocab(self) -> dict[str, int]: ...
def add_eos_token(self, token: str) -> None: ...
@property
def has_thinking(self) -> bool: ...
@property
def think_start(self) -> str | None: ...
@property
def think_end(self) -> str | None: ...
@property
def has_tool_calling(self) -> bool: ...
@property
def tool_call_start(self) -> str | None: ...
@property
def tool_call_end(self) -> str | None: ...
@property
def detokenizer(self) -> NaiveStreamingDetokenizer:
"""Get a stateful streaming detokenizer."""
def __getattr__(self, attr: str) -> Any: ...
def __setattr__(self, attr: str, value: Any) -> None: ...
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load_tokenizer(
def load(
model_path: Path,
tokenizer_config_extra=...,
return_tokenizer=...,
eos_token_ids=...,
) -> (
TokenizerWrapper
| type[SPMStreamingDetokenizer]
| partial[SPMStreamingDetokenizer]
| type[BPEStreamingDetokenizer]
| type[NaiveStreamingDetokenizer]
):
tokenizer_config_extra: dict[str, Any] | None = None,
eos_token_ids: list[int] | int | None = None,
) -> TokenizerWrapper:
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@@ -165,4 +177,7 @@ def load_tokenizer(
a Hugging Face repo ID.
"""
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
# Alias for backward compatibility
load_tokenizer = load
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...

121
AGENTS.md Normal file
View File

@@ -0,0 +1,121 @@
# AGENTS.md
This file provides guidance to AI coding agents when working with code in this repository.
## Project Overview
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
## Build & Run Commands
```bash
# Build the dashboard (required before running exo)
cd dashboard && npm install && npm run build && cd ..
# Run exo (starts both master and worker with API at http://localhost:52415)
uv run exo
# Run with verbose logging
uv run exo -v # or -vv for more verbose
# Run tests (excludes slow tests by default)
uv run pytest
# Run all tests including slow tests
uv run pytest -m ""
# Run a specific test file
uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition
A single exo `Node` (src/exo/main.py) runs multiple components:
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
- **Worker**: Handles inference tasks, downloads models, manages runner processes
- **Master**: Coordinates cluster state, places model instances across nodes
- **Election**: Bully algorithm for master election
- **API**: FastAPI server for OpenAI-compatible chat completions
### Message Flow
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
- `LOCAL_EVENTS`: Workers send events to master for indexing
- `COMMANDS`: Workers/API send commands to master
- `ELECTION_MESSAGES`: Election protocol messages
- `CONNECTION_MESSAGES`: libp2p connection updates
### Event Sourcing
The system uses event sourcing for state management:
- `State` (src/exo/shared/types/state.py): Immutable state object
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
- Master indexes events and broadcasts; workers apply indexed events
### Key Type Hierarchy
- `src/exo/shared/types/`: Pydantic models for all shared types
- `events.py`: Event types (discriminated union)
- `commands.py`: Command types
- `tasks.py`: Task types for worker execution
- `state.py`: Cluster state model
### Rust Components
Rust code in `rust/` provides:
- `networking`: libp2p networking (gossipsub, peer discovery)
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
- `system_custodian`: System-level operations
### Dashboard
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
## Code Style Requirements
From .cursorrules:
- Strict, exhaustive typing - never bypass the type-checker
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
- Pydantic models with `frozen=True` and `strict=True`
- Pure functions with injectable effect handlers for side-effects
- Descriptive names - no abbreviations or 3-letter acronyms
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

19
Cargo.lock generated
View File

@@ -4340,25 +4340,6 @@ dependencies = [
"libc",
]
[[package]]
name = "system_custodian"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
name = "tagptr"
version = "0.2.0"

View File

@@ -3,7 +3,6 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
]
@@ -25,7 +24,6 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools

41
MISSED_THINGS.md Normal file
View File

@@ -0,0 +1,41 @@
# Missed things
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] We put cache limit back in utils_mlx.py.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).

View File

@@ -19,7 +19,6 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.8.1;
minimumVersion = 2.9.0-beta.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
}
}
],

View File

@@ -56,6 +56,11 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,30 +35,43 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
lastConnectionState = "connecting"
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
checkTask = Task { [weak self] in
guard let self else { return }
let result = await self.performCheck()
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
connection?.cancel()
connection = nil
@@ -84,22 +97,7 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
conn.stateUpdateHandler = { state in
switch state {
case .ready:
resumeOnce(.working)
@@ -108,6 +106,7 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -127,7 +126,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: 3_000_000_000)
try? await Task.sleep(nanoseconds: timeout)
let state = conn.state
switch state {
case .ready:

View File

@@ -6,7 +6,7 @@ enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
@@ -28,6 +28,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
@@ -26,7 +27,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -104,22 +105,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
@@ -241,6 +266,9 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn
@@ -296,6 +324,12 @@ def main() -> int:
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
@@ -317,7 +351,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -396,7 +430,7 @@ def main() -> int:
):
continue
if 0 < n <= args.max_nodes:
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
if not selected:
@@ -438,7 +472,13 @@ def main() -> int:
)
client.request_json("POST", "/instance", body={"instance": instance})
wait_for_instance_ready(client, instance_id)
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)

60
dashboard/dashboard.nix Normal file
View File

@@ -0,0 +1,60 @@
{ lib
, config
, dream2nix
, ...
}:
let
# Read and parse the lock file
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
# For packages with bundleDependencies, filter out deps that are bundled
# (bundled deps are inside the tarball, not separate lockfile entries)
fixedPackages = lib.mapAttrs
(path: entry:
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
then entry // {
dependencies = lib.filterAttrs
(name: _: !(lib.elem name entry.bundleDependencies))
(entry.dependencies or { });
}
else entry
)
(rawLockFile.packages or { });
fixedLockFile = rawLockFile // { packages = fixedPackages; };
in
{
imports = [
dream2nix.modules.dream2nix.nodejs-package-lock-v3
dream2nix.modules.dream2nix.nodejs-granular-v3
];
name = "exo-dashboard";
version = "1.0.0";
mkDerivation = {
src = config.deps.dashboardSrc;
buildPhase = ''
runHook preBuild
npm run build
runHook postBuild
'';
installPhase = ''
runHook preInstall
cp -r build $out/build
runHook postInstall
'';
};
deps = { nixpkgs, ... }: {
inherit (nixpkgs) stdenv;
dashboardSrc = null; # Injected by parts.nix
};
nodejs-package-lock-v3 = {
# Don't use packageLockFile - provide the fixed lock content directly
packageLock = fixedLockFile;
};
}

View File

@@ -863,6 +863,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -902,6 +903,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,6 +1520,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1527,6 +1530,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1939,6 +1943,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2646,6 +2651,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2833,6 +2839,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2977,6 +2984,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2998,6 +3006,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

44
dashboard/parts.nix Normal file
View File

@@ -0,0 +1,44 @@
{ inputs, ... }:
{
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to only include dashboard directory
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
inDashboardDir =
(lib.hasInfix "/dashboard/" path)
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|| (baseName == "dashboard" && type == "directory");
in
inDashboardDir;
};
# Build the dashboard with dream2nix (includes node_modules in output)
dashboardFull = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
# Inject the filtered source
{
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
}
];
};
in
{
# Extract just the static site from the full build
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
cp -r ${dashboardFull}/build $out
'';
};
}

View File

@@ -60,12 +60,39 @@
return models;
});
// Auto-select the first available model if none is selected
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
$effect(() => {
const models = availableModels();
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -1,14 +1,16 @@
<script lang="ts">
import {
messages,
currentResponse,
import {
messages,
currentResponse,
isLoading,
deleteMessage,
editAndRegenerate,
regenerateLastResponse
regenerateLastResponse,
regenerateFromToken
} from '$lib/stores/app.svelte';
import type { MessageAttachment } from '$lib/stores/app.svelte';
import MarkdownContent from './MarkdownContent.svelte';
import TokenHeatmap from './TokenHeatmap.svelte';
interface Props {
class?: string;
@@ -95,6 +97,23 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty view state - tracks which messages show token heatmap
let uncertaintyViewMessageIds = $state<Set<string>>(new Set());
function toggleUncertaintyView(messageId: string) {
const newSet = new Set(uncertaintyViewMessageIds);
if (newSet.has(messageId)) {
newSet.delete(messageId);
} else {
newSet.add(messageId);
}
uncertaintyViewMessageIds = newSet;
}
function isUncertaintyViewEnabled(messageId: string): boolean {
return uncertaintyViewMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString('en-US', {
hour12: false,
@@ -366,7 +385,17 @@ function isThinkingExpanded(messageId: string): boolean {
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if message.role === 'assistant' && isUncertaintyViewEnabled(message.id) && message.tokens && message.tokens.length > 0}
<!-- Uncertainty heatmap view -->
<TokenHeatmap
tokens={message.tokens}
isGenerating={loading}
onRegenerateFrom={(tokenIndex) => regenerateFromToken(message.id, tokenIndex)}
/>
{:else}
<!-- Normal markdown view -->
<MarkdownContent content={message.content || (loading ? response : '')} />
{/if}
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
@@ -419,7 +448,20 @@ function isThinkingExpanded(messageId: string): boolean {
</svg>
</button>
{/if}
<!-- Uncertainty view toggle (assistant messages with tokens only) -->
{#if message.role === 'assistant' && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleUncertaintyView(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isUncertaintyViewEnabled(message.id) ? 'text-exo-yellow' : 'text-exo-light-gray hover:text-exo-yellow'}"
title={isUncertaintyViewEnabled(message.id) ? 'Hide uncertainty' : 'Show uncertainty'}
>
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
</svg>
</button>
{/if}
<!-- Delete button -->
<button
onclick={() => handleDeleteClick(message.id)}

View File

@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
// Uses API preview data when available, falls back to local estimation
const placementPreview = $derived(() => {
const nodeArray = nodeList();
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
const numNodes = nodeArray.length;
const iconSize = numNodes === 1 ? 50 : 36;

View File

@@ -0,0 +1,192 @@
<script lang="ts">
import type { TokenData } from '$lib/stores/app.svelte';
interface Props {
tokens: TokenData[];
class?: string;
isGenerating?: boolean;
onRegenerateFrom?: (tokenIndex: number) => void;
}
let { tokens, class: className = '', isGenerating = false, onRegenerateFrom }: Props = $props();
// Tooltip state - track both token data and index
let hoveredTokenIndex = $state<number | null>(null);
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
let isTooltipHovered = $state(false);
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
// Derive the hovered token from the index (stable across re-renders)
const hoveredToken = $derived(
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
? { token: tokens[hoveredTokenIndex], index: hoveredTokenIndex, ...hoveredPosition }
: null
);
/**
* Get confidence styling based on probability.
* Following Apple design principles: high confidence tokens blend in,
* only uncertainty draws attention.
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return 'text-inherit'; // Expected tokens - blend in
if (probability > 0.5) return 'bg-gray-500/10 text-inherit'; // Slight hint
if (probability > 0.2) return 'bg-amber-500/15 text-amber-200/90'; // Subtle warmth
return 'bg-red-500/20 text-red-200/90'; // Draws attention
}
/**
* Get border/underline styling for uncertain tokens
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return 'border-transparent'; // No border for expected
if (probability > 0.5) return 'border-gray-500/20';
if (probability > 0.2) return 'border-amber-500/30';
return 'border-red-500/40';
}
function clearHideTimeout() {
if (hideTimeoutId) {
clearTimeout(hideTimeoutId);
hideTimeoutId = null;
}
}
function handleMouseEnter(event: MouseEvent, token: TokenData, index: number) {
clearHideTimeout();
const rect = (event.target as HTMLElement).getBoundingClientRect();
hoveredTokenIndex = index;
hoveredPosition = {
x: rect.left + rect.width / 2,
y: rect.top - 10
};
}
function handleMouseLeave() {
clearHideTimeout();
// Use longer delay during generation to account for re-renders
const delay = isGenerating ? 300 : 100;
hideTimeoutId = setTimeout(() => {
if (!isTooltipHovered) {
hoveredTokenIndex = null;
hoveredPosition = null;
}
}, delay);
}
function handleTooltipEnter() {
clearHideTimeout();
isTooltipHovered = true;
}
function handleTooltipLeave() {
isTooltipHovered = false;
hoveredTokenIndex = null;
hoveredPosition = null;
}
function handleRegenerate() {
if (hoveredToken && onRegenerateFrom) {
const indexToRegenerate = hoveredToken.index;
// Clear hover state immediately
hoveredTokenIndex = null;
hoveredPosition = null;
isTooltipHovered = false;
// Call regenerate
onRegenerateFrom(indexToRegenerate);
}
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + '%';
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
function getProbabilityColor(probability: number): string {
if (probability > 0.8) return 'text-gray-300';
if (probability > 0.5) return 'text-gray-400';
if (probability > 0.2) return 'text-amber-400';
return 'text-red-400';
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(tokenData.probability)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
onmouseleave={handleMouseLeave}
>{tokenData.token}</span>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
onmouseenter={handleTooltipEnter}
onmouseleave={handleTooltipLeave}
>
<div class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48">
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-500 text-xs">Token:</span>
<span class="text-white font-mono ml-1">"{hoveredToken.token.token}"</span>
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2">{formatProbability(hoveredToken.token.probability)}</span>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono">{formatLogprob(hoveredToken.token.logprob)}</span>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700/50 mt-2 pt-2">
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24">"{alt.token}"</span>
<span class="text-gray-400 ml-2">{formatProbability(altProb)}</span>
</div>
{/each}
</div>
{/if}
<!-- Regenerate button -->
{#if onRegenerateFrom}
<button
onclick={handleRegenerate}
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
>
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
</svg>
Regenerate from here
</button>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: NodeInfo | undefined): string | null {
function checkNode(node: typeof data.nodes[string]): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
@@ -39,19 +39,17 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
return matchFromInterfaces.name;
}
if (node.ip_to_interface) {
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
if (mapped && mapped.trim().length > 0) {
return mapped;
}
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
}
return null;
}
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
@@ -257,24 +255,21 @@ function wrapLine(text: string, maxLen: number): string[] {
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
const pairMap = new Map<string, PairEntry>();
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
if (edge.source === a) entry.aToB = true;
else entry.bToA = true;
const ip = edge.sendBackIp || '?';
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
const ifaceInfo = getInterfaceLabel(edge.source, ip);
entry.connections.push({
from: edge.source,
@@ -343,8 +338,9 @@ function wrapLine(text: string, maxLen: number): string[] {
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
@@ -385,32 +381,32 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
const quadrants: Record<string, typeof debugEdgeLabels> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
if (quadrantEdges.length === 0) return;
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
quadrantEdges.forEach(edge => {
edges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;

View File

@@ -99,7 +99,7 @@ interface RawNodeProfile {
interface RawTopologyNode {
nodeId: string;
nodeProfile?: RawNodeProfile;
nodeProfile: RawNodeProfile;
}
interface RawTopologyConnection {
@@ -110,19 +110,9 @@ interface RawTopologyConnection {
| string;
}
// Connection can be an object or a tuple [source, target, metadata]
type RawConnectionItem =
| RawTopologyConnection
| [
string,
string,
{ sinkMultiaddr?: { ip_address?: string; address?: string } }?,
];
interface RawTopology {
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
connections?: RawConnectionItem[];
nodes: RawTopologyNode[];
connections?: RawTopologyConnection[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
@@ -192,6 +182,20 @@ export interface MessageAttachment {
mimeType?: string;
}
// Token-level data for uncertainty visualization
export interface TopLogprob {
token: string;
logprob: number;
bytes?: number[];
}
export interface TokenData {
token: string;
logprob: number;
probability: number; // exp(logprob)
topLogprobs: TopLogprob[];
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -201,6 +205,7 @@ export interface Message {
attachments?: MessageAttachment[];
ttftMs?: number; // Time to first token in ms (for assistant messages)
tps?: number; // Tokens per second (for assistant messages)
tokens?: TokenData[]; // Token-level data for uncertainty visualization
}
export interface Conversation {
@@ -223,18 +228,9 @@ function transformTopology(
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === "string" ? node : node.nodeId;
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode =
typeof node === "object" ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
@@ -283,7 +279,7 @@ function transformTopology(
}
}
nodes[nodeId] = {
nodes[node.nodeId] = {
system_info: {
model_id: profile?.modelId ?? "Unknown",
chip: profile?.chipId,
@@ -311,39 +307,14 @@ function transformTopology(
};
}
// Handle connections - can be objects with localNodeId/sendBackNodeId or tuples [source, target, metadata]
for (const conn of raw.connections || []) {
let localNodeId: string | undefined;
let sendBackNodeId: string | undefined;
let sendBackMultiaddr:
| { multiaddr?: string; address?: string; ip_address?: string }
| string
| undefined;
// Check if it's a tuple format [source, target, metadata]
if (Array.isArray(conn)) {
localNodeId = conn[0] as string;
sendBackNodeId = conn[1] as string;
const metadata = conn[2] as
| { sinkMultiaddr?: { ip_address?: string; address?: string } }
| undefined;
if (metadata?.sinkMultiaddr) {
sendBackMultiaddr = metadata.sinkMultiaddr;
}
} else {
// Object format with localNodeId/sendBackNodeId
localNodeId = conn.localNodeId;
sendBackNodeId = conn.sendBackNodeId;
sendBackMultiaddr = conn.sendBackMultiaddr;
}
if (!localNodeId || !sendBackNodeId) continue;
if (localNodeId === sendBackNodeId) continue;
if (!nodes[localNodeId] || !nodes[sendBackNodeId]) continue;
if (!conn.localNodeId || !conn.sendBackNodeId) continue;
if (conn.localNodeId === conn.sendBackNodeId) continue;
if (!nodes[conn.localNodeId] || !nodes[conn.sendBackNodeId]) continue;
let sendBackIp: string | undefined;
if (sendBackMultiaddr) {
const multi = sendBackMultiaddr;
if (conn.sendBackMultiaddr) {
const multi = conn.sendBackMultiaddr;
if (typeof multi === "string") {
sendBackIp = extractIpFromMultiaddr(multi);
} else {
@@ -355,8 +326,8 @@ function transformTopology(
}
edges.push({
source: localNodeId,
target: sendBackNodeId,
source: conn.localNodeId,
target: conn.sendBackNodeId,
sendBackIp,
});
}
@@ -412,6 +383,21 @@ class AppStore {
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
private currentRequestController: AbortController | null = null;
/**
* Abort any in-flight generation request
*/
abortCurrentRequest(): boolean {
if (this.currentRequestController) {
this.currentRequestController.abort();
this.currentRequestController = null;
this.isLoading = false;
this.currentResponse = "";
return true;
}
return false;
}
constructor() {
if (browser) {
@@ -1442,6 +1428,10 @@ class AppStore {
let firstTokenTime: number | null = null;
let tokenCount = 0;
// Create abort controller for this request
const controller = new AbortController();
this.currentRequestController = controller;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
@@ -1452,7 +1442,10 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
signal: controller.signal,
});
if (!response.ok) {
@@ -1468,6 +1461,7 @@ class AppStore {
const decoder = new TextDecoder();
let fullContent = "";
let buffer = "";
const collectedTokens: TokenData[] = [];
while (true) {
const { done, value } = await reader.read();
@@ -1489,8 +1483,8 @@ class AppStore {
try {
const parsed = JSON.parse(data);
const tokenContent = parsed.choices?.[0]?.delta?.content;
if (tokenContent) {
const delta = parsed.choices?.[0]?.delta?.content;
if (delta) {
// Track first token for TTFT
if (firstTokenTime === null) {
firstTokenTime = performance.now();
@@ -1507,7 +1501,30 @@ class AppStore {
this.tps = (tokenCount / elapsed) * 1000;
}
fullContent += tokenContent;
// Extract logprobs for uncertainty visualization
const logprobsData = parsed.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
const tokenData: TokenData = {
token: logprobItem.token || delta,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: {
token: string;
logprob: number;
bytes?: number[];
}) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
};
collectedTokens.push(tokenData);
}
fullContent += delta;
// Strip thinking tags for display and extract thinking content
const { displayContent, thinkingContent } =
@@ -1521,6 +1538,7 @@ class AppStore {
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
this.messages[idx].tokens = [...collectedTokens];
}
this.persistActiveConversation();
}
@@ -1568,9 +1586,16 @@ class AppStore {
if (this.tps !== null) {
this.messages[idx].tps = this.tps;
}
if (collectedTokens.length > 0) {
this.messages[idx].tokens = collectedTokens;
}
}
this.persistActiveConversation();
} catch (error) {
// Don't show error for aborted requests (user cancelled)
if (error instanceof Error && error.name === "AbortError") {
return;
}
console.error("Error sending message:", error);
// Update the assistant message with error
const idx = this.messages.findIndex((m) => m.id === assistantMessage.id);
@@ -1580,6 +1605,237 @@ class AppStore {
}
this.persistActiveConversation();
} finally {
// Clean up controller if this is still the active request
if (this.currentRequestController === controller) {
this.currentRequestController = null;
}
this.isLoading = false;
this.currentResponse = "";
this.updateActiveConversation();
}
}
/**
* Regenerate from a specific token in an assistant message.
* Keeps content up to and including the specified token, then continues generation.
* If a generation is already in progress, it will be aborted first.
*/
async regenerateFromToken(
messageId: string,
tokenIndex: number,
): Promise<void> {
// Abort any in-flight request first
this.abortCurrentRequest();
const messageIdx = this.messages.findIndex((m) => m.id === messageId);
if (messageIdx === -1) return;
const message = this.messages[messageIdx];
if (message.role !== "assistant" || !message.tokens) return;
// Get tokens up to and including the specified index
const tokensToKeep = message.tokens.slice(0, tokenIndex + 1);
const prefixText = tokensToKeep.map((t) => t.token).join("");
// Remove all messages after this assistant message
this.messages = this.messages.slice(0, messageIdx + 1);
// Update the message to show the prefix
this.messages[messageIdx].content = prefixText;
this.messages[messageIdx].tokens = tokensToKeep;
// Set up for continuation
this.isLoading = true;
this.currentResponse = prefixText;
this.ttftMs = null;
this.tps = null;
this.totalTokens = tokensToKeep.length;
try {
// Build messages for API - include the partial assistant message
const systemPrompt = {
role: "system" as const,
content:
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
};
// Get all messages up to and including the one we're regenerating from
const apiMessages = [
systemPrompt,
...this.messages.map((m) => {
let msgContent = m.content;
if (m.attachments) {
for (const attachment of m.attachments) {
if (attachment.type === "text" && attachment.content) {
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
}
}
}
return { role: m.role, content: msgContent };
}),
];
// Determine model
let modelToUse = this.selectedChatModel;
if (!modelToUse) {
for (const [, instanceWrapper] of Object.entries(this.instances)) {
if (instanceWrapper && typeof instanceWrapper === "object") {
const keys = Object.keys(
instanceWrapper as Record<string, unknown>,
);
if (keys.length === 1) {
const instance = (instanceWrapper as Record<string, unknown>)[
keys[0]
] as { shardAssignments?: { modelId?: string } };
if (instance?.shardAssignments?.modelId) {
modelToUse = instance.shardAssignments.modelId;
break;
}
}
}
}
}
if (!modelToUse) {
throw new Error("No model available");
}
// Start timing
const requestStartTime = performance.now();
let firstTokenTime: number | null = null;
let tokenCount = tokensToKeep.length;
// Create abort controller
const controller = new AbortController();
this.currentRequestController = controller;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
continue_from_prefix: true,
}),
signal: controller.signal,
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const reader = response.body?.getReader();
if (!reader) throw new Error("No response body");
const decoder = new TextDecoder();
let fullContent = prefixText;
let buffer = "";
const collectedTokens: TokenData[] = [...tokensToKeep];
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed || trimmed === "data: [DONE]") continue;
if (trimmed.startsWith("data: ")) {
try {
const json = JSON.parse(trimmed.slice(6));
const delta = json.choices?.[0]?.delta?.content;
if (delta) {
if (firstTokenTime === null) {
firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime;
}
tokenCount += 1;
this.totalTokens = tokenCount;
if (
firstTokenTime !== null &&
tokenCount > tokensToKeep.length
) {
const elapsed = performance.now() - firstTokenTime;
this.tps =
((tokenCount - tokensToKeep.length) / elapsed) * 1000;
}
// Extract logprobs
const logprobsData = json.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
collectedTokens.push({
token: logprobItem.token || delta,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: {
token: string;
logprob: number;
bytes?: number[];
}) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
});
}
fullContent += delta;
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.currentResponse = displayContent;
this.messages[messageIdx].content = displayContent;
this.messages[messageIdx].thinking =
thinkingContent || undefined;
this.messages[messageIdx].tokens = [...collectedTokens];
this.persistActiveConversation();
}
} catch {
// Skip malformed JSON
}
}
}
}
// Final update
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.messages[messageIdx].content = displayContent;
this.messages[messageIdx].thinking = thinkingContent || undefined;
this.messages[messageIdx].tokens = collectedTokens;
if (this.ttftMs !== null) {
this.messages[messageIdx].ttftMs = this.ttftMs;
}
if (this.tps !== null) {
this.messages[messageIdx].tps = this.tps;
}
this.persistActiveConversation();
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
return;
}
console.error("Error regenerating from token:", error);
this.messages[messageIdx].content =
`${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
this.persistActiveConversation();
} finally {
if (this.currentRequestController === controller) {
this.currentRequestController = null;
}
this.isLoading = false;
this.currentResponse = "";
this.updateActiveConversation();
@@ -1659,6 +1915,8 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
appStore.regenerateFromToken(messageId, tokenIndex);
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -400,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -763,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -771,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);
@@ -895,7 +915,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
const [tag, shard] = getTagged(shardWrapped);
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
return { runnerId, tag, deviceRank };
});

View File

@@ -199,7 +199,13 @@
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
?? (downloadPayload as Record<string, unknown>).downloadProgress
?? {};
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
const totalBytes = getBytes(
(downloadPayload as Record<string, unknown>).total_bytes
?? (downloadPayload as Record<string, unknown>).totalBytes
?? (rawProgress as Record<string, unknown>).total_bytes
?? (rawProgress as Record<string, unknown>).totalBytes
);
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
@@ -332,8 +338,13 @@
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
</div>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
<div>
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
</div>
<div class="text-exo-light-gray normal-case tracking-normal">
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
</div>
</div>
</div>
@@ -385,7 +396,7 @@
</div>
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
{#if model.status !== 'completed'}
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
{/if}

185
flake.lock generated
View File

@@ -1,5 +1,42 @@
{
"nodes": {
"crane": {
"locked": {
"lastModified": 1767744144,
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
"owner": "ipetkov",
"repo": "crane",
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
"type": "github"
},
"original": {
"owner": "ipetkov",
"repo": "crane",
"type": "github"
}
},
"dream2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
"owner": "nix-community",
"repo": "dream2nix",
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "dream2nix",
"type": "github"
}
},
"fenix": {
"inputs": {
"nixpkgs": [
@@ -8,11 +45,11 @@
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1761893049,
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
"lastModified": 1768287139,
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
"owner": "nix-community",
"repo": "fenix",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
"type": "github"
},
"original": {
@@ -21,25 +58,59 @@
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"flake-compat": {
"flake": false,
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-parts": {
"inputs": {
"nixpkgs-lib": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1768135262,
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "flake-parts",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1768127708,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs-swift": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
@@ -50,27 +121,74 @@
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
}
},
"purescript-overlay": {
"inputs": {
"flake-compat": "flake-compat",
"nixpkgs": [
"dream2nix",
"nixpkgs"
],
"slimlock": "slimlock"
},
"locked": {
"lastModified": 1728546539,
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
"type": "github"
},
"original": {
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"type": "github"
}
},
"root": {
"inputs": {
"crane": "crane",
"dream2nix": "dream2nix",
"fenix": "fenix",
"flake-utils": "flake-utils",
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1761849405,
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
"lastModified": 1768224240,
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
"rev": "725349602e525df37f377701e001fe8aab807878",
"type": "github"
},
"original": {
@@ -80,18 +198,25 @@
"type": "github"
}
},
"systems": {
"slimlock": {
"inputs": {
"nixpkgs": [
"dream2nix",
"purescript-overlay",
"nixpkgs"
]
},
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"lastModified": 1688756706,
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
"owner": "thomashoneyman",
"repo": "slimlock",
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"owner": "thomashoneyman",
"repo": "slimlock",
"type": "github"
}
},
@@ -102,11 +227,11 @@
]
},
"locked": {
"lastModified": 1762938485,
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
"lastModified": 1768158989,
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
"owner": "numtide",
"repo": "treefmt-nix",
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
"type": "github"
},
"original": {

210
flake.nix
View File

@@ -3,132 +3,134 @@
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
# Provides Rust dev-env integration:
flake-parts = {
url = "github:hercules-ci/flake-parts";
inputs.nixpkgs-lib.follows = "nixpkgs";
};
crane.url = "github:ipetkov/crane";
fenix = {
url = "github:nix-community/fenix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Provides formatting infrastructure:
treefmt-nix = {
url = "github:numtide/treefmt-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
};
# TODO: figure out caching story
# nixConfig = {
# # nix community cachix
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
# extra-substituters = "https://nix-community.cachix.org";
# };
nixConfig = {
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
extra-substituters = "https://exo.cachix.org";
};
outputs =
inputs:
let
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
systems = [
"x86_64-linux"
"aarch64-darwin"
"aarch64-linux"
];
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
in
inputs.flake-utils.lib.eachSystem systems (
system:
let
pkgs = import inputs.nixpkgs {
inherit system;
overlays = [ inputs.fenix.overlays.default ];
};
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
imports = [
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
];
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
treefmt = {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
};
rustfmt = {
enable = true;
package = config.rust.toolchain;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format = {
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
};
rustfmt = {
enable = true;
package = (fenixToolchain system).rustfmt;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format.enable = true;
};
};
in
{
formatter = treefmtEval.config.build.wrapper;
checks.formatting = treefmtEval.config.build.check inputs.self;
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = pkgs.mkShell {
packages =
with pkgs;
[
# FORMATTING
treefmtEval.config.build.wrapper
# PYTHON
python313
uv
ruff
basedpyright
# RUST
((fenixToolchain system).withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
]);
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];
packages =
[
# FORMATTING
config.treefmt.build.wrapper
# PYTHON
python313
uv
ruff
basedpyright
# RUST
config.rust.toolchain
maturin
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ lib.optionals stdenv.isLinux [
unixtools.ifconfig
]
++ lib.optionals stdenv.isDarwin [
macmon
];
OPENSSL_NO_VENDOR = "1";
shellHook = ''
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
${lib.optionalString stdenv.isLinux ''
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
''}
'';
};
};
}
);
};
}

View File

@@ -1,3 +1,5 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -17,12 +17,13 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx-lm>=0.28.3",
"mlx==0.30.1; sys_platform == 'darwin'",
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
]
[project.scripts]
@@ -33,6 +34,7 @@ exo = "exo.main:main"
# dependencies only required for development
[dependency-groups]
dev = [
"basedpyright>=1.29.0",
"pyinstaller>=6.17.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
@@ -98,6 +100,7 @@ root = "src"
# supported platforms for this project
[tool.uv]
prerelease = "allow"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",

145
rust/parts.nix Normal file
View File

@@ -0,0 +1,145 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
"rust-analyzer"
];
# Crane with fenix toolchain
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
# Source filtering - only include rust/ directory and root Cargo files
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
parentDir = builtins.dirOf path;
inRustDir =
(lib.hasInfix "/rust/" path)
|| (lib.hasSuffix "/rust" parentDir)
|| (baseName == "rust" && type == "directory");
isRootCargoFile =
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
&& (builtins.dirOf path == toString inputs.self);
in
isRootCargoFile
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
};
# Common arguments for all Rust builds
commonArgs = {
inherit src;
pname = "exo-rust";
version = "0.0.1";
strictDeps = true;
nativeBuildInputs = [
pkgs.pkg-config
pkgs.python313 # Required for pyo3-build-config
];
buildInputs = [
pkgs.openssl
pkgs.python313 # Required for pyo3 tests
];
OPENSSL_NO_VENDOR = "1";
# Required for pyo3 tests to find libpython
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
};
# Build dependencies once for caching
cargoArtifacts = craneLib.buildDepsOnly (
commonArgs
// {
cargoExtraArgs = "--workspace";
}
);
in
{
# Export toolchain for use in treefmt and devShell
options.rust = {
toolchain = lib.mkOption {
type = lib.types.package;
default = rustToolchain;
description = "The Rust toolchain to use";
};
};
config = {
packages = {
# Python bindings wheel via maturin
exo_pyo3_bindings = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
pname = "exo_pyo3_bindings";
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
pkgs.maturin
];
buildPhaseCargoCommand = ''
maturin build \
--release \
--manylinux off \
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
--features "pyo3/extension-module,pyo3/experimental-async" \
--interpreter ${pkgs.python313}/bin/python \
--out dist
'';
# Don't use crane's default install behavior
doNotPostBuildInstallCargoBinaries = true;
installPhaseCommand = ''
mkdir -p $out
cp dist/*.whl $out/
'';
}
);
};
checks = {
# Full workspace build (all crates)
cargo-build = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Run tests with nextest
cargo-nextest = craneLib.cargoNextest (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Build documentation
cargo-doc = craneLib.cargoDoc (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
};
};
};
}

View File

@@ -1,47 +0,0 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -1,4 +0,0 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -1,69 +0,0 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -1,6 +1,7 @@
import argparse
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -195,6 +196,8 @@ class Node:
def main():
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
@@ -202,6 +205,14 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -215,6 +226,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -256,6 +268,20 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -0,0 +1 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -0,0 +1,175 @@
"""OpenAI Chat Completions API adapter for converting requests/responses."""
import time
from collections.abc import AsyncGenerator
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionMessageText,
ChatCompletionResponse,
ChatCompletionTaskParams,
ErrorInfo,
ErrorResponse,
FinishReason,
Logprobs,
LogprobsContentItem,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import ResponseInputMessage, ResponsesRequest
def chat_request_to_internal(request: ChatCompletionTaskParams) -> ResponsesRequest:
"""Convert Chat Completions API request to ResponsesRequest (canonical internal format).
Extracts system message as instructions, converts messages to input.
"""
instructions: str | None = None
input_messages: list[ResponseInputMessage] = []
for msg in request.messages:
# Normalize content to string
content: str
if msg.content is None:
content = ""
elif isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, ChatCompletionMessageText):
content = msg.content.text
else:
# List of ChatCompletionMessageText
content = "\n".join(item.text for item in msg.content)
# Extract system message as instructions
if msg.role == "system":
if instructions is None:
instructions = content
else:
# Append additional system messages
instructions = f"{instructions}\n{content}"
else:
# Convert to ResponseInputMessage (only user, assistant, developer roles)
if msg.role in ("user", "assistant", "developer"):
input_messages.append(
ResponseInputMessage(role=msg.role, content=content)
)
return ResponsesRequest(
model=request.model,
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
stream=request.stream,
tools=request.tools,
continue_from_prefix=request.continue_from_prefix,
)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
)
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from TokenChunks."""
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ChatCompletionResponse:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
),
finish_reason=finish_reason,
)
],
)

View File

@@ -0,0 +1,190 @@
"""Claude Messages API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import ResponseInputMessage, ResponsesRequest
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def claude_request_to_internal(request: ClaudeMessagesRequest) -> ResponsesRequest:
"""Convert Claude Messages API request to ResponsesRequest (canonical internal format).
Converts Claude's system parameter to instructions,
and messages to input.
"""
# Handle system message
instructions: str | None = None
if request.system:
if isinstance(request.system, str):
instructions = request.system
else:
# List of text blocks
instructions = "".join(block.text for block in request.system)
# Convert messages to input
input_messages: list[ResponseInputMessage] = []
for msg in request.messages:
content: str
if isinstance(msg.content, str):
content = msg.content
else:
# Concatenate text blocks (images not supported for MVP)
text_parts: list[str] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
content = "".join(text_parts)
# Claude uses "user" and "assistant" roles
input_messages.append(ResponseInputMessage(role=msg.role, content=content))
return ResponsesRequest(
model=request.model,
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
)
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ClaudeMessagesResponse:
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
model=model,
content=[ClaudeTextBlock(text=combined_text)],
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
async for chunk in chunk_stream:
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -0,0 +1,173 @@
"""OpenAI Responses API adapter for converting requests/responses.
ResponsesRequest is the canonical internal format. Responses API is the most featureful,
making it the natural choice for the internal format. All other API formats (Chat
Completions, Claude) are converted TO ResponsesRequest.
"""
from collections.abc import AsyncGenerator
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ResponsesResponse:
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
accumulated_text = ""
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
output_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
return ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[output_item],
output_text=accumulated_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(response=initial_response)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(response=initial_response)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
output_index=0, content_index=0, part=initial_part
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
last_stats = None
async for chunk in chunk_stream:
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
output_index=0, content_index=0, text=accumulated_text
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
output_index=0, content_index=0, part=final_part
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[final_item],
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(response=final_response)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -1,25 +1,33 @@
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
import anyio
from anyio import create_task_group
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.adapters.chat_completions import (
chat_request_to_internal,
collect_chat_response,
generate_chat_stream,
)
from exo.master.adapters.claude import (
claude_request_to_internal,
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
@@ -32,19 +40,24 @@ from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionTaskParams,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
FinishReason,
ErrorInfo,
ErrorResponse,
GenerationStats,
ModelList,
ModelListModel,
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeMessagesRequest,
ClaudeMessagesResponse,
)
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -55,11 +68,19 @@ from exo.shared.types.commands import (
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -67,25 +88,6 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
],
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
if model_id in MODEL_CARDS:
@@ -123,6 +125,7 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -153,6 +156,20 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -176,6 +193,8 @@ class API:
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -236,7 +255,6 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -292,7 +310,6 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -383,52 +400,21 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
"""Yield `TokenChunk`s for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -443,69 +429,26 @@ class API:
await self._send(command)
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
),
finish_reason=finish_reason,
)
],
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> BenchChatCompletionResponse:
import time
from exo.shared.types.api import FinishReason
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._token_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -543,66 +486,154 @@ class API:
async def chat_completions(
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
"""OpenAI Chat Completions API - adapter."""
internal_params = chat_request_to_internal(payload)
model_meta = await resolve_model_meta(internal_params.model)
internal_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
await self._trigger_notify_user_to_download_model(internal_params.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
command = ChatCompletion(
request_params=payload,
)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
try:
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
# Convert to internal format (BenchChatCompletionTaskParams extends ChatCompletionTaskParams)
internal_params = chat_request_to_internal(payload)
model_meta = await resolve_model_meta(internal_params.model)
internal_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
await self._trigger_notify_user_to_download_model(internal_params.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
payload.stream = False
internal_params.stream = False
command = ChatCompletion(request_params=payload)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def claude_messages(
self, payload: ClaudeMessagesRequest
) -> ClaudeMessagesResponse | StreamingResponse:
"""Claude Messages API - adapter."""
internal_params = claude_request_to_internal(payload)
model_meta = await resolve_model_meta(internal_params.model)
internal_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(internal_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_claude_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
try:
return await collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def openai_responses(
self, payload: ResponsesRequest
) -> ResponsesResponse | StreamingResponse:
"""OpenAI Responses API - native format (no conversion needed)."""
model_meta = await resolve_model_meta(payload.model)
# Update model to resolved model_id
request_params = payload.model_copy(update={"model": model_meta.model_id})
if not any(
instance.shard_assignments.model_id == request_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(request_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {request_params.model}",
)
command = ChatCompletion(request_params=request_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_responses_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
try:
return await collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
return total_available
@@ -655,14 +686,14 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -158,7 +158,6 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -201,7 +200,9 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:

View File

@@ -6,10 +6,9 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
@@ -20,11 +19,10 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -54,16 +52,19 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
cycles = topology.get_cycles() + [[node] for node in all_nodes]
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_profiles, command.model_meta.storage_size
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
if len(cycles_with_sufficient_memory) == 0:
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
@@ -93,15 +94,13 @@ def place_instance(
smallest_tb_cycles = [
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(
[node.node_id for node in cycle]
).is_thunderbolt_cycle([node.node_id for node in cycle])
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node.node_id) for node in cycle)
@@ -110,7 +109,11 @@ def place_instance(
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node.node_profile.memory.ram_available for node in cycle),
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
start=Memory(),
),
)
@@ -119,16 +122,14 @@ def place_instance(
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
[node.node_id for node in selected_cycle]
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -136,18 +137,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle[0].node_id,
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:

View File

@@ -1,4 +1,5 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import TypeGuard, cast
from loguru import logger
from pydantic import BaseModel
@@ -8,7 +9,7 @@ from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -23,32 +24,27 @@ class NodeWithProfile(BaseModel):
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[list[NodeId]],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[list[NodeWithProfile]]:
filtered_cycles: list[list[NodeWithProfile]] = []
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
for cycle in cycles:
if not all(node in node_profiles for node in cycle):
if not narrow_all_nodes(cycle):
continue
total_mem = sum(
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(
[
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
for node in cycle
]
)
filtered_cycles.append(cast(list[NodeInfo], cycle))
return filtered_cycles
def get_smallest_cycles(
cycles: list[list[NodeWithProfile]],
) -> list[list[NodeWithProfile]]:
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
@@ -139,9 +135,11 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
selected_cycle: list[NodeInfo],
sharding: Sharding,
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
@@ -178,16 +176,17 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for src, sink, connection in cycle_digraph.list_connections():
if not isinstance(connection, SocketConnection):
continue
if src == current_node and sink == next_node:
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break
@@ -195,7 +194,8 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
return hosts
def get_mlx_jaccl_devices_matrix(
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -204,7 +204,6 @@ def get_mlx_jaccl_devices_matrix(
to device j, or None if no connection exists or no interface name is found.
Diagonal elements are always None.
"""
selected_cycle = list(cycle_digraph.list_nodes())
num_nodes = len(selected_cycle)
matrix: list[list[str | None]] = [
[None for _ in range(num_nodes)] for _ in range(num_nodes)
@@ -215,38 +214,71 @@ def get_mlx_jaccl_devices_matrix(
if i == j:
continue
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
break
else:
logger.warning(
f"Failed to find interface name between {node_i} and {node_j}"
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
"Current ibv backend requires all-to-all rdma connections"
)
return matrix
def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j."""
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def _find_interface_name_for_ip(
ip_address: str,
node_info: NodeWithProfile,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -255,7 +287,7 @@ def _find_interface_name_for_ip(
def _find_ip_prioritised(
node: NodeWithProfile, other_node: NodeWithProfile, cycle_digraph: Topology
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -266,7 +298,7 @@ def _find_ip_prioritised(
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node.node_id, other_node.node_id, cycle_digraph))
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
@@ -292,7 +324,7 @@ def _find_ip_prioritised(
def get_mlx_ring_hosts_by_node(
selected_cycle: list[NodeWithProfile],
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
) -> dict[NodeId, list[Host]]:
@@ -329,7 +361,7 @@ def get_mlx_ring_hosts_by_node(
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node}"
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
@@ -343,30 +375,31 @@ def get_mlx_ring_hosts_by_node(
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
selected_cycle = list(cycle_digraph.list_nodes())
logger.info(f"Selecting coordinator: {coordinator}")
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip, _ in _find_connection_ip(n, coordinator, cycle_digraph):
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}
return {
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}

View File

@@ -1,36 +1,67 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
# TODO: this is a hack to get the port for the send_back_multiaddr
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)
return _create_connection

View File

@@ -0,0 +1,107 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -0,0 +1,283 @@
"""Tests for Claude Messages API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.claude import (
claude_request_to_internal,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessage,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToInternal:
"""Tests for converting Claude Messages API requests to ResponsesRequest."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.model == "claude-3-opus"
assert params.max_output_tokens == 100
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
assert params.instructions is None
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.instructions == "You are a helpful assistant."
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.instructions == "You are helpful. Be concise."
assert isinstance(params.input, list)
assert len(params.input) == 1
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_internal(request)
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_internal(request)
assert isinstance(params.input, list)
assert len(params.input) == 3
assert params.input[0].role == "user"
assert params.input[1].role == "assistant"
assert params.input[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_internal(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)
class TestClaudeStreamingEvents:
"""Tests for Claude Messages API streaming event serialization."""
def test_message_start_event_format(self):
message = ClaudeMessageStart(
id="msg_123",
model="claude-3-opus",
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
)
event = ClaudeMessageStartEvent(message=message)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_start"
assert parsed["message"]["id"] == "msg_123"
assert parsed["message"]["type"] == "message"
assert parsed["message"]["role"] == "assistant"
assert parsed["message"]["model"] == "claude-3-opus"
def test_content_block_start_event_format(self):
event = ClaudeContentBlockStartEvent(
index=0,
content_block=ClaudeTextBlock(text=""),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_start"
assert parsed["index"] == 0
assert parsed["content_block"]["type"] == "text"
assert parsed["content_block"]["text"] == ""
def test_content_block_delta_event_format(self):
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_delta"
assert parsed["index"] == 0
assert parsed["delta"]["type"] == "text_delta"
assert parsed["delta"]["text"] == "Hello"
def test_content_block_stop_event_format(self):
event = ClaudeContentBlockStopEvent(index=0)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_stop"
assert parsed["index"] == 0
def test_message_delta_event_format(self):
event = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason="end_turn"),
usage=ClaudeMessageDeltaUsage(output_tokens=25),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_delta"
assert parsed["delta"]["stop_reason"] == "end_turn"
assert parsed["usage"]["output_tokens"] == 25
def test_message_stop_event_format(self):
event = ClaudeMessageStopEvent()
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_stop"
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
# Simulate the SSE format used in the streaming generator
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
assert sse_line.startswith("event: content_block_delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -7,7 +7,6 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
CommandId,
@@ -19,13 +18,16 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodeGatheredInfo,
NodePerformanceMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -81,14 +83,21 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
NodePerformanceMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
),
@@ -134,13 +143,9 @@ async def test_master():
command=(
ChatCompletion(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
request_params=ResponsesRequest(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
input="Hello, how are you?",
),
)
),
@@ -154,7 +159,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
@@ -191,11 +196,9 @@ async def test_master():
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)
assert events[2].event.task.task_params == ChatCompletionTaskParams(
assert events[2].event.task.task_params == ResponsesRequest(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")
],
input="Hello, how are you?",
)
await master.shutdown()

View File

@@ -0,0 +1,293 @@
"""Tests for OpenAI Responses API types.
ResponsesRequest is the canonical internal type used throughout the pipeline.
No conversion is needed for Responses API requests.
"""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseCreatedEvent,
ResponseInputMessage,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
class TestResponsesRequestAsCanonicalType:
"""Tests for ResponsesRequest as the canonical internal type."""
def test_string_input(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello, how are you?",
)
assert request.model == "gpt-4o"
assert request.input == "Hello, how are you?"
assert request.instructions is None
def test_message_array_input(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="user", content="Hello"),
ResponseInputMessage(role="assistant", content="Hi there!"),
ResponseInputMessage(role="user", content="How are you?"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 3
assert request.input[0].role == "user"
assert request.input[0].content == "Hello"
assert request.input[1].role == "assistant"
assert request.input[1].content == "Hi there!"
assert request.input[2].role == "user"
assert request.input[2].content == "How are you?"
def test_request_with_instructions(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
instructions="You are a helpful assistant. Be concise.",
)
assert request.input == "Hello"
assert request.instructions == "You are a helpful assistant. Be concise."
def test_request_with_optional_parameters(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
max_output_tokens=500,
temperature=0.8,
top_p=0.95,
stream=True,
)
assert request.max_output_tokens == 500
assert request.temperature == 0.8
assert request.top_p == 0.95
assert request.stream is True
def test_request_with_new_fields(self):
"""Test the additional fields added for internal use."""
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
top_k=40,
seed=42,
stop=["STOP", "END"],
tools=[{"type": "function", "function": {"name": "test"}}],
)
assert request.top_k == 40
assert request.seed == 42
assert request.stop == ["STOP", "END"]
assert request.tools == [{"type": "function", "function": {"name": "test"}}]
def test_request_with_system_role_in_messages(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="system", content="Be helpful"),
ResponseInputMessage(role="user", content="Hello"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 2
assert request.input[0].role == "system"
assert request.input[1].role == "user"
def test_request_with_developer_role(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="developer", content="Internal note"),
ResponseInputMessage(role="user", content="Hello"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 2
assert request.input[0].role == "developer"
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model="gpt-4o",
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1
class TestResponsesStreamingEvents:
"""Tests for OpenAI Responses API streaming event serialization."""
def test_response_created_event_format(self):
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="in_progress",
output=[],
output_text="",
)
event = ResponseCreatedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.created"
assert parsed["response"]["id"] == "resp_123"
assert parsed["response"]["object"] == "response"
assert parsed["response"]["status"] == "in_progress"
def test_output_item_added_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="")],
status="in_progress",
)
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.added"
assert parsed["output_index"] == 0
assert parsed["item"]["type"] == "message"
assert parsed["item"]["id"] == "item_123"
assert parsed["item"]["role"] == "assistant"
def test_content_part_added_event_format(self):
part = ResponseOutputText(text="")
event = ResponseContentPartAddedEvent(
output_index=0,
content_index=0,
part=part,
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.content_part.added"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["part"]["type"] == "output_text"
def test_text_delta_event_format(self):
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.delta"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["delta"] == "Hello"
def test_text_done_event_format(self):
event = ResponseTextDoneEvent(
output_index=0,
content_index=0,
text="Hello, world!",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.done"
assert parsed["text"] == "Hello, world!"
def test_output_item_done_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello, world!")],
status="completed",
)
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.done"
assert parsed["item"]["status"] == "completed"
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
def test_response_completed_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello!")],
status="completed",
)
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="completed",
output=[item],
output_text="Hello!",
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
)
event = ResponseCompletedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.completed"
assert parsed["response"]["status"] == "completed"
assert parsed["response"]["output_text"] == "Hello!"
assert parsed["response"]["usage"]["total_tokens"] == 15
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
# Simulate the SSE format used in the streaming generator
sse_line = (
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
)
assert sse_line.startswith("event: response.output_text.delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -1,3 +1,5 @@
from typing import Callable
import pytest
from loguru import logger
@@ -5,20 +7,14 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_connection,
create_node_profile,
create_rdma_connection,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import SocketConnection
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -30,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -76,36 +77,34 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(2))
topology.add_connection(node_id_c, node_id_a, create_connection(3))
topology.add_connection(node_id_c, node_id_b, create_connection(4))
topology.add_connection(node_id_a, node_id_c, create_connection(5))
topology.add_connection(node_id_b, node_id_a, create_connection(6))
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
@@ -131,11 +130,12 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit() -> None:
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -146,7 +146,7 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -157,11 +157,12 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
topology.add_node(create_node(1001 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -172,7 +173,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -183,11 +184,12 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit() -> None:
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
@@ -200,7 +202,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, profiles)
place_instance(cic, topology, {})
def test_get_transition_events_no_change(instance: Instance):
@@ -245,103 +247,179 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_leaf_nodes(
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
model_meta.storage_size = Memory.from_bytes(1000)
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Daisy chain topology
topology.add_connection(node_id_a, node_id_b, create_connection(1))
topology.add_connection(node_id_b, node_id_a, create_connection(1))
topology.add_connection(node_id_b, node_id_c, create_connection(1))
topology.add_connection(node_id_c, node_id_b, create_connection(1))
topology.add_connection(node_id_c, node_id_d, create_connection(1))
topology.add_connection(node_id_d, node_id_c, create_connection(1))
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
logger.info(list(topology.list_connections()))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
)
# act
placements = place_instance(cic, topology, {}, profiles)
# Act
placements = place_instance(cic, topology, {})
# assert
# Assert: D-E-F cycle should be selected as it has more total memory
assert len(placements) == 1
instance = list(placements.values())[0]
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(node_id_c, node_id_d)
)
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="192.168.1.100",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a, node_b, create_rdma_connection(3))
topology.add_connection(node_b, node_c, create_rdma_connection(4))
topology.add_connection(node_c, node_a, create_rdma_connection(5))
topology.add_connection(node_b, node_a, create_rdma_connection(3))
topology.add_connection(node_c, node_b, create_rdma_connection(4))
topology.add_connection(node_a, node_c, create_rdma_connection(5))
topology.add_connection(node_a, node_b, ethernet_conn)
topology.add_connection(node_b, node_c, ethernet_conn)
topology.add_connection(node_c, node_a, ethernet_conn)
topology.add_connection(node_a, node_c, ethernet_conn)
topology.add_connection(node_b, node_a, ethernet_conn)
topology.add_connection(node_c, node_b, ethernet_conn)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
cic = PlaceInstance(
sharding=Sharding.Tensor,
@@ -351,7 +429,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
min_nodes=1,
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -359,10 +437,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert isinstance(instance, MlxJacclInstance)
assert instance.jaccl_devices is not None
assert instance.ibv_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.jaccl_devices
matrix = instance.ibv_devices
assert len(matrix) == 3
for i in range(3):
@@ -371,15 +449,15 @@ def test_tensor_rdma_backend_connectivity_matrix(
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3

View File

@@ -1,48 +1,56 @@
from typing import Callable
import pytest
from exo.master.placement_utils import (
NodeWithProfile,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_connection, create_node_profile
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(1)
connection2 = create_connection(2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
@@ -50,65 +58,64 @@ def test_filter_cycles_by_memory():
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory():
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
topology = Topology()
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(1)
connection2 = create_connection(2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology.add_connection(node1_id, node2_id, connection1)
topology.add_connection(node2_id, node1_id, connection2)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
topology.get_cycles(), Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory():
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
@@ -120,38 +127,31 @@ def test_filter_multiple_cycles_by_memory():
}
def test_get_smallest_cycles():
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
# act
smallest_cycles = get_smallest_cycles(cycles)
smallest_cycles = get_smallest_cycles(topology.get_cycles())
# assert
assert len(smallest_cycles) == 1
@@ -168,6 +168,9 @@ def test_get_smallest_cycles():
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -176,25 +179,19 @@ def test_get_shard_assignments(
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_c_id, create_connection(2))
topology.add_connection(node_c_id, node_a_id, create_connection(3))
topology.add_connection(node_b_id, node_a_id, create_connection(4))
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
@@ -204,11 +201,7 @@ def test_get_shard_assignments(
hidden_size=1000,
supports_tensor=True,
)
cycles = [
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
for cycle in topology.get_cycles()
]
cycles = topology.get_cycles()
selected_cycle = cycles[0]
# act
@@ -237,21 +230,28 @@ def test_get_shard_assignments(
)
def test_get_hosts_from_subgraph():
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -259,47 +259,108 @@ def test_get_hosts_from_subgraph():
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip=("169.254.0.2"), port=1234),
Host(ip=("169.254.0.3"), port=1234),
Host(ip=("169.254.0.4"), port=1234),
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology.add_connection(node_a_id, node_b_id, create_connection(1))
topology.add_connection(node_b_id, node_a_id, create_connection(2))
topology.add_connection(node_a_id, node_c_id, create_connection(3))
topology.add_connection(node_c_id, node_b_id, create_connection(4))
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
conn_a_b = create_connection(1)
conn_b_a = create_connection(2)
conn_b_c = create_connection(3)
conn_c_b = create_connection(4)
conn_c_a = create_connection(5)
conn_a_c = create_connection(6)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
topology.add_connection(node_a_id, node_b_id, conn_a_b)
topology.add_connection(node_b_id, node_a_id, conn_b_a)
topology.add_connection(node_b_id, node_c_id, conn_b_c)
topology.add_connection(node_c_id, node_b_id, conn_c_b)
topology.add_connection(node_c_id, node_a_id, conn_c_a)
topology.add_connection(node_a_id, node_c_id, conn_a_c)
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_b)
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_jaccl_coordinators(
node_a_id, coordinator_port=5000, cycle_digraph=topology
cycle, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -328,11 +389,11 @@ def test_get_mlx_jaccl_coordinators():
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
"node_b should use the IP from conn_b_a"
)
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
"node_c should use the IP from conn_c_a"
)
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"

View File

@@ -1,14 +1,13 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
@pytest.fixture
@@ -17,15 +16,20 @@ def topology() -> Topology:
@pytest.fixture
def connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryUsage.from_bytes(
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -39,85 +43,162 @@ def node_profile() -> NodePerformanceProfile:
)
def test_add_node(topology: Topology):
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
# arrange
node_id = NodeId()
# act
topology.add_node(node_id)
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
# assert
assert topology.node_is_leaf(node_id)
data = topology.get_node_profile(node_id)
assert data == node_profile
def test_add_connection(topology: Topology, connection: SocketConnection):
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
data = list(conn for _, _, conn in topology.list_connections())
data = topology.get_connection_profile(connection)
# assert
assert data == [connection]
assert data == connection.connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
def test_remove_connection_still_connected(
topology: Topology, connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_connection(node_a, node_b, connection)
topology.remove_connection(connection)
# assert
assert list(topology.get_all_connections_between(node_a, node_b)) == []
assert topology.get_connection_profile(connection) is None
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
def test_remove_node_still_connected(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_node(node_b)
topology.remove_node(connection.local_node_id)
# assert
assert list(topology.out_edges(node_a)) == []
assert topology.get_node_profile(connection.local_node_id) is None
def test_list_nodes(topology: Topology, connection: SocketConnection):
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(node_a, node_b, connection)
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeId) for node in nodes)
assert {node for node in nodes} == {node_a, node_b}
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}

View File

@@ -11,8 +11,10 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -25,23 +27,13 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import RDMAConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacTBConnections,
MacTBIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
@@ -55,12 +47,16 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -192,7 +188,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -200,12 +196,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -213,69 +205,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
# TODO: should be broken up into individual events instead of this monster
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
# TODO: makes me slightly sad
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
case MacTBIdentifiers():
profile.tb_interfaces = info.idents
case MacTBConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
(
conn_map[tb_conn.sink_uuid][0],
RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_profiles": new_profiles,
"last_seen": last_seen,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_connection(event.source, event.sink, event.edge)
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.sink, event.source, event.edge)
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -38,7 +38,6 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"

View File

@@ -24,13 +24,16 @@ class _InterceptHandler(logging.Handler):
except ValueError:
level = record.levelno
return
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -14,32 +14,6 @@ class ModelCard(CamelCaseModel):
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
# "deepseek-v3-0324:4bit": ModelCard(
# short_id="deepseek-v3-0324:4bit",
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
# name="DeepSeek V3 0324 (4-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
# pretty_name="DeepSeek V3 0324 (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# ),
# ),
# "deepseek-v3-0324": ModelCard(
# short_id="deepseek-v3-0324",
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
# name="DeepSeek V3 0324 (8-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
# pretty_name="DeepSeek V3 0324 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# ),
# ),
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -70,63 +44,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# "deepseek-v3.2": ModelCard(
# short_id="deepseek-v3.2",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# name="DeepSeek V3.2 (8-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# pretty_name="DeepSeek V3.2 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-v3.2-4bit": ModelCard(
# short_id="deepseek-v3.2-4bit",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# name="DeepSeek V3.2 (4-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# pretty_name="DeepSeek V3.2 (4-bit)",
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# deepseek r1
# "deepseek-r1-0528-4bit": ModelCard(
# short_id="deepseek-r1-0528-4bit",
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
# name="DeepSeek-R1-0528 (4-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
# pretty_name="DeepSeek R1 671B (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-r1-0528": ModelCard(
# short_id="deepseek-r1-0528",
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
# name="DeepSeek-R1-0528 (8-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
# pretty_name="DeepSeek R1 671B (8-bit)",
# storage_size=Memory.from_bytes(754998771712),
# n_layers=61,
# . hidden_size=7168,
# ),
# ),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
@@ -508,23 +425,24 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# Needs to be quantized g32 or g16.
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
@@ -554,19 +472,81 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# storage_size=Memory.from_kb(133_000_000),
# n_layers=88,
# hidden_size=12288,
# supports_tensor=True,
# ),
# ),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
}

View File

@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})
@@ -39,4 +43,7 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import SocketConnection
from exo.shared.types.topology import Connection
def test_state_serialization_roundtrip() -> None:
@@ -11,16 +11,17 @@ def test_state_serialization_roundtrip() -> None:
node_a = NodeId("node-a")
node_b = NodeId("node-b")
connection = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
connection = Connection(
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
state.topology.add_connection(node_a, node_b, connection)
state.topology.add_connection(connection)
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)
assert state.topology.to_snapshot().nodes == restored_state.topology.to_snapshot().nodes
assert set(state.topology.to_snapshot().connections) == set(restored_state.topology.to_snapshot().connections)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert restored_state.model_dump_json() == json_repr

View File

@@ -1,215 +1,203 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
class TopologySnapshot(BaseModel):
nodes: Sequence[NodeId]
connections: Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]
nodes: list[NodeInfo]
connections: list[Connection]
model_config = ConfigDict(frozen=True, extra="forbid")
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
@dataclass
class Topology:
# the _graph can be used as a int -> NodeId map.
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()), connections=self.list_connections()
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node_id in snapshot.nodes:
for node in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node_id)
topology.add_node(node)
for source, sink, conn in snapshot.connections:
topology.add_connection(source, sink, conn)
for connection in snapshot.connections:
topology.add_connection(connection)
return topology
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
return
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
]
def out_edges(
self, node_id: NodeId
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
if node_id not in self._vertex_indices:
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return (
(self._graph[nid], conn)
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
)
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
]
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._vertex_indices
return node_id in self._node_id_to_rx_id_map
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
source: NodeId,
sink: NodeId,
connection: SocketConnection | RDMAConnection,
connection: Connection,
) -> None:
if connection in self.get_all_connections_between(source, sink):
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
return
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
_ = self._graph.add_edge(src_id, sink_id, connection)
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
try:
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def list_connections(
self,
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
return (
(
self._graph[src_id],
self._graph[sink_id],
connection,
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._vertex_indices:
if node_id not in self._node_id_to_rx_id_map:
return
rx_idx = self._vertex_indices[node_id]
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph.remove_node(rx_idx)
del self._vertex_indices[node_id]
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
def replace_all_out_tb_connections(
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for sink, conn in new_connections:
self.add_connection(source, sink, conn)
def remove_connection(
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
) -> None:
if source not in self._vertex_indices or sink not in self._vertex_indices:
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
return
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[source], self._vertex_indices[sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == edge:
self._graph.remove_edge_from_index(conn_idx)
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
def get_cycles(self) -> list[list[NodeId]]:
def get_cycles(self) -> list[list[NodeInfo]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[list[NodeId]] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_cycles_tb(self) -> list[list[NodeId]]:
def get_cycles_tb(self) -> list[list[NodeInfo]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[list[NodeId]] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
topology = Topology()
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for source, sink, connection in self.list_connections():
if source in node_ids and sink in node_ids:
topology.add_connection(source, sink, connection)
for connection in self.list_connections():
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -11,10 +11,21 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call"
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"
@@ -146,10 +157,13 @@ class ChatCompletionTaskParams(BaseModel):
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# When True, continue the last assistant message without EOS tokens
continue_from_prefix: bool = False
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):

View File

@@ -1,6 +1,6 @@
from enum import Enum
from exo.shared.types.api import GenerationStats
from exo.shared.types.api import GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -20,8 +20,11 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):

View File

@@ -0,0 +1,168 @@
"""Claude Messages API types for request/response conversion."""
from typing import Literal
from pydantic import BaseModel, Field
# Type aliases
ClaudeRole = Literal["user", "assistant"]
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
# Content block types
class ClaudeTextBlock(BaseModel, frozen=True):
"""Text content block in Claude Messages API."""
type: Literal["text"] = "text"
text: str
class ClaudeImageSource(BaseModel, frozen=True):
"""Image source for Claude image blocks."""
type: Literal["base64", "url"]
media_type: str | None = None
data: str | None = None
url: str | None = None
class ClaudeImageBlock(BaseModel, frozen=True):
"""Image content block in Claude Messages API."""
type: Literal["image"] = "image"
source: ClaudeImageSource
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
# Request types
class ClaudeMessage(BaseModel, frozen=True):
"""Message in Claude Messages API request."""
role: ClaudeRole
content: str | list[ClaudeContentBlock]
class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API."""
model: str
max_tokens: int
messages: list[ClaudeMessage]
system: str | list[ClaudeTextBlock] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
metadata: dict[str, str] | None = None
# Response types
class ClaudeUsage(BaseModel, frozen=True):
"""Token usage in Claude Messages API response."""
input_tokens: int
output_tokens: int
class ClaudeMessagesResponse(BaseModel, frozen=True):
"""Response body for Claude Messages API."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock]
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
# Streaming event types
class ClaudeMessageStart(BaseModel, frozen=True):
"""Partial message in message_start event."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock] = Field(default_factory=list)
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
class ClaudeMessageStartEvent(BaseModel, frozen=True):
"""Event sent at start of message stream."""
type: Literal["message_start"] = "message_start"
message: ClaudeMessageStart
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
"""Event sent at start of a content block."""
type: Literal["content_block_start"] = "content_block_start"
index: int
content_block: ClaudeTextBlock
class ClaudeTextDelta(BaseModel, frozen=True):
"""Delta for text content block."""
type: Literal["text_delta"] = "text_delta"
text: str
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
"""Event sent for content block delta."""
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: ClaudeTextDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
"""Event sent at end of a content block."""
type: Literal["content_block_stop"] = "content_block_stop"
index: int
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
"""Usage in message_delta event."""
output_tokens: int
class ClaudeMessageDelta(BaseModel, frozen=True):
"""Delta in message_delta event."""
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
"""Event sent with final message delta."""
type: Literal["message_delta"] = "message_delta"
delta: ClaudeMessageDelta
usage: ClaudeMessageDeltaUsage
class ClaudeMessageStopEvent(BaseModel, frozen=True):
"""Event sent at end of message stream."""
type: Literal["message_stop"] = "message_stop"
ClaudeStreamEvent = (
ClaudeMessageStartEvent
| ClaudeContentBlockStartEvent
| ClaudeContentBlockDeltaEvent
| ClaudeContentBlockStopEvent
| ClaudeMessageDeltaEvent
| ClaudeMessageStopEvent
)

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -17,7 +17,7 @@ class TestCommand(BaseCommand):
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
request_params: ResponsesRequest
class PlaceInstance(BaseCommand):

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import SocketConnection
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
# TODO: bikeshed this naem
class NodeGatheredInfo(BaseEvent):
class NodePerformanceMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
class NodeDownloadProgress(BaseEvent):
@@ -97,15 +107,11 @@ class ChunkGenerated(BaseEvent):
class TopologyEdgeCreated(BaseEvent):
source: NodeId
sink: NodeId
edge: SocketConnection
edge: Connection
class TopologyEdgeDeleted(BaseEvent):
source: NodeId
sink: NodeId
edge: SocketConnection
edge: Connection
Event = (
@@ -119,8 +125,10 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodeGatheredInfo
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| TopologyEdgeCreated

View File

@@ -1,11 +1,10 @@
import re
from typing import ClassVar
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from pydantic import BaseModel, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -0,0 +1,190 @@
"""OpenAI Responses API types for request/response conversion.
ResponsesRequest serves as both:
1. The external API request type for /v1/responses
2. The canonical internal type used throughout the inference pipeline
All external API formats (Chat Completions, Claude) are converted to
ResponsesRequest at the API boundary.
"""
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
# Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
ResponseRole = Literal["user", "assistant", "system", "developer"]
# Request types
class ResponseInputMessage(BaseModel, frozen=True):
"""Input message for Responses API.
This is also used as the internal message format throughout the pipeline.
"""
role: ResponseRole
content: str
class ResponsesRequest(BaseModel):
"""Request body for OpenAI Responses API.
This is also the canonical internal task params format used throughout
the inference pipeline. All external API formats are converted to this
format at the API boundary.
Field mapping from other APIs:
- input: Replaces 'messages' from Chat Completions
- instructions: System message, extracted from messages or Claude's 'system'
- max_output_tokens: Replaces 'max_tokens' from Chat Completions
"""
model: str
input: str | list[ResponseInputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
stream: bool = False
# Tools support
tools: list[dict[str, Any]] | None = None
# previous_response_id not supported in MVP
metadata: dict[str, str] | None = None
# When True, continue the last assistant message without EOS tokens
continue_from_prefix: bool = False
# Response types
class ResponseOutputText(BaseModel, frozen=True):
"""Text content in response output."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list[dict[str, str]] = Field(default_factory=list)
class ResponseMessageItem(BaseModel, frozen=True):
"""Message item in response output array."""
type: Literal["message"] = "message"
id: str
role: Literal["assistant"] = "assistant"
content: list[ResponseOutputText]
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
class ResponseUsage(BaseModel, frozen=True):
"""Token usage in Responses API response."""
input_tokens: int
output_tokens: int
total_tokens: int
class ResponsesResponse(BaseModel, frozen=True):
"""Response body for OpenAI Responses API."""
id: str
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
status: ResponseStatus = "completed"
model: str
output: list[ResponseItem]
output_text: str
usage: ResponseUsage | None = None
# Streaming event types
class ResponseCreatedEvent(BaseModel, frozen=True):
"""Event sent when response is created."""
type: Literal["response.created"] = "response.created"
response: ResponsesResponse
class ResponseInProgressEvent(BaseModel, frozen=True):
"""Event sent when response starts processing."""
type: Literal["response.in_progress"] = "response.in_progress"
response: ResponsesResponse
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
"""Event sent when an output item is added."""
type: Literal["response.output_item.added"] = "response.output_item.added"
output_index: int
item: ResponseItem
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a content part is added."""
type: Literal["response.content_part.added"] = "response.content_part.added"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for text delta during streaming."""
type: Literal["response.output_text.delta"] = "response.output_text.delta"
output_index: int
content_index: int
delta: str
class ResponseTextDoneEvent(BaseModel, frozen=True):
"""Event sent when text content is done."""
type: Literal["response.output_text.done"] = "response.output_text.done"
output_index: int
content_index: int
text: str
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a content part is done."""
type: Literal["response.content_part.done"] = "response.content_part.done"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
"""Event sent when an output item is done."""
type: Literal["response.output_item.done"] = "response.output_item.done"
output_index: int
item: ResponseItem
class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed."""
type: Literal["response.completed"] = "response.completed"
response: ResponsesResponse
ResponsesStreamEvent = (
ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseOutputItemAddedEvent
| ResponseContentPartAddedEvent
| ResponseTextDeltaEvent
| ResponseTextDoneEvent
| ResponseContentPartDoneEvent
| ResponseOutputItemDoneEvent
| ResponseCompletedEvent
)

View File

@@ -1,14 +1,12 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import TBIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryUsage(CamelCaseModel):
class MemoryPerformanceProfile(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -54,16 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[TBIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
class ConnectionProfile(CamelCaseModel):
pass
throughput: float
latency: float
jitter: float

View File

@@ -2,8 +2,8 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, Id
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
@@ -50,7 +50,7 @@ class StartWarmup(BaseTask): # emitted by Worker
class ChatCompletion(BaseTask): # emitted by Master
command_id: CommandId
task_params: ChatCompletionTaskParams
task_params: ResponsesRequest
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)

View File

@@ -1,75 +0,0 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class TBConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class TBIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
## Intentionally minimal, only collecting data we care about - there's a lot more
class TBReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str | None = None
class TBConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
class TBConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: TBReceptacleTag | None = None
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
if (
self.domain_uuid_key is None
or self.receptacle_1_tag is None
or self.receptacle_1_tag.receptacle_id_key is None
):
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
assert tag in ifaces # doesn't need to be an assertion but im confident
# if tag not in ifaces: return None
iface = f"rdma_{ifaces[tag]}"
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
def conn(self) -> TBConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
),
None,
)
if sink_key is None:
return None
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
class TBConnectivity(BaseModel, extra="ignore"):
SPThunderboltDataType: list[TBConnectivityData] = []
@classmethod
async def gather(cls) -> list[TBConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType

View File

@@ -1,32 +1,37 @@
from enum import Enum
from loguru import logger
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.utils.pydantic_ext import FrozenModel
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
def is_thunderbolt(self) -> bool:
logger.warning("duh")
return True
# TODO
class LinkType(str, Enum):
Thunderbolt = "Thunderbolt"
Ethernet = "Ethernet"
WiFi = "WiFi"
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
pass
total_bytes: Memory
class DownloadFailed(BaseDownloadProgress):

View File

@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
ibv_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]

View File

@@ -0,0 +1,43 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Callable
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
SystemPerformanceProfile,
)
class ResourceCollector(ABC):
@abstractmethod
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
class SystemResourceCollector(ResourceCollector):
async def collect(self) -> SystemPerformanceProfile: ...
class MemoryResourceCollector(ResourceCollector):
async def collect(self) -> MemoryPerformanceProfile: ...
class ResourceMonitor:
data_collectors: list[ResourceCollector]
effect_handlers: set[
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
]
async def _collect(
self,
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
tasks: list[
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
] = [collector.collect() for collector in self.data_collectors]
return await asyncio.gather(*tasks)
async def collect(self) -> None:
profiles = await self._collect()
for profile in profiles:
for effect_handler in self.effect_handlers:
effect_handler(profile)

View File

@@ -1,4 +1,4 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.api import FinishReason, GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
@@ -13,7 +13,8 @@ class TokenizedResponse(BaseRunnerResponse):
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None

View File

@@ -1,232 +0,0 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import TBConnection, TBConnectivity, TBIdentifier
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacTBIdentifiers(TaggedModel):
idents: Sequence[TBIdentifier]
class MacTBConnections(TaggedModel):
conns: Sequence[TBConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
# TODO
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacTBIdentifiers
| MacTBConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
if IS_DARWIN:
tg.start_soon(self._monitor_system_profiler)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await TBConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacTBIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacTBConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -1,70 +0,0 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -1,81 +0,0 @@
import http.client
from collections.abc import Mapping
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = response.read().decode("utf-8").strip()
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
return None
finally:
connection.close()
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(
topology: Topology,
profiles: Mapping[NodeId, NodePerformanceProfile],
self_node_id: NodeId,
) -> dict[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node_id in topology.list_nodes():
if node_id not in profiles:
continue
for iface in profiles[node_id].network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node_id,
self_node_id,
reachable,
)
return reachable

View File

@@ -1,18 +0,0 @@
import pytest
import sys
from exo.shared.types.thunderbolt import (
TBConnectivity,
)
from exo.utils.info_gatherer.info_gatherer import _gather_iface_map # pyright: ignore[reportPrivateUsage]
@pytest.mark.anyio
@pytest.mark.skipif(sys.platform != "darwin", reason="TB info can only be gathered on macos")
async def test_tb_parsing():
data = await TBConnectivity.gather()
ifaces = await _gather_iface_map()
assert ifaces
assert data
for datum in data:
datum.ident(ifaces)
datum.conn()

View File

@@ -19,20 +19,11 @@ class CamelCaseModel(BaseModel):
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
strict=True,
)
class FrozenModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
strict=True,
frozen=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):

View File

@@ -28,8 +28,9 @@ def bar(send: MpSender[str]):
send.close()
# not async, just want the fail_after
@pytest.mark.anyio
async def test_channel_ipc():
async def test_channel_setup():
with fail_after(0.5):
s, r = mp_channel[str]()
p1 = mp.Process(target=foo, args=(r,))

View File

@@ -40,4 +40,6 @@ class TokenizerWrapper:
messages_dicts: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
) -> str: ...

View File

@@ -10,18 +10,24 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.cache import (
_BaseCache, # pyright: ignore[reportPrivateUsage]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
)
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
class _LayerCallable(Protocol):
@@ -91,8 +97,6 @@ class PipelineLastLayer(CustomMlxLayer):
x, *args, **kwargs
).arguments.get("cache", None)
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
@@ -100,7 +104,6 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# This change happened upstream - check out mlx github somewhere??
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
@@ -132,24 +135,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
return layers
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(model, DeepseekV3Model) and hasattr(
inner_model_instance, "num_layers"
):
inner_model_instance.start_idx = 0
inner_model_instance.end_idx = len(layers)
inner_model_instance.num_layers = len(layers)
elif hasattr(inner_model_instance, "h"):
inner_model_instance.h = layers
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
def pipeline_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
@@ -165,8 +150,7 @@ def pipeline_auto_parallel(
"""
inner_model_instance: nn.Module = _inner_model(model)
# Handle both model.layers and model.h cases
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
@@ -180,6 +164,17 @@ def pipeline_auto_parallel(
group=group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -204,18 +199,44 @@ def tensor_auto_parallel(
group=group,
)
segments: int = 1
def _all_to_sharded(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - all to sharded")
return weight.ndim - 1, segments
return max(weight.ndim - 2, 0), segments
all_to_sharded_linear_in_place = partial(
shard_inplace,
sharding="all-to-sharded",
group=group,
)
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding="sharded-to-all",
sharding=_all_to_sharded, # type: ignore
group=group,
)
if isinstance(model, LlamaModel):
n = group.size()
def _sharded_to_all(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - sharded to all")
weight /= n
return None
return -1, segments
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding=_sharded_to_all, # type: ignore
group=group,
)
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -223,7 +244,8 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, DeepseekV3Model):
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -231,7 +253,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Qwen3MoeModel):
elif isinstance(model, MiniMaxModel):
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -239,6 +269,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -284,13 +323,38 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
) and hasattr(inner_model_instance, "num_layers"):
logger.info(
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
)
inner_model_instance.start_idx = 0
inner_model_instance.end_idx = len(layers)
inner_model_instance.num_layers = len(layers)
elif isinstance(model, Qwen3MoeModel):
logger.info(
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
)
inner_model_instance.num_hidden_layers = len(layers)
elif hasattr(inner_model_instance, "h"):
inner_model_instance.h = layers
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
# Shard the self attention
if layer.self_attn.q_lora_rank is None: # pyright: ignore[reportUnnecessaryComparison]
# Unfortunately, q_lora_rank can be None despite typing hints.
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
@@ -305,7 +369,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.num_heads //= self.N
# Shard the MLP
if isinstance(layer.mlp, DeepseekV3MLP):
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
@@ -339,6 +403,35 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.block_sparse_moe.switch_mlp.down_proj
)
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(Qwen3MoeModel, model)
@@ -353,11 +446,13 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
if isinstance(
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # type: ignore
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -381,3 +476,50 @@ class ShardedQwenMoE(CustomMlxLayer):
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn.num_key_value_groups = (
layer.self_attn.num_attention_heads
// layer.self_attn.num_key_value_heads
)
layer.self_attn.sinks = layer.self_attn.sinks[
layer.self_attn.num_attention_heads
* self.group.rank() : layer.self_attn.num_attention_heads
* (self.group.rank() + 1)
]
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
return model
class ShardedGptOssMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -8,13 +8,12 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
FinishReason,
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
@@ -53,14 +52,9 @@ def warmup_inference(
warmup_prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=ChatCompletionTaskParams(
task_params=ResponsesRequest(
model="",
messages=[
ChatCompletionMessage(
role="user",
content=content,
)
],
input=content,
),
)
@@ -118,11 +112,11 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
task: ResponsesRequest,
is_bench: bool = False,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
# Currently we support chat-completion tasks only.
logger.info(f"task_params: {task}")
@@ -132,7 +126,7 @@ def mlx_generate(
prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=task,
task_params=task,
)
caches = make_kv_cache(model=model)
@@ -146,9 +140,20 @@ def mlx_generate(
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
top_k=task.top_k if task.top_k is not None else 0,
)
max_tokens = task.max_tokens or MAX_TOKENS
# Normalize stop sequences to a list
stop_sequences: list[str] = (
([task.stop] if isinstance(task.stop, str) else task.stop)
if task.stop is not None
else []
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
max_tokens = task.max_output_tokens or MAX_TOKENS
accumulated_text = ""
for out in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -161,11 +166,34 @@ def mlx_generate(
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
return_logprob=True,
return_top_logprobs=5,
):
logger.info(out.text)
accumulated_text += out.text
# Check for stop sequences
text = out.text
finish_reason: FinishReason | None = cast(
FinishReason | None, out.finish_reason
)
stop_matched = False
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in accumulated_text:
# Trim text to just before the stop sequence
stop_index = accumulated_text.find(stop_seq)
text_before_stop = accumulated_text[:stop_index]
chunk_start = len(accumulated_text) - len(out.text)
text = text_before_stop[chunk_start:]
finish_reason = "stop"
stop_matched = True
break
is_done = finish_reason is not None
stats: GenerationStats | None = None
if out.finish_reason is not None:
if is_done:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
@@ -173,22 +201,41 @@ def mlx_generate(
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
if not stop_matched and out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
# Extract logprobs if available
logprob: float | None = getattr(out, "logprob", None)
top_logprobs_raw: list[tuple[int, float]] | None = getattr(
out, "top_logprobs", None
)
top_logprobs: list[TopLogprobItem] | None = None
if top_logprobs_raw is not None:
top_logprobs = [
TopLogprobItem(
token=text if i == 0 else tokenizer.decode([tok_id]),
logprob=float(lp),
)
for i, (tok_id, lp) in enumerate(top_logprobs_raw)
]
yield GenerationResponse(
text=out.text,
text=text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
)
if out.finish_reason is not None:
if is_done:
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
# TODO: Do we want an mx_barrier?

View File

@@ -1,12 +1,28 @@
import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
try:
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
from transformers.convert_slow_tokenizer import bytes_to_unicode
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -18,7 +34,7 @@ from exo.worker.engines.mlx.constants import (
try:
from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
from mlx_lm.tokenizer_utils import load as load_tokenizer
import contextlib
import mlx.core as mx
@@ -26,10 +42,9 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import (
BoundInstance,
MlxJacclInstance,
@@ -68,6 +83,45 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -131,22 +185,20 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
assert all(jaccl_devices[i][i] is None for i in range(len(jaccl_devices)))
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
ibv_devices_json = json.dumps(ibv_devices)
with open(coordination_file, "w") as f:
_ = f.write(jaccl_devices_json)
_ = f.write(ibv_devices_json)
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}")
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
@@ -176,7 +228,9 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
@@ -190,7 +244,9 @@ def load_mlx_items(
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
@@ -204,6 +260,7 @@ def load_mlx_items(
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
@@ -240,7 +297,15 @@ def shard_and_load(
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)
@@ -254,63 +319,127 @@ def shard_and_load(
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
# TODO: Let's move away from this custom logic to mlx_lm.load()
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [163586]
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
elif "glm" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [151336, 151329, 151338]
else:
eos_token_ids = None
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
tokenizer = cast(
TokenizerWrapper,
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
),
Some models require explicit EOS token configuration that isn't in their
tokenizer config. This function returns the known EOS token IDs for such models.
Args:
model_id: The HuggingFace model ID
Returns:
List of EOS token IDs, or None if the model uses standard tokenizer config
"""
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
This is the core tokenizer loading logic, handling special cases for different
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
Args:
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
model_path: Local path where the model/tokenizer files are stored
Returns:
TokenizerWrapper instance configured for the model
"""
model_id_lower = model_id.lower()
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
if "kimi-k2" in model_id_lower:
sys.path.insert(0, str(model_path))
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
# Patch encode to use internal tiktoken model directly
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
def _patched_encode(text: str, **_kwargs: object) -> list[int]:
# Pass allowed_special="all" to handle special tokens like <|im_user|>
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
tokenizer = load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
)
assert isinstance(tokenizer, TokenizerWrapper)
return tokenizer
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
task_params: ResponsesRequest,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
"""Convert ResponsesRequest to a chat template prompt.
Converts the internal format (input + instructions) to a messages list
that can be processed by the tokenizer's chat template.
"""
formatted_messages: list[dict[str, Any]] = []
for message in messages:
if isinstance(message.content, ChatCompletionMessageText):
message.content = message.content.text
if isinstance(message.content, list):
if len(message.content) == 0:
logger.warning("Received prompt with no content, skipping")
continue
message.content = "\n".join(c.text for c in message.content).strip()
if message.content is None and message.thinking is None:
continue
# Null values are not valid when applying templates in tokenizer
# Add system message (instructions) if present
if task_params.instructions:
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
{"role": "system", "content": task_params.instructions}
)
prompt: str = tokenizer.apply_chat_template( # type: ignore
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
)
# Convert input to messages
if isinstance(task_params.input, str):
# Simple string input becomes a single user message
formatted_messages.append({"role": "user", "content": task_params.input})
else:
# List of InputMessage
for msg in task_params.input:
if not msg.content:
logger.warning("Received message with empty content, skipping")
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
return prompt # type: ignore
# Use continue_final_message when continuing from prefix (e.g., regenerate from token)
# This keeps the final assistant message open without EOS tokens
# Note: explicitly set add_generation_prompt=False when using continue_final_message
# because some tokenizers (e.g., Kimi) default add_generation_prompt=True
prompt: str
if task_params.continue_from_prefix:
prompt = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
continue_final_message=True,
add_generation_prompt=False,
tools=task_params.tools,
)
else:
prompt = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=task_params.tools,
)
logger.info(prompt)
return prompt
class NullKVCache(KVCache):
@@ -341,6 +470,11 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -16,7 +16,8 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
@@ -24,6 +25,7 @@ from exo.shared.types.events import (
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -32,7 +34,7 @@ from exo.shared.types.tasks import (
Task,
TaskStatus,
)
from exo.shared.types.topology import SocketConnection
from exo.shared.types.topology import Connection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -43,14 +45,14 @@ from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
class Worker:
@@ -84,7 +86,7 @@ class Worker:
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
self._tg: TaskGroup | None = None
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -96,13 +98,37 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
@@ -116,17 +142,6 @@ class Worker:
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -146,6 +161,7 @@ class Worker:
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
assert self._tg
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
@@ -201,7 +217,9 @@ class Worker:
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
@@ -234,7 +252,8 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
if self._tg:
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
@@ -251,24 +270,24 @@ class Worker:
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
async def _nack_request(self, since_idx: int) -> None:
@@ -317,6 +336,7 @@ class Worker:
event_sender=self.event_sender.clone(),
)
self.runners[task.bound_instance.bound_runner_id] = runner
assert self._tg
self._tg.start_soon(runner.run)
return runner
@@ -346,7 +366,11 @@ class Worker:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
@@ -375,6 +399,7 @@ class Worker:
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
@@ -397,9 +422,7 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(
self.state.topology, self.state.node_profiles, self.node_id
)
conns = await check_reachable(self.state.topology, self.node_id)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
@@ -407,31 +430,26 @@ class Worker:
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(
TopologyEdgeCreated(
source=self.node_id, sink=nid, edge=edge
)
)
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
for nid, conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn, SocketConnection):
continue
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
nid, set()
if (
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
)
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
@@ -445,7 +463,9 @@ class Worker:
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:

View File

@@ -17,15 +17,23 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.jaccl_devices) >= 2
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
)
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main

View File

@@ -1,9 +1,20 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -11,6 +22,8 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -39,6 +52,7 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -48,6 +62,33 @@ from exo.worker.engines.mlx.utils_mlx import (
from exo.worker.runner.bootstrap import logger
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -109,7 +150,20 @@ def main(
)
)
model, tokenizer = load_mlx_items(bound_instance, group)
def on_model_load_timeout() -> None:
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id,
runner_status=RunnerFailed(
error_message="Model loading timed out"
),
)
)
time.sleep(0.5)
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -126,7 +180,7 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
model=cast(Model, model),
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -139,8 +193,6 @@ def main(
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -149,33 +201,48 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
assert model
assert tokenizer
_check_for_debug_prompts(task_params)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
@@ -207,22 +274,65 @@ def main(
break
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(
prompt: str | ChatCompletionMessageText | list[ChatCompletionMessageText],
):
if isinstance(prompt, list):
if len(prompt) == 0:
logger.debug("Empty message prompt received in debug prompt")
return
prompt = prompt[0]
def _check_for_debug_prompts(task_params: ResponsesRequest) -> None:
"""Check for debug prompt triggers in the input.
if isinstance(prompt, ChatCompletionMessageText):
prompt = prompt.text
Extracts the first user input text and checks for debug triggers.
"""
prompt: str
if isinstance(task_params.input, str):
prompt = task_params.input
else:
# List of InputMessage - get first message content
if len(task_params.input) == 0:
logger.debug("Empty message list in debug prompt check")
return
prompt = task_params.input[0].content
if not prompt:
return
if EXO_RUNNER_MUST_FAIL in prompt:
logger.info("raising exception")

View File

@@ -0,0 +1,386 @@
"""
Unit tests for tokenizer loading and functionality across all supported models.
This test downloads only tokenizer-related files (not full model weights) to verify
that tokenizers can be loaded and used correctly for encoding/decoding.
"""
import asyncio
import contextlib
from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,
)
# Files needed for tokenizer functionality
TOKENIZER_FILE_PATTERNS = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"vocab.txt",
"merges.txt",
"tiktoken.model",
"added_tokens.json",
"tokenizer.model",
"tokenization_*.py", # Custom tokenizer implementations
]
def is_tokenizer_file(filename: str) -> bool:
"""Check if a file is needed for tokenizer functionality."""
for pattern in TOKENIZER_FILE_PATTERNS:
if "*" in pattern:
prefix = pattern.split("*")[0]
suffix = pattern.split("*")[1]
if filename.startswith(prefix) and filename.endswith(suffix):
return True
elif filename == pattern:
return True
return False
async def download_tokenizer_files(model_id: str) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
tokenizer_files = [f for f in file_list if is_tokenizer_file(f.path)]
if not tokenizer_files:
pytest.skip(f"No tokenizer files found for {model_id}")
for file_entry in tokenizer_files:
with contextlib.suppress(FileNotFoundError):
await download_file_with_retry(
model_id, "main", file_entry.path, target_dir
)
return target_dir
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for short_id, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = short_id.split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (short_id, card)
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
@pytest.fixture(scope="module")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
# Verify required files exist
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
if not has_tokenizer:
pytest.skip(f"Required tokenizer files not found for {model_id}")
# Load tokenizer
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
# Test basic encoding
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
assert isinstance(encoded, list), f"encode() should return a list for {model_id}"
assert len(encoded) > 0, f"encode() should return non-empty list for {model_id}"
assert all(isinstance(t, int) for t in encoded), (
f"All tokens should be integers for {model_id}"
)
# Test decoding
decoded = tokenizer.decode(encoded)
assert isinstance(decoded, str), f"decode() should return a string for {model_id}"
assert test_text in decoded or decoded.strip() == test_text.strip(), (
f"decode(encode(x)) should preserve text for {model_id}: got {decoded!r}"
)
# Test with longer text
long_text = "The quick brown fox jumps over the lazy dog. " * 10
long_encoded = tokenizer.encode(long_text)
assert len(long_encoded) > len(encoded), (
f"Longer text should produce more tokens for {model_id}"
)
# Test empty string
empty_encoded = tokenizer.encode("")
assert isinstance(empty_encoded, list), (
f"encode('') should return a list for {model_id}"
)
# Test special characters
special_text = 'Hello!\n\tWorld? <test> & "quotes"'
special_encoded = tokenizer.encode(special_text)
assert len(special_encoded) > 0, f"Special chars should encode for {model_id}"
# Test unicode
unicode_text = "Hello 世界 🌍"
unicode_encoded = tokenizer.encode(unicode_text)
assert len(unicode_encoded) > 0, f"Unicode should encode for {model_id}"
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
if not has_tokenizer:
pytest.skip(f"Required tokenizer files not found for {model_id}")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Check for vocabulary size
empty_vocab: dict[str, int] = {}
vocab_size: int = getattr(tokenizer, "vocab_size", None) or len(
getattr(tokenizer, "get_vocab", lambda: empty_vocab)()
)
assert vocab_size > 0, f"Tokenizer should have vocab_size > 0 for {model_id}"
# Check for EOS token (either from tokenizer or explicitly provided)
has_eos = (
eos_token_ids is not None
or getattr(tokenizer, "eos_token_id", None) is not None
or getattr(tokenizer, "eos_token", None) is not None
)
assert has_eos, f"Tokenizer should have EOS token for {model_id}"
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
assert has_tokenizer, f"Required tokenizer files not found for {model_id}"
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
# Get special tokens from the tokenizer
special_tokens: list[str] = []
# Try to get special tokens from various sources
if hasattr(tokenizer, "all_special_tokens"):
special_tokens.extend(tokenizer.all_special_tokens)
elif hasattr(tokenizer, "_tokenizer") and hasattr(
tokenizer._tokenizer,
"all_special_tokens",
):
special_tokens.extend(tokenizer._tokenizer.all_special_tokens)
# Also check for common special token attributes
for attr in [
"bos_token",
"eos_token",
"pad_token",
"unk_token",
"sep_token",
"cls_token",
]:
token = getattr(tokenizer, attr, None)
if token is None and hasattr(tokenizer, "_tokenizer"):
token = getattr(tokenizer._tokenizer, attr, None)
if token and isinstance(token, str) and token not in special_tokens:
special_tokens.append(token)
# If we found special tokens, test encoding text that contains them
if special_tokens:
# Create text with special tokens interspersed
test_with_special = f"{special_tokens[0]}Hello world"
if len(special_tokens) > 1:
test_with_special += f"{special_tokens[1]}"
encoded = tokenizer.encode(test_with_special)
assert isinstance(encoded, list), (
f"encode() with special tokens should return list for {model_id}"
)
assert len(encoded) > 0, (
f"encode() with special tokens should return non-empty list for {model_id}"
)
assert all(isinstance(t, int) for t in encoded), (
f"All tokens should be integers for {model_id}"
)
# Verify we can decode
decoded = tokenizer.decode(encoded)
assert isinstance(decoded, str), f"decode() should return string for {model_id}"
# Test with angle-bracket tokens (common format for special tokens)
# These should not raise errors even if they're not actual special tokens
angle_bracket_text = "<|test|>Hello<|end|>"
encoded = tokenizer.encode(angle_bracket_text)
assert isinstance(encoded, list), (
f"encode() with angle brackets should return list for {model_id}"
)
assert len(encoded) > 0, (
f"encode() with angle brackets should be non-empty for {model_id}"
)
# Specifically test Kimi tokenizer since it has special handling
@pytest.mark.asyncio
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
# Ensure the custom tokenizer file exists
if not (model_path / "tokenization_kimi.py").exists():
pytest.skip("tokenization_kimi.py not found")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Test encode/decode cycle
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert len(encoded) > 0, "Kimi tokenizer should encode text"
assert isinstance(decoded, str), "Kimi tokenizer should decode to string"
# Test that the patched encode works (returns list of ints)
assert all(isinstance(t, int) for t in encoded), "Tokens should be integers"
# Test encoding text with special tokens (like from chat templates)
# This is critical - the warmup inference uses prompts with special tokens
special_token_text = "<|im_user|>user<|im_middle|>Hello<|im_end|><|im_assistant|>"
special_encoded = tokenizer.encode(special_token_text)
assert len(special_encoded) > 0, "Kimi tokenizer should handle special tokens"
assert all(isinstance(t, int) for t in special_encoded), (
"Special token encoding should return integers"
)
# Verify EOS token is set
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
# Test GLM tokenizer since it also has special handling
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
]
if not glm_models:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (model_path / "tokenizer.json").exists() or (
model_path / "tokenizer_config.json"
).exists()
if not has_tokenizer:
pytest.skip("GLM tokenizer files not found")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Test encode/decode
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert len(encoded) > 0, "GLM tokenizer should encode text"
assert isinstance(decoded, str), "GLM tokenizer should decode to string"
# Verify EOS tokens
assert eos_token_ids == [
151336,
151329,
151338,
], "GLM EOS tokens should be correct"

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -94,13 +95,23 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
],
}
result = plan_mod.plan(
@@ -140,7 +151,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,10 +205,16 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [], # NODE_B has no downloads completed yet
}
@@ -212,9 +231,15 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
assert result is None
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet
}

View File

@@ -1,7 +1,7 @@
from typing import cast
import exo.worker.plan as plan_mod
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
@@ -59,7 +59,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -107,7 +107,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -152,7 +152,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
instance_id=other_instance_id,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -201,7 +201,7 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Complete,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
other_task_id = TaskId("other-task")

View File

@@ -0,0 +1,50 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()

View File

@@ -5,7 +5,6 @@ from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
@@ -14,9 +13,9 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
@@ -85,11 +84,11 @@ SHUTDOWN_TASK = Shutdown(
runner_id=RUNNER_1_ID,
)
CHAT_PARAMS = ChatCompletionTaskParams(
CHAT_PARAMS = ResponsesRequest(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
input="hello",
stream=True,
max_tokens=4,
max_output_tokens=4,
temperature=0.0,
)

View File

@@ -0,0 +1,6 @@
from .profile import start_polling_memory_metrics, start_polling_node_metrics
__all__ = [
"start_polling_node_metrics",
"start_polling_memory_metrics",
]

View File

@@ -0,0 +1,108 @@
import anyio
import httpx
from anyio import create_task_group
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
REACHABILITY_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
remote_node_id = NodeId(body)
break
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
if remote_node_id is None:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
for node in topology.list_nodes():
if not node.node_profile:
continue
if node.node_id == self_node_id:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
reachable,
client,
)
return reachable

View File

@@ -0,0 +1,114 @@
import asyncio
import os
import platform
from typing import Any, Callable, Coroutine
import anyio
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from .macmon import (
MacMonError,
Metrics,
)
from .macmon import (
get_metrics_async as macmon_get_metrics_async,
)
from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
)
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
if platform.system().lower() == "darwin":
return await macmon_get_metrics_async()
def get_memory_profile() -> MemoryPerformanceProfile:
"""Construct a MemoryPerformanceProfile using psutil"""
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
async def start_polling_memory_metrics(
callback: Callable[[MemoryPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 0.5,
) -> None:
"""Continuously poll and emit memory-only metrics at a faster cadence.
Parameters
- callback: coroutine called with a fresh MemoryPerformanceProfile each tick
- poll_interval_s: interval between polls
"""
while True:
try:
mem = get_memory_profile()
await callback(mem)
except MacMonError as e:
logger.opt(exception=e).error("Memory Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
)
)
except asyncio.TimeoutError:
logger.warning(
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
)
except MacMonError as e:
logger.opt(exception=e).error("Resource Monitor encountered error")
return
finally:
await anyio.sleep(poll_interval_s)

View File

@@ -13,10 +13,10 @@ from pydantic import BaseModel
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -34,8 +34,7 @@ from exo.shared.types.worker.instances import (
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.channels import MpReceiver, MpSender, mp_channel
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
@@ -50,14 +49,12 @@ class Tests(BaseModel):
kind: typing.Literal["init", "warmup", "inference"]
hn = socket.gethostname()
mp.set_start_method("spawn", force=True)
logger_setup(None)
async def main():
logger.info("starting cool server majig")
logger.info(hn)
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
@@ -68,7 +65,6 @@ async def main():
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
app, # type: ignore
@@ -80,32 +76,44 @@ async def main():
shutdown.set()
async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["llama-3.2-1b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
return await execute_test(test, ring_instance(test, iid))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
def ring_instance(test: Tests, iid: InstanceId) -> Instance:
global hn
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if hn.startswith(test.devs[i][0]):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
@@ -113,6 +121,8 @@ def ring_instance(test: Tests, iid: InstanceId) -> Instance:
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
meta = MODEL_CARDS[test.model_id].metadata
instance = MlxRingInstance(
@@ -142,10 +152,10 @@ def ring_instance(test: Tests, iid: InstanceId) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance):
async def execute_test(test: Tests, instance: Instance, hn: str):
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance)
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
@@ -159,16 +169,10 @@ async def execute_test(test: Tests, instance: Instance):
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
task_params=ResponsesRequest(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,
@@ -192,17 +196,19 @@ async def execute_test(test: Tests, instance: Instance):
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
return await execute_test(test, jaccl_instance(test, iid))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
global hn
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
meta = MODEL_CARDS[test.model_id].metadata
world_size = len(test.devs)
for name, _ in test.devs:
if hn.startswith(name):
hn = name
break
return MlxJacclInstance(
instance_id=iid,
@@ -231,6 +237,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)

Some files were not shown because too many files have changed in this diff Show More